# Train Baseline UNET model with Single Satellite Image

In [15]:
import os
from time import time

In [16]:
import numpy as np
import pandas as pd
import seaborn as sns
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm

In [17]:
sns.set()

In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
from biomasstry.datasets import Sentinel2

In [20]:
sen2dataset = Sentinel2()

In [21]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


In [22]:
# split
torch.manual_seed(0)
train_frac = 0.8
train_samples = round(train_frac * len(sen2dataset))
val_samples = round((1 - train_frac) * len(sen2dataset))

train_dataset, val_dataset = random_split(sen2dataset, [train_samples, val_samples])
print(f"Train samples: {len(train_dataset)} "
      f"Val. samples: {len(val_dataset)}")

Train samples: 6951 Val. samples: 1738


In [23]:
# Model
img_data = train_dataset[0]['image']
in_channels = img_data.shape[0]
print(f'# input channels: {in_channels}')
print(f"Image shape: {img_data.shape}")

model = smp.Unet(
    encoder_name="resnet50",
    encoder_weights=None, # 'imagenet' weights don't seem to help so start clean 
    in_channels=in_channels,                 
    classes=1,                     
).to(device)

# input channels: 10
Image shape: torch.Size([10, 256, 256])


In [24]:
# Loss and Optimizer
loss_module = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

In [25]:
# Train and Validation Loops
def train_loop(dataloader, model, loss_fn, optimizer):
    train_metrics = []
    
    print('Training')
    for ix, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        X = batch['image'].to(device)
        y = batch['target'].to(device)
        
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_metrics.append(np.round(np.sqrt(loss.item()), 5))
            
    return train_metrics

