# Train Baseline UNET model with Single Satellite Image

In [1]:
import os
from time import time

In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm

In [3]:
sns.set()

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
from biomasstry.datasets import Sentinel2

In [6]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


In [7]:
# Model
# img_data = train_dataset[0]['image']
# in_channels = img_data.shape[0]
in_channels = 10
print(f'# input channels: {in_channels}')

model = smp.Unet(
    encoder_name="resnet50",
    encoder_weights=None, # 'imagenet' weights don't seem to help so start clean 
    in_channels=in_channels,                 
    classes=1,                     
).to(device)

# input channels: 10


In [8]:
# Loss and Optimizer
loss_module = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

In [9]:
# Train and Validation Loops
def train_loop(dataloader, model, loss_fn, optimizer):
    train_metrics = []
    
    print('Training')
    for ix, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        X = batch['image'].to(device)
        y = batch['target'].to(device)
        
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_metrics.append(np.round(np.sqrt(loss.item()), 5))
            
    return train_metrics

In [10]:
def valid_loop(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    valid_loss = 0
    valid_metrics = {}

    print('Validation')
    with torch.no_grad():
        for batch in tqdm(dataloader, total=num_batches):
            X = batch['image'].to(device)
            y = batch['target'].to(device)
            
            pred = model(X)
            valid_loss += loss_fn(pred, y).item()
            
    valid_loss /= num_batches
    valid_rmse = np.round(np.sqrt(valid_loss), 5)
    print(f"Validation Error: \n RMSE: {valid_rmse:>8f} \n")
    return valid_rmse

In [11]:
def run_training(model, loss_module, optimizer, train_dataloader, val_dataloader, save_path, n_epochs=10):
    min_valid_metric = np.inf
    train_metrics = []
    valid_metrics = []
    
    total_train_time = 0
    total_val_time = 0

    for ix in range(n_epochs):
        print(f"\n-------------------------------\nEpoch {ix+1}")
        start = time()
        train_metrics_epoch = train_loop(train_dataloader, model, loss_module, optimizer)
        end = time()
        train_time = end - start
        total_train_time += train_time
        train_metrics.extend(train_metrics_epoch)
        
        start = time()
        valid_metrics_epoch = valid_loop(val_dataloader, model, loss_module)
        end = time()
        val_time = end - start
        total_val_time += val_time
        valid_metrics.append((len(train_metrics), valid_metrics_epoch))

        # check validation score, if improved then save model
        if min_valid_metric > valid_metrics_epoch:
            print(f'Validation RMSE Decreased({min_valid_metric:.6f}--->{valid_metrics_epoch:.6f}) \t Saving The Model')
            min_valid_metric = valid_metrics_epoch

            # Saving State Dict
            torch.save(model.state_dict(), save_path)
        print(f"Train time: {train_time}. Validation time: {val_time}")
    print("Done!")
    print(f"Total train time: {total_train_time} s. Avg. time per epoch: {total_train_time / n_epochs}")
    print(f"Total val time: {total_val_time} s. Avg. time per epoch: {total_val_time / n_epochs}")
    train_metrics_zipped = list(zip(np.arange(0, len(train_metrics)), train_metrics))
    
    return {'training': train_metrics_zipped, 'validation': valid_metrics}

## Experiment with `num_workers` and `batch_size` for tuning `DataLoader` Throughput

In [12]:
# Train a model for each of the months below
months = ["september", "october"]
train_frac = 0.8

torch.manual_seed(0)
dir_saved_models = "/notebooks/artifacts"
num_workers = 6
batch_size = 32  # Note: training speed is sensitive to memory usage
                 # set this as high as you can without significantly slowing down training time 
n_epochs = 25
for month in months:
    sen2dataset = Sentinel2(month=month)
    
    # split
    train_samples = round(train_frac * len(sen2dataset))
    val_samples = round((1 - train_frac) * len(sen2dataset))
    train_dataset, val_dataset = random_split(sen2dataset, [train_samples, val_samples])
    print(f"Train samples: {len(train_dataset)} "
          f"Val. samples: {len(val_dataset)}")

    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                pin_memory=True
                                )

    val_dataloader = DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers,
                                pin_memory=True
                            )

    save_file = f"UNET_resnet50_10bandS2{month}_batch64_AGBMLinear_20epoch_10DEC.pt"
    save_path = os.path.join(dir_saved_models, save_file)
    # Kickoff training

    metrics = run_training(model=model,
                        loss_module=loss_module,
                        optimizer=optimizer,
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        save_path=save_path,
                        n_epochs=n_epochs)
    # Save the metrics to a file
    train_metrics_df = pd.DataFrame(metrics['training'], columns=["step", "score"])
    val_metrics_df = pd.DataFrame(metrics["validation"], columns=["step", "score"])
    train_metrics_df.to_csv(dir_saved_models + f"/unet_s2_{month}_train_metrics.csv")
    val_metrics_df.to_csv(dir_saved_models + f"/unet_s2_{month}_val_metrics.csv")


Train samples: 6951 Val. samples: 1738

-------------------------------
Epoch 1
Training


  0%|          | 0/218 [00:00<?, ?it/s]

FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataset.py", line 290, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/notebooks/src/biomasstry/datasets/sentinel2.py", line 157, in __getitem__
    img_data = load_raster(img_path)[:10]  # only first 10 channels, leave out cloud coverage channel
  File "/notebooks/src/biomasstry/datasets/sentinel2.py", line 31, in load_raster
    with fsspec.open(file_url, **storage_options).open() as f:
  File "/usr/local/lib/python3.9/dist-packages/fsspec/core.py", line 135, in open
    return self.__enter__()
  File "/usr/local/lib/python3.9/dist-packages/fsspec/core.py", line 103, in __enter__
    f = self.fs.open(self.path, mode=mode)
  File "/usr/local/lib/python3.9/dist-packages/fsspec/spec.py", line 1106, in open
    f = self._open(
  File "/usr/local/lib/python3.9/dist-packages/fsspec/implementations/local.py", line 175, in _open
    return LocalFileOpener(path, mode, fs=self, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/fsspec/implementations/local.py", line 273, in __init__
    self._open()
  File "/usr/local/lib/python3.9/dist-packages/fsspec/implementations/local.py", line 278, in _open
    self.f = open(self.path, mode=self.mode)
FileNotFoundError: [Errno 2] No such file or directory: '/datasets/biomassters/train_features/d81e03d4_S2_00.tif'
