#### Training

This notebook contains code to train the model for crypt segmentation in colon images. We use segmentation models to train 

In [1]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
#import libraries
import os
import warnings
warnings.filterwarnings('ignore')
import torch.nn as nn
import torch
import segmentation_models_pytorch as sm
import numpy as np
import pandas as pd
import skimage.io as io
from PIL import Image
import cv2
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import tifffile
from sklearn.model_selection import KFold
import glob
import torch_optimizer as t_optim
import utils
import cv2
import torch.optim as optim
from tqdm.notebook import tqdm
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import OneCycleLR



### Dataset

In [3]:
class Colon_Dataset(Dataset):
    def __init__(self,data_csv_path:str="train_data.csv",indexes:list= None,valid:bool=False,transform:transforms = None, target_transform:transforms=None,preprocessing=None):
        self.data = pd.read_csv(data_csv_path)
        self.indexed_data = self.data.iloc[indexes,:]
        self.transform = transform
        self.target_transform = target_transform
        self.preprocessing = preprocessing
        
    def __getitem__(self,idx):
        
        image = tifffile.imread(self.indexed_data.iloc[idx,0])
        mask = tifffile.imread(self.indexed_data.iloc[idx,1]).astype(float)
        
        
        if self.transform:
            augmentations = self.transform(image=image,mask=mask)
            image,mask = augmentations['image'],augmentations['mask']
            
            
        if self.preprocessing:
            preprocessed = self.preprocessing(image=image,mask=mask)
            image,mask = preprocessed['image'],preprocessed['mask']
    
        return image.float(),(mask[:,:,0]/255.0).type(torch.LongTensor)
    
    def __len__(self):
        return len(self.indexed_data)        

### Config 

In [4]:
#Defining configurations
class Configuration:
    MODEL_SAVEPATH = "models/"
    ENCODER = "efficientnet-b2"
    PRETRAINED_WEIGHTS = "imagenet"
    BATCH_SIZE = 16
    INPUT_CHANNELS = 3
    INPUT_SHAPE = (512,512,3)
    NFOLDS = 5
    ACTIVATION = None
    CLASSES = 2 #(crypts 1 background 0)
    DEVICE = "cuda" if torch.cuda.is_available() else cpu
    EPOCHS = 25
    LOSS_CROSSENTROPY = nn.CrossEntropyLoss() 
    LOSS_DICE = utils.DiceScore(loss=True)#utils.DiceLoss()
    DICE_COEF = utils.DiceScore(loss=False)#utils.DiceScore()
    WEIGHT_DECAY = 1e-4
    LEARNING_RATE = 1e-3
    PREPROCESS = sm.encoders.get_preprocessing_fn(ENCODER,PRETRAINED_WEIGHTS)
    ONECYCLELR = False
    MODEL_NAME = 0
    ARCHITECTURE = "UNET"
    MAX_LR_FOR_ONECYCLELR = 1e-3
       
cfg = Configuration()  

Trainer class

