# This notebook is used to train UNet model
[UNet](https://arxiv.org/pdf/1505.04597.pdf) model is best known for image segmentation 
where the task is to assign a class for each pixel in the input. 
The UNet architecture consists of a contracting path to capture context 
and a symmetric expanding path that enables precise localization. 

## Loss functions

I have used weighted average of dice loss and mean-squared loss to perform image segmentation.
Dice loss helps model to separate objects from background.
Mean-squared loss is used to assign label class to input pixels.

Also, since the distribution of pixel classes in masks is highly imbalanced, 
therefore I have used weighted losses where 
weights for classes are inversely proportional to their proportion in training data respectively

## Modules used in this project
I have used PyTorch Lightning to train, validate and test model.
The modules are present in the directory `oil_seep_detection`.

In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import os
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

In [3]:
from oil_seep_detection.trainer import LitUNet
from oil_seep_detection.datasets import SARDataModule
torch.set_default_dtype(torch.float64)
pl.utilities.seed.seed_everything(40)

Global seed set to 40


40

In [4]:
DATA_PATH = os.path.join(os.getcwd(), "data")
NUM_CLASSES = 8

IN_CHANNELS = 1
MID_CHANNELS = [64, 128, 256]
OUT_CHANNELS = 8
KERNEL_SIZE = 3
STRIDE = 1
PADDING = 1
PADDING_MODE = "zeros"
NORM = "batchnorm"
ACTIVATION = "lrelu"
DICE_COEFF = .5
LR = 1e-3
CLASS_WEIGHTS = [
    0.00253406, 0.14137538, 0.14266087, 0.14278807, 0.14272933, 0.14255343, 0.14252834, 0.14283051]

MAX_EPOCHS = 100

LOG_DIR = "logs"
EXPERIMENT_NAME = "oil_seep_detection_dice_mse_loss"

CHECKPOINT_FP = "{epoch}-{step}-{val_total_loss:.3f}-{val_acc:.3f}"
VAL_CHECK_INTERVAL = 40

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
NGPUS = 1 if DEVICE != "cpu" else None
MONITOR, MODE = "val_total_loss", "min"
SAVE_TOP_K = 1
NUM_SANITY_VAL_STEPS = 1
LOG_EVERY_N_STEPS = 1

In [5]:
logger = CSVLogger(save_dir=LOG_DIR, name=EXPERIMENT_NAME)

checkpoint_callback = ModelCheckpoint(
    filename=CHECKPOINT_FP, 
    monitor=MONITOR, 
    mode=MODE,
    save_top_k=SAVE_TOP_K)

trainer = Trainer(
    gpus=NGPUS,
    max_epochs=MAX_EPOCHS, 
    log_every_n_steps=LOG_EVERY_N_STEPS, 
    logger=logger, 
    callbacks=[checkpoint_callback], 
    num_sanity_val_steps=1)

lit_unet = LitUNet(
    in_channels= IN_CHANNELS,
    mid_channels= MID_CHANNELS,
    out_channels= OUT_CHANNELS,
    kernel_size= KERNEL_SIZE,
    stride= STRIDE,
    padding= PADDING,
    padding_mode=PADDING_MODE,
    norm=NORM,
    activation=ACTIVATION,
    dice_coeff=DICE_COEFF,
    lr=LR,
    val_mask_dir=os.path.join(logger.log_dir, "val_masks"), 
    class_weights=CLASS_WEIGHTS)

dm = SARDataModule(
    train_dirs=[
        os.path.join(DATA_PATH, "train", "images"),
        os.path.join(DATA_PATH, "train", "masks")], 
    val_dirs=[
        os.path.join(DATA_PATH, "val", "images"),
        os.path.join(DATA_PATH, "val", "masks")],
    batch_sizes=[12, 32,], 
    num_workers=[4, 2], 
    num_classes=NUM_CLASSES)

trainer.fit(model=lit_unet, datamodule=dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name | Type | Params
------------------------------
0 | unet | UNet | 1.9 M 
------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.459     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
Global seed set to 40


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]