# Use HuggingFace Accelerate to Train Model on Temporal (monthly) Sentinel Data

## Imports

In [1]:
import multiprocessing as mp
from time import time

In [2]:
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed
from biomasstry.datasets import TemporalSentinel2Dataset, TemporalSentinel1Dataset
from biomasstry.models import TemporalSentinelModel, UTAE
# from biomasstry.models.utils import run_training
import numpy as np
import pandas as pd
from pynvml import *
import torch
import torch.nn as nn
from torch.utils.data import random_split, DataLoader
from transformers import TrainingArguments, Trainer, logging
from tqdm.auto import tqdm

In [3]:
logging.set_verbosity_error()

In [4]:
mp.set_forkserver_preload(["torch"])

## Utility Functions

In [5]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

1.12.0+cu116
11.6
8302


In [6]:
# Utility functions for printing GPU utilization
def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [7]:
print_gpu_utilization()

GPU memory occupied: 107 MB.


## Dataset and DataLoaders

In [8]:
def get_dataloaders(dataset: str, batch_size: int=8, num_workers: int=0):
    """Return train and eval DataLoaders with specified batch size.
    
    dataset: str
        Dataset identifier. Must be one of "Sentinel-1A", "Sentinel-1D" or "Sentinel-2All"
    batch_size: int
        batch size for each batch.
    """
    # If True, access directly S3.
    # If False, assume data is mounted and available under '/datasets/biomassters'
    S3_DIRECT = False
    if S3_DIRECT:
        data_url="s3://drivendata-competition-biomassters-public-us"
    else:
        data_url = ""

    if dataset == "Sentinel-1A": # Sentinel-1 Ascending only
        ds = TemporalSentinel1Dataset(data_url=data_url, bands=["VVA", "VHA"])
    elif dataset == "Sentinel-1D": # Sentinel-1 Descending only
        ds = TemporalSentinel1Dataset(data_url=data_url, bands=["VVD", "VHD"])
    elif dataset == "Sentinel-2all":
        ds = TemporalSentinel2Dataset(data_url=data_url)
    else:
        print("Unrecognized dataset identifier. Must be one of 'Sentinel-1A', 'Sentinel-1D' or 'Sentinel-2all'")
        return None, None

    train_size = int(0.8*len(ds))
    valid_size = len(ds) - train_size
    train_set, eval_set = random_split(ds, [train_size, valid_size])

    print(f"Train samples: {len(train_set)} "
        f"Val. samples: {len(eval_set)}")

    # DataLoaders
    pin_memory = True
    train_dataloader = DataLoader(train_set,
                        batch_size=batch_size,
                        shuffle=True,
                        pin_memory=pin_memory,
                        num_workers=num_workers)
    eval_dataloader = DataLoader(eval_set,
                        batch_size=batch_size,
                        shuffle=False,
                        pin_memory=pin_memory,
                        num_workers=num_workers)
    
    return train_dataloader, eval_dataloader

## Training Loop

