#### Training

This notebook contains code to train the model for crypt segmentation in colon images. We use segmentation models to train 

In [1]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
#import libraries

import warnings
warnings.filterwarnings('ignore')
import torch.nn as nn
import torch
import segmentation_models_pytorch as sm
import numpy as np
import pandas as pd
import skimage.io as io
from PIL import Image
import cv2
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import tifffile
from sklearn.model_selection import KFold
import glob
import torch_optimizer as t_optim
import utils
import cv2
import torch.optim as optim
from tqdm.notebook import tqdm
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import OneCycleLR



### Dataset

In [3]:
class Colon_Dataset(Dataset):
    def __init__(self,data_csv_path:str="train_data.csv",indexes:list= None,valid:bool=False,transform:transforms = None, target_transform:transforms=None,preprocessing=None):
        self.data = pd.read_csv(data_csv_path)
        self.indexed_data = self.data.iloc[indexes,:]
        self.transform = transform
        self.target_transform = target_transform
        self.preprocessing = preprocessing
        
    def __getitem__(self,idx):
        
        image = tifffile.imread(self.indexed_data.iloc[idx,0])
        mask = tifffile.imread(self.indexed_data.iloc[idx,1]).astype(float)
        
        
        if self.transform:
            augmentations = self.transform(image=image,mask=mask)
            image,mask = augmentations['image'],augmentations['mask']
            
            
        if self.preprocessing:
            preprocessed = self.preprocessing(image=image,mask=mask)
            image,mask = preprocessed['image'],preprocessed['mask']
    
        return image.float(),(mask[:,:,0]/255.0).type(torch.LongTensor)
    
    def __len__(self):
        return len(self.indexed_data)        

### Config 

In [4]:
#Defining configurations
class Configuration:
    MODEL_SAVEPATH = "models/"
    ENCODER = "efficientnet-b2"
    PRETRAINED_WEIGHTS = "imagenet"
    BATCH_SIZE = 16
    INPUT_CHANNELS = 3
    INPUT_SHAPE = (512,512,3)
    NFOLDS = 5
    ACTIVATION = None
    CLASSES = 2 #(crypts 1 background 0)
    DEVICE = "cuda" if torch.cuda.is_available() else cpu
    EPOCHS = 25
    LOSS_CROSSENTROPY = nn.CrossEntropyLoss() 
    LOSS_DICE = utils.DiceScore(loss=True)#utils.DiceLoss()
    DICE_COEF = utils.DiceScore(loss=False)#utils.DiceScore()
    WEIGHT_DECAY = 1e-4
    LEARNING_RATE = 1e-3
    PREPROCESS = sm.encoders.get_preprocessing_fn(ENCODER,PRETRAINED_WEIGHTS)
    ONECYCLELR = False
    MODEL_NAME = 0
    ARCHITECTURE = "UNET"
    MAX_LR_FOR_ONECYCLELR = 1e-3
       
cfg = Configuration()  

Trainer class

