<a href="https://colab.research.google.com/github/jaideep11061982/Pytorch-UNet/blob/master/simple_unet_pytorch_baseline_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simple Unet Baseline (Train)

This is the training part of the two part Unet Baseline for this competition.
#### Inference Notebook: [Simple Unet Baseline (Infer)][1].
You can find the notebook to create the dataset used for training [here][2] to get a better understanding of how everything works.
* Smp library is used to get the unet model.
* EfficientNetB0 is used as the backbone initialized on imagenet weight.
* Ash color images are used for training (With only the labeled frames and human_pixel_masks.
* Custom implementation of dice score is used according to this competition.
* After training, we find the best threshold for the valid set, which will then be used for the submission.
* Wandb can also be used with this notebook to log experiments, just uncomment the wandb code snippets.

**Version 5** Updates:
* Added some Augmentations
* Trained for more Epochs
* Option to increase image size

### Please upvote if you find this useful.

[1]: https://www.kaggle.com/code/shashwatraman/simple-unet-pytorch-baseline-infer
[2]: https://www.kaggle.com/code/shashwatraman/contrails-dataset-ash-color/notebook

In [1]:
install='N'
if install=='Y':

  !mkdir -p ~/.kaggle
  !cp kaggle.json ~/.kaggle/
  !pip install segmentation_models_pytorch
  #!mkdir ../input/
  #!mkdir ../input/gi-dataset-fold
  #!mkdir ../input/uwmgi-25d-stride2-dataset
  #!kaggle datasets download -d awsaf49/uwmgi-25d-stride2-dataset
  #!kaggle datasets download -d jaideepvalani/gi-trac-public-part1
  !kaggle datasets download -d shashwatraman/contrails-images-ash-color
  !pip install transformers
  !mkdir ../input/
  !mkdir ../input/contrails-images-ash-color
  !mkdir ../input/contrails-images-ash-color/contrails/
  !unzip -q  /content/contrails-images-ash-color.zip -d  .
  !mv /content/contrails/*.npy ../input/contrails-images-ash-color/contrails/

In [2]:

#!unzip -q  /content/contrails-images-ash-color.zip -d  .
#!mv /content/contrails/*.npy ../input/contrails-images-ash-color/contrails/
#!unzip -q  /content/gi-trac-public-part1.zip -d  .
!mv *.csv  ../input/contrails-images-ash-color/

mv: cannot stat '*.csv': No such file or directory


In [3]:
#!mv /content/contrails/*.npy ../input/contrails-images-ash-color/contrails/

In [4]:
#!unzip -q /content/contrails-images-ash-color.zip #-d ../input/

In [5]:
#!ls -l /content/contrails/
!nvidia-smi

Sat Jul 22 16:49:55 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    43W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Import Libraries

In [6]:
from pathlib import Path
import os
import random
import math
from collections import defaultdict
import cv2
import skimage

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import albumentations as A
import torch.nn.functional as F

from PIL import Image
from tqdm.notebook import tqdm
from transformers import get_cosine_schedule_with_warmup

torch.__version__

'2.0.1+cu118'

In [7]:
#!pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp

In [8]:
# !pip install -qU wandb
# import wandb
# wandb.login(key='')
#!mv *.csv ../input/contrails-images-ash-color/

## Data Preparation

In [9]:
class Config:
    train = True
    train_aug=True

    num_epochs = 80
    num_classes = 1
    batch_size = 32
    seed = 42

    encoder = 'timm-efficientnet-b7'
    pretrained =True
    weights = 'imagenet'
    classes = ['contrail']
    activation = None
    in_chans = 3

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    image_size = 256
    warmup = 0
    #lr = 3e-3
    lr=8e-4

class Paths:
    data_root = '/kaggle/input/google-research-identify-contrails-reduce-global-warming'
    contrails = '../input/contrails-images-ash-color/contrails/'
    train_path = '../input/contrails-images-ash-color/train_df.csv'
    valid_path = '../input/contrails-images-ash-color/valid_df.csv'

In [10]:
def set_seed(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [11]:
set_seed(9)

In [12]:
# Import dataframes
train_df = pd.read_csv(Paths.train_path)
valid_df = pd.read_csv(Paths.valid_path)

train_df['path'] = Paths.contrails + train_df['record_id'].astype(str) + '.npy'
valid_df['path'] = Paths.contrails + valid_df['record_id'].astype(str) + '.npy'

train_df.shape, valid_df.shape

((20529, 3), (1856, 3))

In [13]:
transform_size = A.Compose([
    A.Resize(Config.image_size, Config.image_size, interpolation=cv2.INTER_LANCZOS4, always_apply=True)
]) #cv2.INTER_LANCZOS4

train_transform = A.Compose([
    A.OneOf([
            A.Sharpen (alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=False, p=0.5),

# #             A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
            A.ColorJitter(brightness=0.15, contrast=0.2),
            A.RandomBrightnessContrast (brightness_limit=0.15, contrast_limit=0.2,
                                        brightness_by_max=True, always_apply=False, p=0.3),
            A.RandomGamma(gamma_limit=(30,150),p=0.2)
        ], p=0.0),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=15, p=0.25),

    A.VerticalFlip(p=0.25),
    #A.RandomResizedCrop(height=384, width=384, scale=(0.25, 1.0), p=1,interpolation=cv2.INTER_LANCZOS4)
     #A.Resize(Config.image_size, Config.image_size, interpolation=cv2.INTER_LANCZOS4, always_apply=True)
])

In [14]:
class ContrailsDataset(torch.utils.data.Dataset):
    def __init__(self, df, train=True, transform=None):

        self.df = df
        self.trn = train
        self.transform = transform

    def __getitem__(self, index):
        row = self.df.iloc[index]
        con_path = row.path
        con = np.load(str(con_path))

        img = con[..., :-1]
        label = con[..., -1]

        img = img.astype(np.float32)
        label = label.astype(np.float32)

        if Config.train_aug and self.trn:
            if self.transform is not None:
                augmented = self.transform(image=img, mask=label)
                img = augmented['image']
                label = augmented['mask']
        if not self.trn:
          if self.transform is not None:
                augmented = self.transform(image=img, mask=label)
                img = augmented['image']
                label = augmented['mask']

        #if Config.image_size != 256:
        #    img = transform_size(image=img)["image"]

        img = torch.tensor(img)
        label = torch.tensor(label)

        img = img.permute(2, 0, 1)

        return img.float(), label.float()

    def __len__(self):
        return len(self.df)

In [15]:
train_ds = ContrailsDataset(
        train_df,
        train=True,
        transform=train_transform
    )

valid_ds = ContrailsDataset(
        valid_df,
        train=False,
        transform=None #transform_size
    )

train_dl = DataLoader(train_ds, batch_size=Config.batch_size , shuffle=True, num_workers = 8)
valid_dl = DataLoader(valid_ds, batch_size=Config.batch_size, num_workers = 8)

In [16]:
train_ds[0][0].shape

torch.Size([3, 256, 256])

In [17]:
img, label = next(iter(train_dl))
img.shape, label.shape

(torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256]))

In [18]:
img, label = next(iter(valid_dl))
img.shape, label.shape

(torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256]))

In [19]:
def display_random_images(dataset, n=10, seed=None):
    if seed:
        random.seed(seed)
    random_samples_idx = random.sample(range(len(dataset)), k=n)
    plt.figure(figsize=(30, 20))

    for i, targ_sample in enumerate(random_samples_idx):
        targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]
        print(targ_image.shape)

        targ_image = targ_image.permute(1, 2, 0)

        plt.subplot(1, n, i+1)
        plt.imshow(targ_image)
        plt.axis(False)

In [20]:
#display_random_images(train_ds, 2, 42)

In [21]:
#display_random_images(valid_ds, 4, 42)

## Training

In [22]:
def dice_coef(y_true, y_pred, thr=0.5, epsilon=0.001):
    y_true = y_true.flatten()
    y_pred = (y_pred>thr).astype(np.float32).flatten()
    inter = (y_true*y_pred).sum()
    den = y_true.sum() + y_pred.sum()
    dice = ((2*inter+epsilon)/(den+epsilon))
    return dice

def dice_avg(y_p, y_t,smooth=1e-3):
    i = torch.sum(y_p * y_t, dim=(2, 3))
    u = torch.sum(y_p, dim=(2, 3)) + torch.sum(y_t, dim=(2, 3))
    score = (2 * i + smooth)/(u + smooth)
    return torch.mean(score)
def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))

def dice_loss_avg(y_p,y_t):
    return -torch.log(dice_loss(y_p,y_t))
def iou_loss(input, target,epsilon=1e-5):
    input = torch.sigmoid(input)
    smooth = 1.0
    union = (target + input - target*input).sum()
    inter = (target*input).sum()
    iou = ((inter+epsilon)/(union+epsilon)) #.mean(0)
    return iou

def iou_loss_avg(y_p,y_t):
    return -torch.log(iou_loss(y_p,y_t))

In [23]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
    def forward(self, logit, target):
        target = target.float()
        max_val = (-logit).clamp(min=0)
        loss = logit - logit * target + max_val + \
               ((-max_val).exp() + (-logit - max_val).exp()).log()
        invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        if len(loss.size())==2:
            loss = loss.sum(dim=1)
        return loss.mean()
def dice_global(y_p,y_t,smooth=1e-3):

    intersection = torch.sum(y_p * y_t)
    union = torch.sum(y_p) + torch.sum(y_t)

    dice = (2.0 * intersection + smooth) / (union + smooth)

    return dice

def dice_loss_global(y_p,y_t):
    return 1-dice_global(y_p,y_t)


class UNet(nn.Module):
    def __init__(self, cfg):
        super(UNet, self).__init__()

        self.cfg = cfg
        self.training = True

        self.model = smp.Unet(
            encoder_name=cfg.encoder,
            encoder_weights=cfg.weights,
            decoder_use_batchnorm=True,
            classes=len(cfg.classes),
            activation=cfg.activation,
        )

        self.loss_fn1 = smp.losses.SoftBCEWithLogitsLoss( )
        self.loss_fn2= smp.losses.DiceLoss(mode='binary') #returns mean ice_loss_avg#
        self.loss_fn3=smp.losses.JaccardLoss(mode='binary') #iou_loss_avg#
        self.lossfn4= FocalLoss( )

    def forward(self, imgs, targets):

        x = imgs
        y = targets

        logits = self.model(x)
        #print(logits.size())

        #if Config.image_size != 256:
        #    logits = F.interpolate(logits, size=(256, 256), mode='nearest-exact')+self.loss_fn3(logits, y)*0.1

        loss = self.loss_fn1(logits.squeeze(1), y )*0.5+0.5*(self.loss_fn2(logits, y)#*0.5
                                                #+self.loss_fn3(logits,y)*0.2
                                                             )

        return {"loss": loss, "logits": logits.sigmoid(), "logits_raw": logits, "target": y}

In [24]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [25]:
scaler = torch.cuda.amp.GradScaler()
def train_step(model, dataloader, optimizer, device):

    model.train()

    train_losses = []
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    losses = AverageMeter()

    for step, (X, y) in pbar:

        X, y = X.to(device), y.to(device)
        torch.set_grad_enabled(True)
        bs=X[0].shape[0]
        with torch.cuda.amp.autocast():

	          # Forward Pass
            #outputs = model(inputs)
            # Compute Loss and Perform Back-propagation
	          # loss = loss_fn(outputs, labels)


          output_dict = model(X, y)
        loss = output_dict["loss"]
        scaler.scale(loss).backward()

        # Uncales gradients and calls
        # or skips optimizer.step()
        scaler.step(optimizer)

        # Updates the scale for next iteration
        scaler.update()
        losses.update(loss.item(), bs)
        #loss = output_dict["loss"]
        train_losses.append(loss.item())
        pbar.set_description(f"Train - Loss:{losses.avg:0.4f} ")

        #loss.backward()
        #optimizer.step()
        optimizer.zero_grad()

        if scheduler is not None:
            scheduler.step()

    train_loss = np.sum(train_losses)

    return train_loss

In [26]:
def test_step(model, dataloader, device):

    model.eval()
    torch.set_grad_enabled(False)

    val_data = defaultdict(list)
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid')
    losses=AverageMeter()
    for step, (X, y) in pbar:

        X, y = X.to(device), y.to(device)
        bs=X[0].shape[0]
        output = model(X, y)
        loss = output["loss"]
        for key, val in output.items():
            val_data[key] += [output[key]]
        losses.update(loss.item(), bs)
        pbar.set_description(f"Val Loss:{losses.avg:0.4f} ")
    for key, val in output.items():
        value = val_data[key]
        if len(value[0].shape) == 0:
            val_data[key] = torch.stack(value)
        else:
            val_data[key] = torch.cat(value, dim=0).cpu().detach().numpy()

    val_losses = val_data["loss"].cpu().numpy()
    val_loss = np.sum(val_losses)

    val_dice = dice_coef(val_data['target'], val_data['logits'])

    return val_loss, val_dice

In [27]:
from tqdm.auto import tqdm

In [28]:
def train(model, train_dataloader, test_dataloader, optimizer, epochs, device):
    results = {'train_loss': [],
              'val_loss': [],
              'val_dice': []}
    best_dice=-1

    for epoch in range(epochs):

        set_seed(Config.seed + epoch)
        print("EPOCH:", epoch)

        train_loss = train_step(model,
                              train_dataloader,
                              optimizer,
                              device)
        val_loss, val_dice = test_step(model,
                            test_dataloader,
                            device)

        train_loss = train_loss / len(train_ds)
        val_loss = val_loss / len(valid_ds)

        print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}')
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")

        results['train_loss'].append(train_loss)
        results['val_loss'].append(val_loss)
        results['val_dice'].append(val_dice)

#         wandb.log({
#         "Train Loss": train_loss,
#         "Valid Loss": val_loss,
#         'Valid Dice': val_dice})
        if best_dice<val_dice:
          best_dice=val_dice
          print('saving best dice',val_dice)

          PATH = f"epoch-{epoch}_best.pth"
          torch.save(model.state_dict(), PATH)

#         wandb.save(PATH)

    return results

In [29]:
def get_optimizer(lr, params):

    model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, params),
            lr=lr,
            weight_decay=3e-5)

    return model_optimizer

In [30]:
def get_scheduler(cfg, optimizer, total_steps,annealing=True):
  if not annealing:

    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps= cfg.warmup * (total_steps // cfg.batch_size),
        num_training_steps= cfg.num_epochs * (total_steps // cfg.batch_size)
    )
  else:
      scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 8e-4,
                                                           eta_min=7e-5, last_epoch=- 1, verbose=False)
  return scheduler

In [31]:
state['model'].keys()
('module.model.encoder._conv_stem.weight' ).replace('._','.')

NameError: ignored

In [32]:
#model.state_dict().keys()

from collections import OrderedDict
model = UNet(Config).to(Config.device)
'''
state=torch.load('/content/drive/MyDrive/vision_Exp/checkpoint_dice_ctrl_fold0.pth')
#torch.load('/content/drive/MyDrive/vision_Exp/epoch-29_best.pth')
state=( (k,v) for k,v in zip(model.state_dict().keys(),state['model'].values()) )
#state={k.replace('._','.'):v for k,v in state['model'].items()}
state=OrderedDict(state)
state.keys()
'''

"\nstate=torch.load('/content/drive/MyDrive/vision_Exp/checkpoint_dice_ctrl_fold0.pth')\n#torch.load('/content/drive/MyDrive/vision_Exp/epoch-29_best.pth')\nstate=( (k,v) for k,v in zip(model.state_dict().keys(),state['model'].values()) )\n#state={k.replace('._','.'):v for k,v in state['model'].items()}\nstate=OrderedDict(state)\nstate.keys()\n"

In [34]:
NUM_EPOCHS = Config.num_epochs
model=nn.DataParallel(model ).cuda()

#state=torch.load('/content/drive/MyDrive/vision_Exp/checkpoint_dice_ctrl_fold0.pth')
state=torch.load('/content/drive/MyDrive/vision_Exp/epoch-73_best.pth')
#state=( (k,v) for k,v in zip(model.state_dict().keys(),state['model'].values()) )
#state={k.replace('._','.'):v for k,v in state['model'].items()}
#state=OrderedDict(state)
model.load_state_dict(state  )
print('x')
# run = wandb.init(project='Google Contrails',
#                      config={k:v for k, v in dict(vars(Config)).items() if '__' not in k},
#                      name=f"{Config.encoder}-{Config.num_epochs}epos-{Config.lr}-unet"
#                     )
#i=torch.randn(2,3,384,384).to(Config.device)

#model(i,i[:,0,:,:].unsqueeze(1))['logits'].shape

x


In [52]:
val_loss, val_dice = test_step(model,
                            valid_dl,
                            Config.device)
val_loss, val_dice

Valid:   0%|          | 0/78 [00:00<?, ?it/s]

(34.61483, 0.15687073161309625)

In [221]:
torch.cuda.empty_cache()
val_dice

0.6208948367943696

In [35]:
total_steps = len(train_ds)
optimizer = get_optimizer(lr=Config.lr, params=model.parameters())
scheduler = get_scheduler(Config, optimizer, total_steps)

# wandb.watch(model, log_freq=100, log='all') 663

from timeit import default_timer as timer
start_time = timer()

model_results = train(model, train_dl, valid_dl, optimizer, NUM_EPOCHS, Config.device)

end_time = timer()

# run.finish() 663
print(f'Total Training Time: {end_time-start_time:.3f} seconds')

EPOCH: 0


Train :   0%|          | 0/642 [00:00<?, ?it/s]

Valid:   0%|          | 0/58 [00:00<?, ?it/s]

Train Loss: 0.0053 | Val Loss: 0.0064 | Val Dice: 0.6214
Learning rate: 0.0008
saving best dice 0.6214187383143658
EPOCH: 1


Train :   0%|          | 0/642 [00:00<?, ?it/s]

Valid:   0%|          | 0/58 [00:00<?, ?it/s]

Train Loss: 0.0054 | Val Loss: 0.0063 | Val Dice: 0.6226
Learning rate: 0.0008
saving best dice 0.6225820774452732
EPOCH: 2


Train :   0%|          | 0/642 [00:00<?, ?it/s]

Valid:   0%|          | 0/58 [00:00<?, ?it/s]

Train Loss: 0.0055 | Val Loss: 0.0060 | Val Dice: 0.6366
Learning rate: 0.0008
saving best dice 0.6365892407318583
EPOCH: 3


Train :   0%|          | 0/642 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:


total_steps = len(train_ds)
optimizer = get_optimizer(lr=Config.lr, params=model.parameters())
scheduler = get_scheduler(Config, optimizer, total_steps)

# wandb.watch(model, log_freq=100, log='all')

from timeit import default_timer as timer
start_time = timer()

model_results = train(model, train_dl, valid_dl, optimizer, NUM_EPOCHS, Config.device)

end_time = timer()

# run.finish()
print(f'Total Training Time: {end_time-start_time:.3f} seconds')

## Finding the Best Threshold

In [None]:
# Predicting the Valid Set
model.eval()
torch.set_grad_enabled(False)

val_data = defaultdict(list)
pbar = tqdm(enumerate(valid_dl), total=len(valid_dl), desc='Valid')
for step, (X, y) in pbar:
    X, y = X.to(Config.device), y.to(Config.device)

    output = model(X, y)
    for key, val in output.items():
        val_data[key] += [output[key]]

for key, val in output.items():
    value = val_data[key]
    if len(value[0].shape) == 0:
        val_data[key] = torch.stack(value)
    else:
        val_data[key] = torch.cat(value, dim=0).cpu().detach().numpy()

val_losses = val_data["loss"].cpu().numpy()
val_loss = np.sum(val_losses)
val_loss = val_loss / len(valid_ds)

In [None]:
predictions = val_data['logits']
ground_truths = val_data['target']

In [None]:
predictions.shape, ground_truths.shape

In [None]:
# Finding the Best Threshold
bdice = -1
bi = None
for i in tqdm(np.arange(0, 1.01, 0.01)):
    val_dice = dice_coef(ground_truths, predictions, i)
    if val_dice > bdice:
        bdice = val_dice
        bi = i

In [None]:
print(f'Best Threshold: {bi}')
print(f'Best Validation Dice Score: {bdice}')