# Use HuggingFace Accelerate to Train a Stacked ResNet Model on Temporal (monthly) Sentinel Data

## Imports

In [1]:
from time import time

In [2]:
from accelerate import Accelerator
from biomasstry.datasets import TemporalSentinel2Dataset, TemporalSentinel1Dataset
from biomasstry.models import TemporalSentinelModel, UTAE
# from biomasstry.models.utils import run_training
import numpy as np
import pandas as pd
from pynvml import *
import torch
import torch.nn as nn
from torch.utils.data import random_split, DataLoader
from transformers import TrainingArguments, Trainer, logging
from tqdm.notebook import tqdm

In [3]:
logging.set_verbosity_error()

## Utility Functions

In [4]:
# Utility functions for printing GPU utilization
def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [5]:
print_gpu_utilization()

GPU memory occupied: 107 MB.


## Dataset

In [6]:
S3_DIRECT = False  # Access S3 directly or as a mounted data source
USE_SENTINEL1 = False
if S3_DIRECT:
    data_url="s3://drivendata-competition-biomassters-public-us"
else:
    data_url = ""
if USE_SENTINEL1:
    ds = TemporalSentinel1Dataset(data_url=data_url)
    input_nc = 4
    n_tsamples = 12
else:
    ds = TemporalSentinel2Dataset(data_url=data_url)
    input_nc = 10
    n_tsamples = 5

In [7]:
torch.manual_seed(0)
train_size = int(0.8*len(ds))
valid_size = len(ds) - train_size
train_set, val_set = random_split(ds, [train_size, valid_size])
print(f"Train samples: {len(train_set)} "
      f"Val. samples: {len(val_set)}")

Train samples: 6951 Val. samples: 1738


## Model

In [8]:
# model = TemporalSentinelModel(
#     n_tsamples=n_tsamples, 
#     input_nc=input_nc,
#     output_nc=1,
# )  # .to(device)

In [9]:
model = UTAE(input_nc)

In [10]:
loss_function = nn.MSELoss(reduction='mean')  # .to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

## Training Arguments

In [11]:
default_args = {
    "output_dir": "/notebooks/artifacts",
    "overwrite_output_dir": "True",
    "evaluation_strategy": "steps",
    "num_train_epochs": 10,
    "log_level": "error",
    "report_to": "none",
}

In [12]:
training_args = TrainingArguments(
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    **default_args,
)

## DataLoaders

In [13]:
# DataLoaders
num_workers = 6
train_dataloader = DataLoader(train_set,
                      batch_size=training_args.per_device_train_batch_size,
                      shuffle=True,
                      pin_memory=True,
                      num_workers=num_workers)
eval_dataloader = DataLoader(val_set,
                    batch_size=training_args.per_device_train_batch_size,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=num_workers)

## Prepare

In [14]:
accel_model = Trainer(model, args=training_args, train_dataset=train_set)

# HuggingFace Accelerator with Gradient Accumulation
accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision='fp16')

accel_model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(accel_model,
                                                                optimizer,
                                                                train_dataloader,
                                                                eval_dataloader)

In [15]:
artifacts_dir = "/notebooks/artifacts"
if USE_SENTINEL1:
    model_name = "UTAE_S1"
else:
    model_name = "UTAE_S2"
nb_epochs = 10
date = "20230109"
save_path = artifacts_dir + (f"/{date}_{model_name}_B"
    f"{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
    f"_E{nb_epochs}.pt")
print(f"Model file path: {save_path}")

Model file path: /notebooks/artifacts/20230109_UTAE_S2_B32_E10.pt


## Training and Evaluation Loop

In [16]:
num_batches = len(eval_dataloader)
train_metrics = []
val_metrics = []
min_valid_metric = np.inf
for i in tqdm(range(nb_epochs)):
    train_metrics_epoch = []
    epoch_start = time()
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        inputs = batch["image"]
        targets = batch["target"]
        inputs = torch.stack(inputs, dim=1)
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        accelerator.backward(loss)
        optimizer.step()
        train_metrics_epoch.append(np.round(np.sqrt(loss.item()), 5))

    epoch_end = time()
    print(f"Epoch training time: {epoch_end - epoch_start}")
    # Validation Loop
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            inputs = batch["image"]
            targets = batch["target"]
            inputs = torch.stack(inputs, dim=1)
            predictions = model(inputs)
            # Gather all predictions and targets
            all_predictions, all_targets = accelerator.gather_for_metrics((predictions, targets))
            val_loss += loss_function(predictions, targets).item()

    val_loss /= num_batches
    val_rmse = np.round(np.sqrt(val_loss), 5)
    print(f"Validation Error: \n RMSE: {val_rmse:>8f} \n")
    train_metrics.extend(train_metrics_epoch)
    val_metrics.append((len(train_metrics), val_rmse))
    # check validation score, if improved then save model
    if min_valid_metric > val_rmse:
        print(f'Validation RMSE Decreased({min_valid_metric:.6f}--->{val_rmse:.6f}) \t Saving The Model')
        min_valid_metric = val_rmse

        # Saving State Dict
        torch.save(model.state_dict(), save_path)


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1071.1505773067474


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 75.412090 

Validation RMSE Decreased(inf--->75.412090) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1104.5562286376953


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 60.441380 

Validation RMSE Decreased(75.412090--->60.441380) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1108.6039881706238


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 50.115050 

Validation RMSE Decreased(60.441380--->50.115050) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1103.446296453476


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 44.618340 

Validation RMSE Decreased(50.115050--->44.618340) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1143.2115604877472


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 41.978310 

Validation RMSE Decreased(44.618340--->41.978310) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1154.649334192276


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 41.715980 

Validation RMSE Decreased(41.978310--->41.715980) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1107.8819320201874


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 41.145460 

Validation RMSE Decreased(41.715980--->41.145460) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1133.770623922348


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 40.956760 

Validation RMSE Decreased(41.145460--->40.956760) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1106.3496730327606


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 40.581980 

Validation RMSE Decreased(40.956760--->40.581980) 	 Saving The Model


  0%|          | 0/869 [00:00<?, ?it/s]

Epoch training time: 1117.100100517273


  0%|          | 0/218 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 40.630280 



In [17]:
##### Save the metrics to a file
train_metrics_zipped = list(zip(np.arange(0, len(train_metrics)), train_metrics))
metrics = {'training': train_metrics_zipped, 'validation': val_metrics}
train_metrics_df = pd.DataFrame(metrics['training'], columns=["step", "score"])
val_metrics_df = pd.DataFrame(metrics["validation"], columns=["step", "score"])
train_metrics_df.to_csv(artifacts_dir + "/train_metrics.csv")
val_metrics_df.to_csv(artifacts_dir + "/val_metrics.csv")