#### Training

This notebook contains code to train the model for crypt segmentation in colon images.

In [1]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
#import libraries
import os
import warnings
warnings.filterwarnings('ignore')
import torch.nn as nn
import torch
import segmentation_models_pytorch as sm
import numpy as np
import pandas as pd
import skimage.io as io
from PIL import Image
import cv2
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import tifffile
from sklearn.model_selection import KFold
import glob
import torch_optimizer as t_optim
import utils
import cv2
import torch.optim as optim
from tqdm.notebook import tqdm
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import OneCycleLR



### Dataset

In [3]:
class AIRS_dataset(Dataset):
    def __init__(self,data_csv_path:str="train_data.csv",indexes:list= None,valid:bool=False,transform:transforms = None, target_transform:transforms=None,preprocessing=None):
        self.data = pd.read_csv(data_csv_path)
        self.indexed_data = self.data.iloc[indexes,:]
        self.transform = transform
        self.target_transform = target_transform
        self.preprocessing = preprocessing
        
    def __getitem__(self,idx):
        
        image = tifffile.imread(self.indexed_data.iloc[idx,1])
        mask = tifffile.imread(self.indexed_data.iloc[idx,2]).astype(float) * 255.
        
        if self.transform:
            augmentations = self.transform(image=image,mask=mask)
            image,mask = augmentations['image'],augmentations['mask']
            
            
        if self.preprocessing:
            preprocessed = self.preprocessing(image=image,mask=mask)
            image,mask = preprocessed['image'],preprocessed['mask']
        # print(mask.shape)
        return image.float(),(mask[:,:]/255.0).type(torch.LongTensor)
    
    def __len__(self):
        return len(self.indexed_data)        

### Config 

In [4]:
#Defining configurations
class Configuration:
    MODEL_SAVEPATH = "models/"
    ENCODER = "efficientnet-b2"
    PRETRAINED_WEIGHTS = "imagenet"
    BATCH_SIZE = 16
    INPUT_CHANNELS = 3
    INPUT_SHAPE = (512,512,3)
    NFOLDS = 5
    ACTIVATION = None
    CLASSES = 2 #(crypts 1 background 0)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    EPOCHS = 25
    LOSS_CROSSENTROPY = nn.CrossEntropyLoss() 
    LOSS_DICE = utils.DiceScore(loss=True)#utils.DiceLoss()
    DICE_COEF = utils.DiceScore(loss=False)#utils.DiceScore()
    WEIGHT_DECAY = 1e-4
    LEARNING_RATE = 1e-3
    PREPROCESS = sm.encoders.get_preprocessing_fn(ENCODER,PRETRAINED_WEIGHTS)
    ONECYCLELR = False
    MODEL_NAME = 0
    ARCHITECTURE = "UNET"
    MAX_LR_FOR_ONECYCLELR = 1e-3
       
cfg = Configuration()  

Trainer class

In [5]:
class Trainer:
    def __init__(self,cfg:Configuration,train_data_loader:DataLoader,valid_data_loader:DataLoader)->None:
        self.cfg = cfg
        self.patience = 5
        self.model = sm.Unet(encoder_name=self.cfg.ENCODER, 
                     encoder_weights=self.cfg.PRETRAINED_WEIGHTS, 
                     in_channels=self.cfg.INPUT_CHANNELS, 
                     classes=self.cfg.CLASSES)
        self.loss_function = self.cfg.LOSS_CROSSENTROPY
        self.lr = self.cfg.LEARNING_RATE
        self.batch_size = self.cfg.BATCH_SIZE
        self.train_dataloader = train_data_loader
        self.valid_dataloader = valid_data_loader
        self.device = self.cfg.DEVICE
        self.epochs = self.cfg.EPOCHS
        self.lr = 1e-3
        self.track_best_valid = []
        self.val_for_early_stopping = 9999999
        if not os.path.isdir(self.cfg.MODEL_SAVEPATH):
            os.makedirs(self.cfg.MODEL_SAVEPATH)
        
        
            
            
        self.log = pd.DataFrame(columns=["model_name","train_loss","train_dice","valid_loss","valid_dice"])
        
        
        
        self.optimizer = t_optim.Ranger(self.model.parameters(),weight_decay=self.cfg.WEIGHT_DECAY)
        if self.cfg.ONECYCLELR:
            self.optimizer = OneCycleLR(self.optimizer, max_lr=self.cfg.MAX_LR_FOR_ONECYCLELR, steps_per_epoch=len(self.train_dataloader), epochs=self.EPOCHS)
        
   
    
    def calculate_metrics(self,data_loader:DataLoader):
        self.model.eval()
        total_loss = 0
        total_dice = 0 #batch wise dice loss
        with torch.no_grad():
            for data in tqdm(data_loader,total=len(data_loader)):
                im = data[0].to(self.device)
                mask = data[1].to(self.device)
                out = self.model(im)
                loss = self.loss_function(out.data,mask) #+ cfg.LOSS_DICE(out,mask)
                total_loss+=loss.item()
                total_dice+= self.cfg.DICE_COEF(out.data.to("cpu"),mask.cpu())
        return total_dice/len(data_loader),total_loss/len(data_loader)

    
    def earlystopping(self,val_loss):
        
        if val_loss < self.val_for_early_stopping:
            self.val_for_early_stopping = val_loss
            return True
        else:
            self.patience-=1
            return False
 
    def fit(self)->None:
        print("started fitting the model")
        best_loss = 9999999
        
        for epoch in range(self.epochs):
            self.model.train()
            self.model.to(self.device)
            
            dice_score_ = 0
            loss_ = 0
            
            for j,data in enumerate(tqdm(self.train_dataloader,total = len(self.train_dataloader))):
                input_image_batch = data[0].to(self.device)
                mask_batch = data[1].to(self.device)
                self.optimizer.zero_grad()
                output = self.model(input_image_batch)
                loss = self.loss_function(output,mask_batch)#+cfg.LOSS_DICE(output,mask_batch.unsqueeze(1))
                loss.backward()
                self.optimizer.step()
                loss_+=loss.item()
            
                dice_score_+= self.cfg.DICE_COEF(output.data.to("cpu"),mask_batch.to("cpu"))

            
            dice_score_valid,loss_valid, = self.calculate_metrics(self.valid_dataloader)
            train_dice = dice_score_/len(self.train_dataloader)
            train_loss = loss_/len(self.train_dataloader)
            print(f"train dice score : {train_dice}, train loss {train_loss}")
            print(f"valid dice score : {dice_score_valid}, valid loss {loss_valid}")
            
            self.log.loc[epoch,:] = [f"fold_{self.cfg.ENCODER}_{self.cfg.ENCODER}.pth",f"{train_loss}",f"{train_dice}",f"{loss_valid}",f"{dice_score_valid}"]
            self.log.to_csv(self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}__{self.cfg.ENCODER}_{self.cfg.BATCH_SIZE}_CE_Valid_slicing_all.csv",index=False)
            
            if self.patience >= 0 and self.earlystopping(loss_valid):
                print("saving model")
                
                torch.save(self.model.state_dict(),self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}_{self.cfg.ENCODER}_{self.cfg.BATCH_SIZE}_CE_Valid_slicing_all.pth")
                self.patience= 5
                
            
            if self.patience <= 0:
                print("Training terminated, no improvement in valid loss")
                break
                
        
        
        

        