In [26]:
def valid_loop(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    valid_loss = 0
    valid_metrics = {}

    print('Validation')
    with torch.no_grad():
        for batch in tqdm(dataloader, total=num_batches):
            X = batch['image'].to(device)
            y = batch['target'].to(device)
            
            pred = model(X)
            valid_loss += loss_fn(pred, y).item()
            
    valid_loss /= num_batches
    valid_rmse = np.round(np.sqrt(valid_loss), 5)
    print(f"Validation Error: \n RMSE: {valid_rmse:>8f} \n")
    return valid_rmse

In [27]:
def run_training(model, loss_module, optimizer, train_dataloader, val_dataloader, save_path, n_epochs=10):
    min_valid_metric = np.inf
    train_metrics = []
    valid_metrics = []
    
    total_train_time = 0
    total_val_time = 0

    for ix in range(n_epochs):
        print(f"\n-------------------------------\nEpoch {ix+1}")
        start = time()
        train_metrics_epoch = train_loop(train_dataloader, model, loss_module, optimizer)
        end = time()
        train_time = end - start
        total_train_time += train_time
        train_metrics.extend(train_metrics_epoch)
        
        start = time()
        valid_metrics_epoch = valid_loop(val_dataloader, model, loss_module)
        end = time()
        val_time = end - start
        total_val_time += val_time
        valid_metrics.append((len(train_metrics), valid_metrics_epoch))

        # check validation score, if improved then save model
        if min_valid_metric > valid_metrics_epoch:
            print(f'Validation RMSE Decreased({min_valid_metric:.6f}--->{valid_metrics_epoch:.6f}) \t Saving The Model')
            min_valid_metric = valid_metrics_epoch

            # Saving State Dict
            torch.save(model.state_dict(), save_path)
        print(f"Train time: {train_time}. Validation time: {val_time}")
    print("Done!")
    print(f"Total train time: {total_train_time} s. Avg. time per epoch: {total_train_time / n_epochs}")
    print(f"Total val time: {total_val_time} s. Avg. time per epoch: {total_val_time / n_epochs}")
    train_metrics_zipped = list(zip(np.arange(0, len(train_metrics)), train_metrics))
    
    return {'training': train_metrics_zipped, 'validation': valid_metrics}

## Experiment with `num_workers` and `batch_size` for tuning `DataLoader` Throughput

In [28]:
# DataLoaders
# num_workers = 4
# batch_size = 64  # Note: training speed is sensitive to memory usage
                 # set this as high as you can without significantly slowing down training time 

dir_saved_models = "../artifacts"
# Expt. with num_workers and batch_size
timing = []
for batch_size in [64]:
    print(f"Batch size: {batch_size}")
    for num_workers in [6]:
        print(f"Number of workers = {num_workers}")
        train_dataloader = DataLoader(train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=num_workers,
                                    pin_memory=True
                                    )

        val_dataloader = DataLoader(val_dataset,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=num_workers,
                                    pin_memory=True
                                )

        save_file = f"UNET_resnet50_10bandS2Apr_batch_AGBMLinear_1epoch_08DEC.pt"
        save_path = os.path.join(dir_saved_models, save_file)
        # Kickoff training
        n_epochs = 10
        # start = time()
        metrics = run_training(model=model,
                            loss_module=loss_module,
                            optimizer=optimizer,
                            train_dataloader=train_dataloader,
                            val_dataloader=val_dataloader,
                            save_path=save_path,
                            n_epochs=n_epochs)
        # epoch_time = time() - start
        # timing.append((num_workers, batch_size, epoch_time))
        # print(f"time for one epoch = {epoch_time}")

Batch size: 64
Number of workers = 6

-------------------------------
Epoch 1
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 52.592200 

Validation RMSE Decreased(inf--->52.592200) 	 Saving The Model
Train time: 114.13965463638306. Validation time: 20.567407369613647

-------------------------------
Epoch 2
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 51.302870 

Validation RMSE Decreased(52.592200--->51.302870) 	 Saving The Model
Train time: 115.15396451950073. Validation time: 24.65321135520935

-------------------------------
Epoch 3
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 49.778640 

Validation RMSE Decreased(51.302870--->49.778640) 	 Saving The Model
Train time: 115.72116589546204. Validation time: 20.784033060073853

-------------------------------
Epoch 4
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 49.052950 

Validation RMSE Decreased(49.778640--->49.052950) 	 Saving The Model
Train time: 116.42943787574768. Validation time: 21.07456374168396

-------------------------------
Epoch 5
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 48.856960 

Validation RMSE Decreased(49.052950--->48.856960) 	 Saving The Model
Train time: 116.56645798683167. Validation time: 20.7010817527771

-------------------------------
Epoch 6
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 48.622590 

Validation RMSE Decreased(48.856960--->48.622590) 	 Saving The Model
Train time: 116.66453623771667. Validation time: 20.856730699539185

-------------------------------
Epoch 7
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 47.974430 

Validation RMSE Decreased(48.622590--->47.974430) 	 Saving The Model
Train time: 116.04930424690247. Validation time: 20.74225664138794

-------------------------------
Epoch 8
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 47.653650 

Validation RMSE Decreased(47.974430--->47.653650) 	 Saving The Model
Train time: 116.13060021400452. Validation time: 21.036286115646362

-------------------------------
Epoch 9
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 47.049350 

Validation RMSE Decreased(47.653650--->47.049350) 	 Saving The Model
Train time: 115.4872636795044. Validation time: 22.070374011993408

-------------------------------
Epoch 10
Training


  0%|          | 0/109 [00:00<?, ?it/s]

Validation


  0%|          | 0/28 [00:00<?, ?it/s]

Validation Error: 
 RMSE: 47.059340 

Train time: 115.53460574150085. Validation time: 20.003660678863525
Done!
Total train time: 1157.876991033554 s. Avg. time per epoch: 115.78769910335541
Total val time: 212.48960542678833 s. Avg. time per epoch: 21.248960542678834


In [29]:
# sns.catplot(data=timing_df, x="workers", y="time", hue="batch", kind="bar")