In [5]:
class Trainer:
    def __init__(self,cfg:Configuration,train_data_loader:DataLoader,valid_data_loader:DataLoader)->None:
        self.cfg = cfg
        self.patience = 5
        self.model = sm.Unet(encoder_name=self.cfg.ENCODER, 
                     encoder_weights=self.cfg.PRETRAINED_WEIGHTS, 
                     in_channels=self.cfg.INPUT_CHANNELS, 
                     classes=self.cfg.CLASSES)
        self.loss_function = self.cfg.LOSS_CROSSENTROPY
        self.lr = self.cfg.LEARNING_RATE
        self.batch_size = self.cfg.BATCH_SIZE
        self.train_dataloader = train_data_loader
        self.valid_dataloader = valid_data_loader
        self.device = self.cfg.DEVICE
        self.epochs = self.cfg.EPOCHS
        self.lr = 1e-3
        self.track_best_valid = []
        self.val_for_early_stopping = 9999999
        
        
            
            
        self.log = pd.DataFrame(columns=["model_name","train_loss","train_dice","valid_loss","valid_dice"])
        
        
        
        self.optimizer = t_optim.Ranger(self.model.parameters(),weight_decay=self.cfg.WEIGHT_DECAY)
        if self.cfg.ONECYCLELR:
            self.optimizer = OneCycleLR(self.optimizer, max_lr=self.cfg.MAX_LR_FOR_ONECYCLELR, steps_per_epoch=len(self.train_dataloader), epochs=self.EPOCHS)
        
   
    
    def calculate_metrics(self,data_loader:DataLoader):
        self.model.eval()
        total_loss = 0
        total_dice = 0 #batch wise dice loss
        with torch.no_grad():
            for data in tqdm(data_loader,total=len(data_loader)):
                im = data[0].to(self.device)
                mask = data[1].to(self.device)
                out = self.model(im)
                loss = self.loss_function(out.data,mask) #+ cfg.LOSS_DICE(out,mask)
                total_loss+=loss.item()
                total_dice+= self.cfg.DICE_COEF(out.data.to("cpu"),mask.cpu())
        return total_dice/len(data_loader),total_loss/len(data_loader)

    
    def earlystopping(self,val_loss):
        
        if val_loss < self.val_for_early_stopping:
            self.val_for_early_stopping = val_loss
            return True
        else:
            self.patience-=1
            return False
 
    def fit(self)->None:
        print("started fitting the model")
        best_loss = 9999999
        
        for epoch in range(self.epochs):
            self.model.train()
            self.model.to(self.device)
            
            dice_score_ = 0
            loss_ = 0
            
            for j,data in enumerate(tqdm(self.train_dataloader,total = len(self.train_dataloader))):
                input_image_batch = data[0].to(self.device)
                mask_batch = data[1].to(self.device)
                self.optimizer.zero_grad()
                output = self.model(input_image_batch)
                loss = self.loss_function(output,mask_batch)#+cfg.LOSS_DICE(output,mask_batch.unsqueeze(1))
                loss.backward()
                self.optimizer.step()
                loss_+=loss.item()
            
                dice_score_+= self.cfg.DICE_COEF(output.data.to("cpu"),mask_batch.to("cpu"))

            
            dice_score_valid,loss_valid, = self.calculate_metrics(self.valid_dataloader)
            train_dice = dice_score_/len(self.train_dataloader)
            train_loss = loss_/len(self.train_dataloader)
            print(f"train dice score : {train_dice}, train loss {train_loss}")
            print(f"valid dice score : {dice_score_valid}, valid loss {loss_valid}")
            
            self.log.loc[epoch,:] = [f"fold_{self.cfg.ENCODER}_{self.cfg.ENCODER}.pth",f"{train_loss}",f"{train_dice}",f"{loss_valid}",f"{dice_score_valid}"]
            self.log.to_csv(self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}__{self.cfg.ENCODER}_{self.cfg.BATCH_SIZE}_CE_Valid_slicing_all.csv",index=False)
            
            if self.patience >= 0 and self.earlystopping(loss_valid):
                print("saving model")
                
                torch.save(self.model.state_dict(),self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}_{self.cfg.ENCODER}_{self.cfg.BATCH_SIZE}_CE_Valid_slicing_all.pth")
                self.patience= 5
                
            
            if self.patience <= 0:
                print("Training terminated, no improvement in valid loss")
                break
                
        
        
        

        

#### Model Training

##### We perform k fold cross validation 

In [6]:
train_data_csv_path = "Colonic_crypt_dataset/train.csv"
patches_csv_path = "train_data.csv"

def kfold_training(train_data_csv_path:str,patches_csv_path:str,n_folds:int)->pd.DataFrame:
    #get all the training data unique image ids
    train_ids_from_csv = pd.read_csv(train_data_csv_path).iloc[0:-1,:]['id'].values
    
    nfold = KFold(n_folds, shuffle=True, random_state=42)
    
    patch_dataframe = pd.read_csv(patches_csv_path)
    track_best_model = []
    for i, (train_idx, val_idx) in enumerate(nfold.split(train_ids_from_csv)):
        
        train_ids = (patch_dataframe[patch_dataframe.Train_image_path.str.contains("|".join(train_ids_from_csv[train_idx]))]).index
        valid_ids = (patch_dataframe[patch_dataframe.Train_image_path.str.contains("|".join(train_ids_from_csv[val_idx]))]).index
        train_dataset = Colon_Dataset("train_data.csv",indexes=train_ids,transform=utils.get_train_transforms(),preprocessing=utils.preprocessing_fucntion(cfg.PREPROCESS))
        valid_dataset = Colon_Dataset("train_data.csv",indexes=valid_ids,preprocessing=utils.preprocessing_fucntion(cfg.PREPROCESS))

        train_dataloader = DataLoader(train_dataset,batch_size=cfg.BATCH_SIZE,shuffle=True)
        valid_dataloader = DataLoader(valid_dataset,batch_size=cfg.BATCH_SIZE,shuffle=False)
        cfg.MODEL_NAME = str(i) + "_"+cfg.ARCHITECTURE
        trainer =Trainer(cfg,train_dataloader,valid_dataloader)
        trainer.fit()
        track_best_model.append([cfg.MODEL_NAME,trainer.val_for_early_stopping])
        
    pd.Dataframe(track_best_model,columns=["model name","valid_loss"]).to_csv("report_"+cfg.MODEL_NAME+".csv",index=False)
        
    return track_best_model



    


In [7]:
kfold_training(train_data_csv_path,patches_csv_path,n_folds = cfg.NFOLDS)