In [5]:
class Trainer:
    def __init__(self,cfg:Configuration,train_data_loader:DataLoader,valid_data_loader:DataLoader)->None:
        self.cfg = cfg
        self.patience = 5
        self.model = sm.Unet(encoder_name=self.cfg.ENCODER, 
                     encoder_weights=self.cfg.PRETRAINED_WEIGHTS, 
                     in_channels=self.cfg.INPUT_CHANNELS, 
                     classes=self.cfg.CLASSES)
        self.loss_function = self.cfg.LOSS_CROSSENTROPY
        self.lr = self.cfg.LEARNING_RATE
        self.batch_size = self.cfg.BATCH_SIZE
        self.train_dataloader = train_data_loader
        self.valid_dataloader = valid_data_loader
        self.device = self.cfg.DEVICE
        self.epochs = self.cfg.EPOCHS
        self.lr = 1e-3
        self.track_best_valid = []
        self.val_for_early_stopping = 9999999
        if not os.path.isdir(self.cfg.MODEL_SAVEPATH):
            os.makedirs(self.cfg.MODEL_SAVEPATH)
        
        
            
            
        self.log = pd.DataFrame(columns=["model_name","train_loss","train_dice","valid_loss","valid_dice"])
        
        
        
        self.optimizer = t_optim.Ranger(self.model.parameters(),weight_decay=self.cfg.WEIGHT_DECAY)
        if self.cfg.ONECYCLELR:
            self.optimizer = OneCycleLR(self.optimizer, max_lr=self.cfg.MAX_LR_FOR_ONECYCLELR, steps_per_epoch=len(self.train_dataloader), epochs=self.EPOCHS)
        
   
    
    def calculate_metrics(self,data_loader:DataLoader):
        self.model.eval()
        total_loss = 0
        total_dice = 0 #batch wise dice loss
        with torch.no_grad():
            for data in tqdm(data_loader,total=len(data_loader)):
                im = data[0].to(self.device)
                mask = data[1].to(self.device)
                out = self.model(im)
                loss = self.loss_function(out.data,mask) #+ cfg.LOSS_DICE(out,mask)
                total_loss+=loss.item()
                total_dice+= self.cfg.DICE_COEF(out.data.to("cpu"),mask.cpu())
        return total_dice/len(data_loader),total_loss/len(data_loader)

    
    def earlystopping(self,val_loss):
        
        if val_loss < self.val_for_early_stopping:
            self.val_for_early_stopping = val_loss
            return True
        else:
            self.patience-=1
            return False
 
    def fit(self)->None:
        print("started fitting the model")
        best_loss = 9999999
        
        for epoch in range(self.epochs):
            self.model.train()
            self.model.to(self.device)
            
            dice_score_ = 0
            loss_ = 0
            
            for j,data in enumerate(tqdm(self.train_dataloader,total = len(self.train_dataloader))):
                input_image_batch = data[0].to(self.device)
                mask_batch = data[1].to(self.device)
                self.optimizer.zero_grad()
                output = self.model(input_image_batch)
                loss = self.loss_function(output,mask_batch)#+cfg.LOSS_DICE(output,mask_batch.unsqueeze(1))
                loss.backward()
                self.optimizer.step()
                loss_+=loss.item()
            
                dice_score_+= self.cfg.DICE_COEF(output.data.to("cpu"),mask_batch.to("cpu"))

            
            dice_score_valid,loss_valid, = self.calculate_metrics(self.valid_dataloader)
            train_dice = dice_score_/len(self.train_dataloader)
            train_loss = loss_/len(self.train_dataloader)
            print(f"train dice score : {train_dice}, train loss {train_loss}")
            print(f"valid dice score : {dice_score_valid}, valid loss {loss_valid}")
            
            self.log.loc[epoch,:] = [f"fold_{self.cfg.ENCODER}_{self.cfg.ENCODER}.pth",f"{train_loss}",f"{train_dice}",f"{loss_valid}",f"{dice_score_valid}"]
            self.log.to_csv(self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}__{self.cfg.ENCODER}_{self.cfg.BATCH_SIZE}_CE_Valid_slicing_all.csv",index=False)
            
            if self.patience >= 0 and self.earlystopping(loss_valid):
                print("saving model")
                
                torch.save(self.model.state_dict(),self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}_{self.cfg.ENCODER}_{self.cfg.BATCH_SIZE}_CE_Valid_slicing_all.pth")
                self.patience= 5
                
            
            if self.patience <= 0:
                print("Training terminated, no improvement in valid loss")
                break
                
        
        
        

        

#### Model Training

##### We perform k fold cross validation 

In [6]:
train_data_csv_path = "Colonic_crypt_dataset/train.csv"
patches_csv_path = "train_data.csv"