In [9]:
def training_loop(dataset: str,
                  mixed_precision: str="fp16",
                  seed: int=123,
                  batch_size: int=8,
                  gradient_accumulation_steps: int=4,
                  nb_epochs=2,
                  train_mode: str=""
    ):
    """Main Training and Evaluation Loop to be called by accelerator.notebook_launcher()."""
    print(f"Args: {mixed_precision}, {seed}, {batch_size}, "
          f"{gradient_accumulation_steps}, {nb_epochs}, {train_mode}")

    # Set random seed
    set_seed(seed)

    # Initialize Accelerator
    accelerator = Accelerator(mixed_precision=mixed_precision,
        gradient_accumulation_steps=gradient_accumulation_steps)

    # Build DataLoaders
    train_dataloader, eval_dataloader = get_dataloaders(dataset, batch_size=batch_size)

    # Assign model inputs based on dataset
    if dataset == "Sentinel-1A":
        input_nc = 2
        n_tsamples = 6
    elif dataset == "Sentinel-1D":
        input_nc = 2
        n_tsamples = 6
    else:
        input_nc = 10
        n_tsamples = 5

    # Create model
    if train_mode == "tune":
        saved_dict = torch.load(pretrained_weights_path)
        with init_empty_weights():
            model = UTAE(10, out_conv=[32, 20])  # Initialize the original model & load pre-trained weights
            model.load_state_dict(saved_dict["state_dict"], map_location=accelerator.device)
        model.out_conv = ConvBlock([32, 32, 1], padding_mode="reflect")  # Modify the last layer
        lr = 0.001
    else:
        model = UTAE(input_nc)  # modify output layer to predict AGBM
        lr = 0.02
        if train_mode == "resume":
            state_dict = torch.load(saved_state_path)  # , map_location=accelerator.device)
            model.load_state_dict(state_dict)
    
    # model = UTAE(input_nc)

    loss_function = nn.MSELoss(reduction='mean')  # Loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Optimizer
    
    # Prepare everything to use accelerator
    # Maintain order while unpacking
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model,
                                                                optimizer,
                                                                train_dataloader,
                                                                eval_dataloader)
    min_valid_metric = np.inf
    save_path = artifacts_dir + (f"/{date}_{model_name}_{dataset}_B"
        f"{batch_size * gradient_accumulation_steps}.pt")
    
    # Training loop
    for i in tqdm(range(nb_epochs), disable=not accelerator.is_local_main_process):
        accelerator.print(f"Epoch {i+1}")
        epoch_start = time()
        for b, batch in enumerate(tqdm(train_dataloader, disable=not accelerator.is_local_main_process)):
            inputs, targets, chip_id = batch
            # with accelerator.accumulate(model):
            #     outputs = model(inputs)
            #     loss = loss_function(outputs, targets)
            #     accelerator.backward(loss)
            #     optimizer.step()
            #     optimizer.zero_grad()
            if b % 10 == 0:
                print(f"Batch {b+1}. Chip ID: {chip_id}. Input size: {inputs.size()}. Output size: {targets.size()}.")
                
        epoch_end = time()
        accelerator.print(f"  Training time: {epoch_end - epoch_start}")
        
        # Save Model State Dict after each epoch in order to continue training later
        unwrap_model = accelerator.unwrap_model(model)  # Unwrap the Accelerator model
        train_model_path = save_path[:-3] + f"_E{i+1}.pt"
        accelerator.save(unwrap_model.state_dict(), train_model_path)
        accelerator.print(f"  Model file path: {train_model_path}")

        # Validation Loop
        val_loss = 0.0
        num_elements = 0
        for batch in tqdm(eval_dataloader, disable=not accelerator.is_local_main_process):
            inputs, targets, _ = batch
            with torch.no_grad():
                predictions = model(inputs)
            # Gather all predictions and targets
            all_predictions, all_targets = accelerator.gather_for_metrics((predictions, targets))
            num_elements += all_predictions.shape[0]
            val_loss += loss_function(all_predictions, all_targets).item()

        val_loss /= num_elements
        val_rmse = np.round(np.sqrt(val_loss), 5)
        accelerator.print(f"  Validation RMSE: {val_rmse:>8f}")
        # check validation score, if improved then save model
        if min_valid_metric > val_rmse:
            accelerator.print(f"  Validation RMSE Decreased({min_valid_metric:.6f}--->{val_rmse:.6f})")
            min_valid_metric = val_rmse

            # Saving Model State Dict
            unwrap_model = accelerator.unwrap_model(model)  # Unwrap the Accelerator model
            accelerator.save(unwrap_model.state_dict(), best_model_path)
            accelerator.print(f"  Best Model file path: {best_model_path}")

In [10]:
dataset = "Sentinel-2all"
mixed_precision = "fp16"
seed = 123
batch_size = 4
gradient_accumulation_steps = 1
nb_epochs = 10
train_mode = "tune"

artifacts_dir = "/notebooks/artifacts"
model_name = "UTAE"
date = "20230118"
pretrained_weights_path = artifacts_dir + "/pretrained_utae/f1model.pth.tar"  # for fine tuning
saved_state_path = artifacts_dir + "/20230112_UTAE_S2_B32_E20.pt"  # for resuming training

save_path = artifacts_dir + (f"/{date}_{model_name}_{dataset}_B"
        f"{batch_size * gradient_accumulation_steps}.pt")
best_model_path = save_path[:-3] + "_BEST.pt"

# Notebook Launcher for distributed training
train_args = (dataset, mixed_precision, seed, batch_size, gradient_accumulation_steps, nb_epochs)
notebook_launcher(training_loop, train_args, num_processes=1)

Launching training on one GPU.
Args: fp16, 123, 4, 1, 10, 
Train samples: 6951 Val. samples: 1738


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1


  0%|          | 0/1738 [00:00<?, ?it/s]