#### Model Training

In [6]:
train_data_csv_path = "train_patches_from_40_random_sample.csv"
valid_data_csv_path = "valid_patches_from_40_random_sample.csv"
patches_csv_path =  "train_patches_from_40_random_sample.csv"

def training(train_data_csv_path:str,patches_csv_path:str,n_folds:int)->pd.DataFrame:
    #get all the training data unique image ids

    
    # nfold = KFold(n_folds, shuffle=True, random_state=0,)
    
    patch_dataframe = pd.read_csv(patches_csv_path)
    # track_best_model = []
    # for i, (train_idx, val_idx) in enumerate(nfold.split(train_ids_from_csv)):
        
    train_ids = pd.read_csv(train_data_csv_path).index
    valid_ids = pd.read_csv(valid_data_csv_path).index
    
    train_dataset = AIRS_dataset(train_data_csv_path,indexes=train_ids,transform=utils.get_train_transforms(),preprocessing=utils.preprocessing_fucntion(cfg.PREPROCESS))
    valid_dataset = AIRS_dataset(valid_data_csv_path,indexes=valid_ids,preprocessing=utils.preprocessing_fucntion(cfg.PREPROCESS))

    train_dataloader = DataLoader(train_dataset,batch_size=cfg.BATCH_SIZE,shuffle=True)
    valid_dataloader = DataLoader(valid_dataset,batch_size=cfg.BATCH_SIZE,shuffle=False)
    cfg.MODEL_NAME =   "_"+cfg.ARCHITECTURE
    trainer =Trainer(cfg,train_dataloader,valid_dataloader)
    trainer.fit()
    # track_best_model.append([cfg.MODEL_NAME,trainer.val_for_early_stopping])
        
    pd.DataFrame(track_best_model,columns=["model name","valid_loss"]).to_csv("report_"+cfg.MODEL_NAME+".csv",index=False)
        



    


In [7]:
training(train_data_csv_path,patches_csv_path,n_folds = cfg.NFOLDS)

started fitting the model


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.7978237271308899, train loss 0.3140485172067181
valid dice score : 0.9462895393371582, valid loss 0.0741326682994851
saving model


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.924212634563446, train loss 0.12814334505048625
valid dice score : 0.9721938371658325, valid loss 0.04158883154458177
saving model


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.9397636651992798, train loss 0.10518355355546483
valid dice score : 0.9749948978424072, valid loss 0.037271681198421704
saving model


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.9504663944244385, train loss 0.0886891063618924
valid dice score : 0.9814695119857788, valid loss 0.031196880388172525
saving model


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.9520512819290161, train loss 0.08533431788927075
valid dice score : 0.9842981100082397, valid loss 0.028515982602319096
saving model


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.9568033218383789, train loss 0.0760528886367694
valid dice score : 0.9803197979927063, valid loss 0.03245662077560823


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.9594742059707642, train loss 0.07366119036485348
valid dice score : 0.9848818778991699, valid loss 0.02746455263154812
saving model


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.959334135055542, train loss 0.07390707478312109
valid dice score : 0.9868287444114685, valid loss 0.023336154603207714
saving model


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.9619971513748169, train loss 0.06888800659440321
valid dice score : 0.9863169193267822, valid loss 0.02520659488053535


  0%|          | 0/271 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

train dice score : 0.9632487297058105, train loss 0.0657330662067086
valid dice score : 0.9867939949035645, valid loss 0.02314851393529618
saving model


  0%|          | 0/271 [00:00<?, ?it/s]

KeyboardInterrupt: 