def kfold_training(train_data_csv_path:str,patches_csv_path:str,n_folds:int)->pd.DataFrame:
    #get all the training data unique image ids
    train_ids_from_csv = pd.read_csv(train_data_csv_path).iloc[0:-1,:]['id'].values
    
    nfold = KFold(n_folds, shuffle=True, random_state=0)
    
    patch_dataframe = pd.read_csv(patches_csv_path)
    track_best_model = []
    for i, (train_idx, val_idx) in enumerate(nfold.split(train_ids_from_csv)):
        
        train_ids = (patch_dataframe[patch_dataframe.Train_image_path.str.contains("|".join(train_ids_from_csv[train_idx]))]).index
        valid_ids = (patch_dataframe[patch_dataframe.Train_image_path.str.contains("|".join(train_ids_from_csv[val_idx]))]).index
        train_dataset = Colon_Dataset("train_data.csv",indexes=train_ids,transform=utils.get_train_transforms(),preprocessing=utils.preprocessing_fucntion(cfg.PREPROCESS))
        valid_dataset = Colon_Dataset("train_data.csv",indexes=valid_ids,preprocessing=utils.preprocessing_fucntion(cfg.PREPROCESS))

        train_dataloader = DataLoader(train_dataset,batch_size=cfg.BATCH_SIZE,shuffle=True)
        valid_dataloader = DataLoader(valid_dataset,batch_size=cfg.BATCH_SIZE,shuffle=False)
        cfg.MODEL_NAME = str(i) + "_"+cfg.ARCHITECTURE
        trainer =Trainer(cfg,train_dataloader,valid_dataloader)
        trainer.fit()
        track_best_model.append([cfg.MODEL_NAME,trainer.val_for_early_stopping])
        
    pd.DataFrame(track_best_model,columns=["model name","valid_loss"]).to_csv("report_"+cfg.MODEL_NAME+".csv",index=False)
        
    return track_best_model



    


In [None]:
kfold_training(train_data_csv_path,patches_csv_path,n_folds = cfg.NFOLDS)

started fitting the model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.7793204188346863, train loss 0.36847229143406485
valid dice score : 0.8037179708480835, valid loss 0.3701242655515671
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.8323666453361511, train loss 0.27541542243450245
valid dice score : 0.8157902956008911, valid loss 0.3339341804385185
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.8674011826515198, train loss 0.20667448941063374
valid dice score : 0.871990442276001, valid loss 0.19635743461549282
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9107339978218079, train loss 0.13198905152843354
valid dice score : 0.8997569680213928, valid loss 0.15056656021624804
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9274612069129944, train loss 0.10924337970766615
valid dice score : 0.914080023765564, valid loss 0.13697780389338732
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9372887015342712, train loss 0.09635114923436591
valid dice score : 0.9264574646949768, valid loss 0.12529025366529822
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9480992555618286, train loss 0.08055320397970524
valid dice score : 0.9313315749168396, valid loss 0.1256098411977291


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9538446664810181, train loss 0.07389038222584318
valid dice score : 0.9319142699241638, valid loss 0.11718650185503066
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9551547169685364, train loss 0.07357487296487422
valid dice score : 0.934475839138031, valid loss 0.10941710323095322
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9587458968162537, train loss 0.06537259813952953
valid dice score : 0.928773820400238, valid loss 0.1525818428490311


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9646559953689575, train loss 0.057771198253365276
valid dice score : 0.9346379637718201, valid loss 0.13776662503369153


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9666985273361206, train loss 0.0544922235719067
valid dice score : 0.9410706162452698, valid loss 0.12397479056380689


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9677273631095886, train loss 0.052693513122961874
valid dice score : 0.9407567977905273, valid loss 0.15300136781297624


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9718087911605835, train loss 0.050693327679913094
valid dice score : 0.9424779415130615, valid loss 0.1271718661300838
Training terminated, no improvement in valid loss
started fitting the model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.6077746152877808, train loss 0.5684734959351389
valid dice score : 0.6490503549575806, valid loss 0.48261260635712566
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.6755353808403015, train loss 0.44979054363150345
valid dice score : 0.6986778974533081, valid loss 0.4071301712709315
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.7422655820846558, train loss 0.35709163232853536
valid dice score : 0.7609079480171204, valid loss 0.3294174013768925
saving model


  0%|          | 0/38 [00:00<?, ?it/s]