Batch 1. Chip ID: ['fd839d71', 'a3aac2ce', 'e7325413', '2409a7cf']. Input size: torch.Size([4, 5, 10, 256, 256]). Output size: torch.Size([4, 1, 256, 256]).
Batch 11. Chip ID: ['fe7f68eb', '2f5131c6', '57039896', 'dc8dc18d']. Input size: torch.Size([4, 5, 10, 256, 256]). Output size: torch.Size([4, 1, 256, 256]).
Batch 21. Chip ID: ['b3c29fad', 'e935c106', '4e54622e', 'a759e918']. Input size: torch.Size([4, 5, 10, 256, 256]). Output size: torch.Size([4, 1, 256, 256]).
Batch 31. Chip ID: ['0c1b5ede', '1925f7d8', '52d4b426', 'd4475d0a']. Input size: torch.Size([4, 5, 10, 256, 256]). Output size: torch.Size([4, 1, 256, 256]).
Batch 41. Chip ID: ['2b835255', 'b4d9dfcc', '8c94e3ad', '671cc3e6']. Input size: torch.Size([4, 5, 10, 256, 256]). Output size: torch.Size([4, 1, 256, 256]).
Batch 51. Chip ID: ['ce9d0a64', 'c0d5cc9c', 'd4061b6d', '3ff95cac']. Input size: torch.Size([4, 5, 10, 256, 256]). Output size: torch.Size([4, 1, 256, 256]).
Batch 61. Chip ID: ['35cb0400', '88af4eb2', '56ac3ec6

KeyboardInterrupt: 

root@n47fvylijb:/notebooks# accelerate launch notebooks/distributed_training_utae.py 
1.12.0+cu116
11.6
8302
GPU memory occupied: 107 MB.
Args: fp16, 123, 4, 1, 10, 
Before DataLoaders
  time    PID  rss     pss     uss     shared    shared_file
------  -----  ------  ------  ------  --------  -------------
 61239    473  723.6M  554.6M  393.6M  330.0M    330.0M
Train samples: 6951 Val. samples: 1738
After DataLoaders
  time    PID  rss     pss     uss     shared    shared_file
------  -----  ------  ------  ------  --------  -------------
 61239    473  756.7M  587.0M  425.4M  331.2M    331.2M
After Model
  time    PID  rss     pss     uss     shared    shared_file
------  -----  ------  ------  ------  --------  -------------
 61239    473  762.2M  592.5M  430.8M  331.5M    331.5M
  0%|                                                                                            | 0/10 [00:00<?, ?it/s]Epoch 1
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                                       | 0/1738 [00:00<?, ?it/s]
------  -----  -----  -----  -----  --------  -------------
 61243    473  2.5G   2.4G   2.2G   332.8M    332.8M
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                              | 10/1738 [00:08<24:10,  1.19it/s]
------  -----  -----  -----  -----  --------  -------------
 61251    473  2.7G   2.5G   2.4G   332.8M    332.8M
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                              | 20/1738 [00:16<22:01,  1.30it/s]
------  -----  -----  -----  -----  --------  -------------
 61259    473  2.8G   2.6G   2.4G   332.8M    332.8M
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                              | 30/1738 [00:24<20:59,  1.36it/s]
------  -----  -----  -----  -----  --------  -------------
 61267    473  2.8G   2.6G   2.5G   332.8M    332.8M
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                              | 40/1738 [00:31<21:10,  1.34it/s]
------  -----  -----  -----  -----  --------  -------------
 61274    473  2.9G   2.7G   2.6G   332.8M    332.8M
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                              | 50/1738 [00:39<21:23,  1.32it/s]
------  -----  -----  -----  -----  --------  -------------
 61282    473  2.9G   2.8G   2.6G   332.8M    332.8M
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                              | 60/1738 [00:47<21:32,  1.30it/s]
------  -----  -----  -----  -----  --------  -------------
 61289    473  3.0G   2.8G   2.7G   332.8M    332.8M
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                              | 70/1738 [00:54<20:33,  1.35it/s]
------  -----  -----  -----  -----  --------  -------------
 61297    473  3.1G   2.9G   2.7G   332.8M    332.8M
                                                                                                                         time    PID  rss    pss    uss    shared    shared_file                              | 80/1738 [01:02<21:15,  1.30it/s]
------  -----  -----  -----  -----  --------  -------------
 61305    473  3.1G   3.0G   2.8G   332.8M    332.8M
                                                                                                                       ^