started fitting the model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.6369217038154602, train loss 0.5261364555672595
valid dice score : 0.6924483180046082, valid loss 0.4183240515344283
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.7037497162818909, train loss 0.4074235848690334
valid dice score : 0.7234495878219604, valid loss 0.3776395461138557
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.7597942352294922, train loss 0.33451917257748154
valid dice score : 0.7461945414543152, valid loss 0.3447716095868279
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.8091416358947754, train loss 0.268326498568058
valid dice score : 0.7854247093200684, valid loss 0.28083915482549104
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.8552007079124451, train loss 0.19400063745285334
valid dice score : 0.8507441878318787, valid loss 0.2004926301100675
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.8954328298568726, train loss 0.14352250667779068
valid dice score : 0.8731256723403931, valid loss 0.17785562048940098
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9173789024353027, train loss 0.11614302722247023
valid dice score : 0.8924328684806824, valid loss 0.15514756476177888
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.928686261177063, train loss 0.10412088154177916
valid dice score : 0.9115232825279236, valid loss 0.13300087215269313
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9389655590057373, train loss 0.09001758116248407
valid dice score : 0.9214329719543457, valid loss 0.11932078617460587
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9424279928207397, train loss 0.08776976158352275
valid dice score : 0.9217832088470459, valid loss 0.12677757967920864


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9474355578422546, train loss 0.07880792109981964
valid dice score : 0.9377257227897644, valid loss 0.1053988116190714
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9516950845718384, train loss 0.0758836267417983
valid dice score : 0.9344229102134705, valid loss 0.107496108421508


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9509716033935547, train loss 0.0774350855499506
valid dice score : 0.9332499504089355, valid loss 0.11454126633265439


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9574984312057495, train loss 0.06486755511478375
valid dice score : 0.9434549808502197, valid loss 0.10391707915593595
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9618700742721558, train loss 0.05834609348522989
valid dice score : 0.9424503445625305, valid loss 0.11383200283436214


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9643775820732117, train loss 0.05539999867936498
valid dice score : 0.9372639656066895, valid loss 0.1289921063272392


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9658106565475464, train loss 0.05208976615808512
valid dice score : 0.9447739720344543, valid loss 0.12049348082612543


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9668692350387573, train loss 0.055333168785038744
valid dice score : 0.9514896273612976, valid loss 0.09946816611815901
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9676990509033203, train loss 0.055552387659094836
valid dice score : 0.9505581855773926, valid loss 0.09806812718948897
saving model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9653061032295227, train loss 0.057587061175390294
valid dice score : 0.9512351751327515, valid loss 0.13590597186018438


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9652447700500488, train loss 0.0628669106548554
valid dice score : 0.9551681280136108, valid loss 0.10343209117212716


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9590009450912476, train loss 0.0707056561092797
valid dice score : 0.952050507068634, valid loss 0.12084642195088022


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9656809568405151, train loss 0.05603289383610612
valid dice score : 0.9524840116500854, valid loss 0.10800084427875631


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.9710423350334167, train loss 0.0463525052917631
valid dice score : 0.9585516452789307, valid loss 0.10401084767106701
Training terminated, no improvement in valid loss
started fitting the model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.4690302312374115, train loss 0.8749170493572316
valid dice score : 0.5111258029937744, valid loss 0.709122896194458
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.5642709136009216, train loss 0.6331067858858311
valid dice score : 0.6073722839355469, valid loss 0.5285227298736572
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.6782964468002319, train loss 0.42757429467870833
valid dice score : 0.7235345244407654, valid loss 0.36185409501194954
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.7959930896759033, train loss 0.26577476269387185
valid dice score : 0.8163834810256958, valid loss 0.2848240099847317
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.8598920106887817, train loss 0.18892411792531927
valid dice score : 0.8632700443267822, valid loss 0.19213929027318954
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.8929295539855957, train loss 0.14397193237822106
valid dice score : 0.8784260749816895, valid loss 0.1938208406791091


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9105132818222046, train loss 0.12426597197005089
valid dice score : 0.9051560163497925, valid loss 0.1418027812615037
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9231297969818115, train loss 0.10856304523792673
valid dice score : 0.9217091798782349, valid loss 0.11756675411015749
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9336158037185669, train loss 0.0966097780681671
valid dice score : 0.9289774298667908, valid loss 0.10949499579146504
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9393783211708069, train loss 0.09014075265285816
valid dice score : 0.9277074337005615, valid loss 0.11271344684064388


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9497079849243164, train loss 0.07291150687539831
valid dice score : 0.9351742267608643, valid loss 0.11794983595609665


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9528099298477173, train loss 0.0709909108725
valid dice score : 0.9390182495117188, valid loss 0.11241443501785398


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9560083150863647, train loss 0.06657761866424947
valid dice score : 0.941484272480011, valid loss 0.09865548508241773
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.95854651927948, train loss 0.06349098817148108
valid dice score : 0.9380764961242676, valid loss 0.12496570264920592


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9639782309532166, train loss 0.05401297940060179
valid dice score : 0.9429079294204712, valid loss 0.09830989176407456
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9651933312416077, train loss 0.053686480493621624
valid dice score : 0.9463388323783875, valid loss 0.11235316703096032


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9675137996673584, train loss 0.05121598618620254
valid dice score : 0.9456503391265869, valid loss 0.10370945185422897


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9695298075675964, train loss 0.04705597083777823
valid dice score : 0.954294741153717, valid loss 0.08998747752048075
saving model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.9708106517791748, train loss 0.04642001757437878
valid dice score : 0.9380759596824646, valid loss 0.11518717603757977


  0%|          | 0/47 [00:00<?, ?it/s]

KeyboardInterrupt: 