# Fine-tune UTAE model to predict AGBM from Temporal Sentinel-2 Images

## Imports

In [1]:
from time import time

In [2]:
from biomasstry.datasets import TemporalSentinel2Dataset, TemporalSentinel1Dataset
from biomasstry.models import TemporalSentinelModel, UTAE
from biomasstry.models.unet_tae import ConvBlock
# from biomasstry.models.utils import run_training
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import random_split, DataLoader
from tqdm.notebook import tqdm

## Utility Functions

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(f"Device: {device}")

Device: cuda


## Dataset

In [4]:
# Metadata
metadata_file = "/notebooks/data/metadata_parquet/features_metadata_slim.parquet"
metadata_df = pd.read_parquet(metadata_file)
chip_ids = metadata_df[metadata_df.split == "train"].chip_id.unique().tolist()

In [5]:
random_perm = np.random.permutation(len(chip_ids))
cut = int(0.8 * len(chip_ids))
train_split = random_perm[:cut]
eval_split = random_perm[cut:]

In [6]:
S3_DIRECT = False  # Access S3 directly or as a mounted data source
if S3_DIRECT:
    data_url="s3://drivendata-competition-biomassters-public-us"
else:
    data_url = ""

datasets = ["Sentinel-1A",  # Sentinel-1 Ascending only
            "Sentinel-1D",  # Sentinel-1 Descending only
            "Sentinel-2all"]

dataset = datasets[2]
if dataset == "Sentinel-1A":
    ds = TemporalSentinel1Dataset(data_url=data_url, bands=["VVA", "VHA"])
    input_nc = 2
    n_tsamples = 6
elif dataset == "Sentinel-1A":
    ds = TemporalSentinel1Dataset(data_url=data_url, bands=["VVD", "VHD"])
    input_nc = 2
    n_tsamples = 6
elif dataset == "Sentinel-2all":
    # ds = TemporalSentinel2Dataset(data_url=data_url)
    train_set = TemporalSentinel2Dataset([chip_ids[i] for i in train_split],
                                        data_url=data_url)
    eval_set = TemporalSentinel2Dataset([chip_ids[i] for i in eval_split],
                                       data_url=data_url)
    input_nc = 10
    n_tsamples = 5

In [7]:
print(f"Train samples: {len(train_set)} "
      f"Val. samples: {len(eval_set)}")

Train samples: 6951 Val. samples: 1738


## Model

In [8]:
artifacts_dir = "/notebooks/artifacts"
pretrained_weights_path = artifacts_dir + "/pretrained_utae/f1model.pth.tar"
model = UTAE(input_nc, out_conv=[32, 20])  # .to(accelerator.device)
saved_dict = torch.load(pretrained_weights_path, map_location=device)
model.load_state_dict(saved_dict["state_dict"])
model.out_conv = ConvBlock([32, 32, 1], padding_mode="reflect")
model = model.to(device)

In [9]:
loss_function = nn.MSELoss(reduction='mean')  # .to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Training Arguments

In [10]:
batch_size = 4
model_name = "UTAE-pretrainedF1"
nb_epochs = 10
date = "20230118"
save_path = artifacts_dir + (f"/{date}_{model_name}_{dataset}_B"
    f"{batch_size}"
    f"_E{nb_epochs}.pt")
print(f"Model file path: {save_path}")

Model file path: /notebooks/artifacts/20230118_UTAE-pretrainedF1_Sentinel-2all_B4_E10.pt


## DataLoaders

In [11]:
# DataLoaders
num_workers = 6
train_dataloader = DataLoader(train_set,
                      batch_size=batch_size,
                      shuffle=True,
                      pin_memory=True,
                      num_workers=num_workers)
eval_dataloader = DataLoader(eval_set,
                    batch_size=batch_size,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=num_workers)

## Training and Evaluation Loop

In [12]:
num_batches = len(eval_dataloader)
train_metrics = []
val_metrics = []
min_valid_metric = np.inf

for i in tqdm(range(nb_epochs)):
    train_metrics_epoch = []
    epoch_start = time()
    for batch in tqdm(train_dataloader):
        inputs = batch["image"].to(device)
        targets = batch["target"].to(device)
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_metrics_epoch.append(np.round(np.sqrt(loss.item()), 5))

    epoch_end = time()
    print(f"Epoch training time: {epoch_end - epoch_start}")
    
    # Saving State Dict after each epoch
    torch.save(model.state_dict(), save_path[:-3] + "_Ep{i}.pt")
    
    # Validation Loop
    val_loss = 0.0
    for batch in tqdm(eval_dataloader):
        inputs = batch["image"].to(device)
        targets = batch["target"].to(device)
        with torch.no_grad():
            predictions = model(inputs)
        val_loss += loss_function(predictions, targets).item()

    val_loss /= num_batches
    val_rmse = np.round(np.sqrt(val_loss), 5)
    print(f"Validation Error: \n RMSE: {val_rmse:>8f} \n")
    train_metrics.extend(train_metrics_epoch)
    val_metrics.append((len(train_metrics), val_rmse))
    # check validation score, if improved then save model
    if min_valid_metric > val_rmse:
        print(f'Validation RMSE Decreased({min_valid_metric:.6f}--->{val_rmse:.6f}) \t Saving The Model')
        min_valid_metric = val_rmse

        # Saving State Dict
        torch.save(model.state_dict(), save_path[:-3] + "_BEST.pt")


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1738 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
##### Save the metrics to a file
train_metrics_zipped = list(zip(np.arange(0, len(train_metrics)), train_metrics))
metrics = {'training': train_metrics_zipped, 'validation': val_metrics}
train_metrics_df = pd.DataFrame(metrics['training'], columns=["step", "score"])
val_metrics_df = pd.DataFrame(metrics["validation"], columns=["step", "score"])
train_metrics_df.to_csv(artifacts_dir + "/train_metrics.csv")
val_metrics_df.to_csv(artifacts_dir + "/val_metrics.csv")