In [1]:
import os
import math
import random
import pandas as pd
import numpy as np

import cv2
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import notebook
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn import metrics

import torch
import torch.nn as nn

from torch.nn import Parameter
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim.lr_scheduler import _LRScheduler
from torchvision import transforms as T

from torch.utils.tensorboard import SummaryWriter

from albumentations.pytorch import ToTensorV2
import albumentations as A
from pathlib import Path

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
import shutil
shutil.copy('/content/drive/MyDrive/unet mask/mask_data.zip', '/content/')

'/content/mask_data.zip'

In [3]:
!unzip mask_data.zip

Archive:  mask_data.zip
   creating: mask_data/images/
   creating: mask_data/images/test/
  inflating: mask_data/images/test/7.jpg  
  inflating: mask_data/images/test/78.jpg  
  inflating: mask_data/images/test/79.jpg  
  inflating: mask_data/images/test/8.jpg  
  inflating: mask_data/images/test/80.jpg  
  inflating: mask_data/images/test/81.jpg  
  inflating: mask_data/images/test/82.jpg  
  inflating: mask_data/images/test/83.jpg  
  inflating: mask_data/images/test/84.jpg  
  inflating: mask_data/images/test/85.jpg  
  inflating: mask_data/images/test/86.jpg  
  inflating: mask_data/images/test/87.jpg  
  inflating: mask_data/images/test/88.jpg  
  inflating: mask_data/images/test/89.jpg  
  inflating: mask_data/images/test/9.jpg  
  inflating: mask_data/images/test/90.jpg  
  inflating: mask_data/images/test/91.jpg  
  inflating: mask_data/images/test/92.jpg  
  inflating: mask_data/images/test/93.jpg  
  inflating: mask_data/images/test/94.jpg  
  inflating: mask_data/images/te

In [4]:
class configs:
    IMAGE_SIZE = 224
    NUM_WORKERS = 2
    BATCH_SIZE = 8
    EPOCHS = 100
    SEED = 69
    CHECKPOINT = "/content/drive/MyDrive/unet mask/checkpoints/ResUnet++/"
    MODEL_NAME = 'SCSEUnet'
    scheduler_params = {
        "lr_start": 3e-5,
        "lr_max": 5e-5,
        "lr_min": 1e-5,
        "lr_ramp_ep": 5,
        "lr_sus_ep": 0,
        "lr_decay": 0.8,
    }

    model_params = {
      'channel':3
    }

In [5]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [6]:
def seed_torch(seed=1):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(configs.SEED)

In [7]:
extra_trns = A.Compose([
    A.OneOf([
        A.Resize(configs.IMAGE_SIZE,configs.IMAGE_SIZE,p=1),
        A.RandomResizedCrop(configs.IMAGE_SIZE,configs.IMAGE_SIZE,p=1)
        ]),
    A.OneOf([
        A.HorizontalFlip(p=1),
        A.VerticalFlip(p=1),
        A.RandomRotate90(p=1)
    ]),
    A.Transpose(p=0.5),
    ToTensorV2()
])

