In [1]:
import os
import math
import random
import pandas as pd
import numpy as np

import cv2
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import notebook
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn import metrics

import torch
import torch.nn as nn

from torch.nn import Parameter
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim.lr_scheduler import _LRScheduler
from torchvision import transforms as T

from torch.utils.tensorboard import SummaryWriter

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [47]:
class configs:
    IMAGE_SIZE = 100
    NUM_WORKERS = 2
    BATCH_SIZE = 2
    EPOCHS = 4
    SEED = 43
    CHECKPOINT = "/checkpoint/model.pt"
    MODEL_NAME = 'custom'
    scheduler_params = {
        "lr_start": 3e-5,
        "lr_max": 3e-5 * BATCH_SIZE,
        "lr_min": 1e-6,
        "lr_ramp_ep": 5,
        "lr_sus_ep": 0,
        "lr_decay": 0.8,
    }

    model_params = {
    }

In [48]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [49]:
def seed_torch(seed=1):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(configs.SEED)

In [68]:
train_trns = T.Compose([
    T.ToPILImage(),
    T.Resize(size = (configs.IMAGE_SIZE,configs.IMAGE_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=(0.8,1.2)),
    T.RandomAffine(degrees=(-10,10),translate =(0.1,0.1),shear =(-5,5,-5,5),interpolation = T.InterpolationMode.BILINEAR),
    T.RandomPerspective(distortion_scale=0.3),
    T.ToTensor(),
    T.RandomApply([T.Lambda(lambda x : x + (0.1**0.7)*torch.randn(3,configs.IMAGE_SIZE,configs.IMAGE_SIZE))],p=0.08),
    T.Grayscale(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

val_trns = T.Compose([
    T.ToPILImage(),
    T.Resize(size = (configs.IMAGE_SIZE,configs.IMAGE_SIZE)),
    T.ToTensor(),
    T.Grayscale(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [69]:
import pathlib,glob,os
data_loc = r"./data/"
data_loc = pathlib.Path(data_loc)
all_data = list(data_loc.glob('*.*'))

In [70]:
all_data

[PosixPath('data/4 (2).jpg'),
 PosixPath('data/5.jpg'),
 PosixPath('data/color_3_0029.png'),
 PosixPath('data/0 (2).jpg'),
 PosixPath('data/2.jpg'),
 PosixPath('data/0.jpg'),
 PosixPath('data/color_3_0002.png'),
 PosixPath('data/6 (2).jpg'),
 PosixPath('data/25.jpg')]

In [71]:
file = open('images_data.txt',"w+")
file.write('img_loc'+ "\n")
for i,x in enumerate(tqdm(all_data)):
    w = str(x)+"\n"
    file.write(w)
file.close()

100%|██████████| 9/9 [00:00<00:00, 9295.43it/s]


In [72]:
df = pd.read_csv('images_data.txt')

In [73]:
df_train,df_val = train_test_split(df,test_size = 0.25,random_state = configs.SEED)
df_train.reset_index(drop=True,inplace=True)
df_val.reset_index(drop=True,inplace=True)

In [77]:
class SiameseNetworkDataset(Dataset):
  def __init__(self,img_locs,transform = None):
    self.img_locs = img_locs
    self.trns = transform

  def __len__(self):
    return len(self.img_locs)

  def __getitem__(self,idx):
    img_loc = self.img_locs[idx]
    imgA = cv2.imread(img_loc)
    imgA = cv2.cvtColor(imgA,cv2.COLOR_BGR2RGB)
    same_img = random.choice([True,False])
    target = 1
    #imgB = None
    if same_img:
      imgB = cv2.imread(img_loc)
      imgB = cv2.cvtColor(imgB,cv2.COLOR_BGR2RGB)
    else :
      img_loc2 = ""
      count = 5
      while count:
        count-=1
        idx2 = random.randint(0,len(self.img_locs)-1)
        if idx2 != idx:
          img_loc2 = self.img_locs[idx2]

      if img_loc2 == "":
        imgB = cv2.imread(img_loc)
        imgB = cv2.cvtColor(imgB,cv2.COLOR_BGR2RGB)

      else:
        imgB = cv2.imread(img_loc2)
        imgB = cv2.cvtColor(imgB,cv2.COLOR_BGR2RGB)
        target = 0
    imgA = self.trns(imgA)
    imgB = self.trns(imgB)

    return imgA, imgB, target

In [78]:
train_ds = SiameseNetworkDataset(df_train["img_loc"],train_trns)
train_loader = DataLoader(train_ds, batch_size=configs.BATCH_SIZE, shuffle=True, drop_last=True)
val_ds = SiameseNetworkDataset(df_val["img_loc"],val_trns)
val_loader = DataLoader(val_ds, batch_size=configs.BATCH_SIZE, shuffle=False, drop_last=True)


In [79]:
data = next(iter(train_loader))

RuntimeError: ignored

In [60]:
len(data),data[0].shape

(3, torch.Size([2, 3, 100, 100]))

In [61]:
def convBlock(ni,no):
  return nn.Sequential(
      nn.Dropout(0.2),
      nn.Conv2d(ni,no,kernel_size=3,padding = 1,padding_mode = 'reflect'),
      nn.ReLU(inplace = True),
      nn.BatchNorm3d(no),
  )

In [62]:
class SiameseNetwork(nn.Module):
  def __init__(self):
    super(SiameseNetwork,self).__init__()
    self.features = nn.Sequential(
        convBlock(1,4),
        convBlock(4,8),
        convBlock(8,8),
        nn.Flatten(),
        nn.Linear(8*100*100, 500),
        nn.ReLU(inplace = True),
        nn.Linear(500,500),
        nn.ReLU(inplace = True),
        nn.Linear(500,5)
    )
  def forward(self,input1,input2):
    output1 = self.features(input1)
    output2 = self.features(input2)
    return output1,output2

In [63]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        acc = ((euclidean_distance > 0.6) == label).float().mean()
        return loss_contrastive, acc


In [64]:
class lr_scheduler(_LRScheduler):
    def __init__(self, optimizer, lr_start=5e-6, lr_max=1e-5,
                 lr_min=1e-6, lr_ramp_ep=5, lr_sus_ep=0, lr_decay=0.8,
                 last_epoch=-1):
        self.lr_start = lr_start
        self.lr_max = lr_max
        self.lr_min = lr_min
        self.lr_ramp_ep = lr_ramp_ep
        self.lr_sus_ep = lr_sus_ep
        self.lr_decay = lr_decay
        super(lr_scheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            self.last_epoch += 1
            return [self.lr_start for _ in self.optimizer.param_groups]

        lr = self._compute_lr_from_epoch()
        self.last_epoch += 1

        return [lr for _ in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return self.base_lrs

    def _compute_lr_from_epoch(self):
        if self.last_epoch < self.lr_ramp_ep:
            lr = ((self.lr_max - self.lr_start) /
                  self.lr_ramp_ep * self.last_epoch +
                  self.lr_start)

        elif self.last_epoch < self.lr_ramp_ep + self.lr_sus_ep:
            lr = self.lr_max

        else:
            lr = ((self.lr_max - self.lr_min) * self.lr_decay**
                  (self.last_epoch - self.lr_ramp_ep - self.lr_sus_ep) +
                  self.lr_min)
        return lr

In [65]:
class trainer:
    def __init__(self,train_dataloader,val_dataloader,load_checkpoint = False):
        if(load_checkpoint):
            print("Loading pretrained model...")
            self.model = torch.load(configs.CHECKPOINT)
        else:
            self.model  = SiameseNetwork(**configs.model_params)
        no_decay = ['bias','LayerNorm.bias','LayerNorm.weight']
        para_optimizer = list(self.model.named_parameters())
        self.optimizer_parameters = [
        {'params':[p for n,p in para_optimizer if not any(nd in n for nd in no_decay)],'weight_decay':1e-5},
        {'params':[p for n,p in para_optimizer if  any(nd in n for nd in no_decay)],'weight_decay':0.0}
        ]

        self.optimizer = AdamW(amsgrad = True,params = self.optimizer_parameters,lr = 3e-5)
        self.lr_scheduler = lr_scheduler(self.optimizer,**configs.scheduler_params)
        self.criterion = ContrastiveLoss()
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.model = self.model.to(device)
        self.writer = SummaryWriter("tboard")
        self.step_train = 0
        self.step_val = 0



    def train_fn(self,epoch):
        self.model.train()
        count,total_loss,total_acc = 0,0,0
        loop = notebook.tqdm(enumerate(self.train_dataloader),total = len(self.train_dataloader))
        for bi,data in loop:
            imgA = data[0].to(device)
            imgB = data[1].to(device)
            targets = data[2].to(device)
            codesA, codesB =  self.model(imgA,imgB)
            self.optimizer.zero_grad()
            loss,acc =self.criterion(codesA, codesB, targets)
            total_loss += loss.item()
            total_acc += acc
            self.writer.add_scalar("Training Loss",loss.item(),global_step=self.step_train)
            self.writer.add_scalar("Training Accuracy",acc,global_step=self.step_train)
            count +=1
            self.step_train +=1
            loss.backward()
            self.optimizer.step()
            loop.set_postfix(Epoch=epoch,Avg_Train_Loss=total_loss/count,Avg_Train_Acc=total_acc/count,Current_Train_Loss=loss.item(),Current_Train_Acc=acc,LR=self.optimizer.param_groups[0]['lr'])
        self.lr_scheduler.step()

    def eval_fn(self,epoch):
        self.model.eval()
        count,total_loss,total_acc = 0,0,0
        with torch.no_grad():
            loop = notebook.tqdm(enumerate(self.val_dataloader),total = len(self.val_dataloader))
            for bi,data in loop:
                imgA = data[0].to(device)
                imgB = data[1].to(device)
                targets = data[2].to(device)
                codesA, codesB =  self.model(imgA,imgB)
                loss,acc =self.criterion(codesA, codesB, targets)
                total_loss += loss.item()
                total_acc += acc
                self.writer.add_scalar("Validation Loss",loss.item(),global_step=self.step_val)
                self.writer.add_scalar("Validation Accuracy",acc,global_step=self.step_val)
                count +=1
                self.step_val +=1
                loop.set_postfix(Epoch=epoch,Avg_Val_Loss=total_loss/count,Avg_Val_Acc=total_acc/count,Current_Val_Loss=loss.item(),Current_Val_Acc=acc)

        return total_acc / count

    def run(self,epochs = 5):
        best_acc = 0
        for epoch in range (epochs):
            self.train_fn(epoch)
            val_acc = self.eval_fn(epoch)
            print("Epoch {} complete! Validation Acc : {}".format(epoch, val_acc))
            if val_acc > best_acc:
                print("Best validation Acc improved from {} to {}, saving model...".format(best_acc, val_acc))
                best_acc = val_acc
                torch.save(self.model, configs.CHECKPOINT)


In [66]:
train = trainer(train_loader,val_loader)

In [67]:
train.run(configs.EPOCHS)

  0%|          | 0/3 [00:00<?, ?it/s]

RuntimeError: ignored