train_trns = T.Compose([
    T.ToTensor(),
    T.RandomApply([T.Lambda(lambda x : x + (0.1**0.7)*torch.randn(3,configs.IMAGE_SIZE,configs.IMAGE_SIZE))],p=0.08),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

test_trns = T.Compose([
    T.ToPILImage(),
    T.Resize(size = (configs.IMAGE_SIZE,configs.IMAGE_SIZE)),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

mask_trns = T.Compose([
    T.ToPILImage(),
    T.Resize(size = (configs.IMAGE_SIZE,configs.IMAGE_SIZE)),
    T.ToTensor(),
])

In [8]:
def dim_check(img_path,mask_path):
  img = cv2.imread(img_path)
  mask = cv2.imread(mask_path)
  return img.shape == mask.shape

In [10]:
folder_path = Path("./mask_data")
train_img_paths = []
train_mask_paths = []
test_img_paths = []
test_mask_paths = []
for file_path in folder_path.rglob("*"):
  if file_path.is_file():
    path_ls = str(file_path). split("/")
    file_path = str(file_path)

    if "train" in path_ls:
      if "masks" in path_ls:
        train_mask_paths.append(file_path)
      elif "images" in path_ls:
          train_img_paths.append (file_path)
    elif "test" in path_ls:
      if "masks" in path_ls:
        test_mask_paths.append(file_path)
      elif "images" in path_ls:
        test_img_paths.append(file_path)

train_img_paths.sort()
train_mask_paths.sort()
test_img_paths.sort()
test_mask_paths.sort()

final_train_img_paths = []
final_train_mask_paths = []
final_test_img_paths = []
final_test_mask_paths = []

for i in range(len(train_img_paths)):
  if dim_check(train_img_paths[i],train_mask_paths[i]):
    final_train_img_paths.append(train_img_paths[i])
    final_train_mask_paths.append(train_mask_paths[i])

for i in range(len(test_img_paths)):
  if dim_check(test_img_paths[i],test_mask_paths[i]):
    final_test_img_paths.append(test_img_paths[i])
    final_test_mask_paths.append(test_mask_paths[i])


df_train =pd.DataFrame({"img_locs":final_train_img_paths, "mask_locs":final_train_mask_paths})
df_test =pd.DataFrame({"img_locs":final_test_img_paths, "mask_locs":final_test_mask_paths})
df_train.to_csv("./train.csv",index=False)
df_test.to_csv("./test.csv",index=False)

In [11]:
df_train.head()

Unnamed: 0,img_locs,mask_locs
0,mask_data/images/train/1.jpg,mask_data/masks/train/1.jpg
1,mask_data/images/train/10.jpg,mask_data/masks/train/10.jpg
2,mask_data/images/train/100.jpg,mask_data/masks/train/100.jpg
3,mask_data/images/train/11.jpg,mask_data/masks/train/11.jpg
4,mask_data/images/train/12.jpg,mask_data/masks/train/12.jpg


In [12]:
class UNetDataLoader(Dataset):
  def __init__(self, split_type, img_locs, mask_locs):
    super().__init__()
    self.img_locs = img_locs
    self.mask_locs = mask_locs
    self.spatial_trns = extra_trns
    self.train_trns = train_trns
    self.test_trns = test_trns
    self.mask_trns = mask_trns
    self.split_type = split_type

  def __len__(self):
    return len(self.img_locs)
  def __getitem__(self, idx):
    img_loc = self.img_locs[idx]
    mask_loc = self.mask_locs[idx]
    img = cv2.imread(img_loc)
    img = cv2. cvtColor (img, cv2. COLOR_BGR2RGB)
    mask = cv2. imread(mask_loc)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    if self.split_type == "train":
      transformed = self.spatial_trns(image=img, mask=mask)
      img = transformed["image"]
      mask = transformed["mask"]
      mask = self.mask_trns(mask)
      img = self.mask_trns(img)
    else:
      img = self.test_trns(img)
      mask = self.mask_trns(mask)
    return img, mask

In [13]:
train_ds = UNetDataLoader("train",df_train["img_locs"],df_train["mask_locs"])
train_loader = DataLoader(train_ds, batch_size=configs.BATCH_SIZE, shuffle=True, drop_last=True)
val_ds = UNetDataLoader("test",df_test["img_locs"],df_test["mask_locs"])
val_loader = DataLoader(val_ds, batch_size=configs.BATCH_SIZE, shuffle=False, drop_last=True)


In [14]:
datat = next(iter(train_loader))

In [15]:
class lr_scheduler(_LRScheduler):
    def __init__(self, optimizer, lr_start=5e-6, lr_max=1e-5,
                 lr_min=1e-6, lr_ramp_ep=5, lr_sus_ep=0, lr_decay=0.8,
                 last_epoch=-1):
        self.lr_start = lr_start
        self.lr_max = lr_max
        self.lr_min = lr_min
        self.lr_ramp_ep = lr_ramp_ep
        self.lr_sus_ep = lr_sus_ep
        self.lr_decay = lr_decay
        super(lr_scheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            self.last_epoch += 1
            return [self.lr_start for _ in self.optimizer.param_groups]

        lr = self._compute_lr_from_epoch()
        self.last_epoch += 1

        return [lr for _ in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return self.base_lrs

    def _compute_lr_from_epoch(self):
        if self.last_epoch < self.lr_ramp_ep:
            lr = ((self.lr_max - self.lr_start) /
                  self.lr_ramp_ep * self.last_epoch +
                  self.lr_start)

        elif self.last_epoch < self.lr_ramp_ep + self.lr_sus_ep:
            lr = self.lr_max

        else:
            lr = ((self.lr_max - self.lr_min) * self.lr_decay**
                  (self.last_epoch - self.lr_ramp_ep - self.lr_sus_ep) +
                  self.lr_min)
        return lr

In [16]:
from core.res_unet_plus import ResUnetPlusPlus

In [17]:
from torch import nn


class BCEDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()

    def forward(self, input, target):
        pred = input.view(-1)
        truth = target.view(-1)

        # BCE loss
        bce_loss = nn.BCELoss()(pred, truth).double()

        # Dice Loss
        dice_coef = (2.0 * (pred * truth).double().sum() + 1) / (
            pred.double().sum() + truth.double().sum() + 1
        )

        return bce_loss + (1 - dice_coef)


# https://github.com/pytorch/examples/blob/master/imagenet/main.py
class MetricTracker(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


# https://stackoverflow.com/questions/48260415/pytorch-how-to-compute-iou-jaccard-index-for-semantic-segmentation
def jaccard_index(input, target):

    intersection = (input * target).long().sum().data.cpu()[0]
    union = (
        input.long().sum().data.cpu()[0]
        + target.long().sum().data.cpu()[0]
        - intersection
    )

    if union == 0:
        return float("nan")
    else:
        return float(intersection) / float(max(union, 1))


# https://github.com/pytorch/pytorch/issues/1249
def dice_coeff(input, target):
    num_in_target = input.size(0)

    smooth = 1.0

    pred = input.view(num_in_target, -1)
    truth = target.view(num_in_target, -1)

    intersection = (pred * truth).sum(1)

    loss = (2.0 * intersection + smooth) / (pred.sum(1) + truth.sum(1) + smooth)

    return loss.mean().item()

In [18]:
class trainer:
    def __init__(self,train_dataloader,val_dataloader,load_checkpoint = False):
        if(load_checkpoint):
            print("Loading pretrained model...")
            self.model = torch.load(configs.CHECKPOINT)
        else:
            self.model  = ResUnetPlusPlus(**configs.model_params)
        no_decay = ['bias','LayerNorm.bias','LayerNorm.weight']
        para_optimizer = list(self.model.named_parameters())
        self.optimizer_parameters = [
        {'params':[p for n,p in para_optimizer if not any(nd in n for nd in no_decay)],'weight_decay':1e-5},
        {'params':[p for n,p in para_optimizer if  any(nd in n for nd in no_decay)],'weight_decay':0.0}
        ]

        self.optimizer = AdamW(amsgrad = True,params = self.optimizer_parameters,lr = 3e-5)
        self.lr_scheduler = lr_scheduler(self.optimizer,**configs.scheduler_params)
        self.criterion = BCEDiceLoss()
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.model = self.model.to(device)
        self.writer = SummaryWriter(configs.CHECKPOINT+"tboard")
        self.step_train = 0
        self.step_val = 0



    def train_fn(self,epoch):
        self.model.train()
        count,total_loss = 0,0
        loop = notebook.tqdm(enumerate(self.train_dataloader),total = len(self.train_dataloader))
        for bi,data in loop:
            images = data[0].to(device)
            targets = data[1].to(device)
            preds =  self.model(images)
            self.optimizer.zero_grad()
            loss =self.criterion(preds, targets)
            total_loss += loss.item()
            self.writer.add_scalar("Training Loss",loss.item(),global_step=self.step_train)
            count +=1
            self.step_train +=1
            loss.backward()
            self.optimizer.step()
            loop.set_postfix(Epoch=epoch,Avg_Train_Loss=total_loss/count,Current_Train_Loss=loss.item(),LR=self.optimizer.param_groups[0]['lr'])
        self.lr_scheduler.step()

    def eval_fn(self,epoch):
        self.model.eval()
        count,total_loss = 0,0
        with torch.no_grad():
            loop = notebook.tqdm(enumerate(self.val_dataloader),total = len(self.val_dataloader))
            for bi,data in loop:
                images = data[0].to(device)
                targets = data[1].to(device)
                preds =  self.model(images)
                loss =self.criterion(preds, targets)
                total_loss += loss.item()
                self.writer.add_scalar("Validation Loss",loss.item(),global_step=self.step_val)
                count +=1
                self.step_val +=1
                loop.set_postfix(Epoch=epoch,Avg_Val_Loss=total_loss/count,Current_Val_Loss=loss.item())

        return total_loss / count

    def run(self,epochs = 5):
        best_loss = 1e9
        for epoch in range (epochs):
            self.train_fn(epoch)
            val_loss = self.eval_fn(epoch)
            print("Epoch {} complete! Validation Loss : {}".format(epoch, val_loss))
            if best_loss > val_loss:
                print("Best validation loss improved from {} to {}, saving model...".format(best_loss, val_loss))
                best_loss = val_loss
                torch.save(self.model, configs.CHECKPOINT+"model.pt")


In [19]:
train = trainer(train_loader,val_loader)

In [None]:
train.run(configs.EPOCHS)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Epoch 0 complete! Validation Loss : 1.5268886698502302
Best validation loss improved from 1000000000.0 to 1.5268886698502302, saving model...


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Epoch 1 complete! Validation Loss : 16.848109276241946


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Epoch 2 complete! Validation Loss : 43.40152724354066


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Epoch 3 complete! Validation Loss : 61.326002025707105


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Epoch 4 complete! Validation Loss : 54.84373019318303


  0%|          | 0/20 [00:00<?, ?it/s]