#### Training

This notebook contains code to train the model for crypt segmentation in colon images. We use segmentation models to train 

In [1]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch.nn as nn
import torch
import segmentation_models_pytorch as sm
import numpy as np
import pandas as pd
import skimage.io as io
from PIL import Image
import cv2
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import tifffile
from sklearn.model_selection import KFold
import glob
import torch_optimizer as t_optim
import utils
import torch.optim as optim
from tqdm.notebook import tqdm
from torch.optim.lr_scheduler import OneCycleLR



In [3]:
class Colon_Dataset(Dataset):
    def __init__(self,data_csv_path:str="train_data.csv",indexes:list= None,valid:bool=False,transform:transforms = None, target_transform:transforms=None,preprocessing=None):
        self.data = pd.read_csv(data_csv_path)
        self.indexed_data = self.data.iloc[indexes,:]
        self.transform = transform
        self.target_transform = target_transform
        self.preprocessing = preprocessing
        
    def __getitem__(self,idx):
        
        image = tifffile.imread(self.indexed_data.iloc[idx,0])
        mask = tifffile.imread(self.indexed_data.iloc[idx,1]).astype(float)
        
        
        if self.transform:
            augmentations = self.transform(image=image,mask=mask)
            image,mask = augmentations['image'],augmentations['mask']
            
            
        if self.preprocessing:
            preprocessed = self.preprocessing(image=image,mask=mask)
            image,mask = preprocessed['image'],preprocessed['mask']
    
        return image.float(),(mask[:,:,0]/255.0).type(torch.LongTensor)
    
    def __len__(self):
        return len(self.indexed_data)        

In [4]:
train_transform =  A.Compose([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9,
                         border_mode=cv2.BORDER_REFLECT),
        A.OneOf([
            A.ElasticTransform(p=.3),
            A.GaussianBlur(p=.3),
            A.GaussNoise(p=.3),
            A.OpticalDistortion(p=0.3),
            A.GridDistortion(p=.1),
            A.PiecewiseAffine(p=0.3),
        ], p=0.3),
        A.OneOf([
            A.HueSaturationValue(15,25,0),
            A.CLAHE(clip_limit=2),
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
        ], p=0.3),
    
        
    ])

validation_transform = A.Compose([ToTensorV2()])

def preprocessing_fucntion(preprocesing_function=None):
    return A.Compose([A.Lambda(image=preprocesing_function),ToTensorV2()])
    

In [5]:
#Defining configurations
class Configuration:
    MODEL_SAVEPATH = "models/"
    ENCODER = "efficientnet-b2"
    PRETRAINED_WEIGHTS = "imagenet"
    BATCH_SIZE = 16
    INPUT_CHANNELS = 3
    INPUT_SHAPE = (512,512,3)
    NFOLDS = 5
    ACTIVATION = None
    CLASSES = 2 #(crypts 1 background 0)
    DEVICE = "cuda" if torch.cuda.is_available() else cpu
    EPOCHS = 10
    LOSS_CROSSENTROPY = nn.CrossEntropyLoss() 
    LOSS_DICE = utils.DiceScore()
    DICE_COEF = utils.DiceLoss()
    WEIGHT_DECAY = 1e-4
    LEARNING_RATE = 1e-3
    PREPROCESS = sm.encoders.get_preprocessing_fn(ENCODER,PRETRAINED_WEIGHTS)
    ONECYCLELR = False
    MODEL_NAME = 0
    MAX_LR_FOR_ONECYCLELR = 1e-3
       
cfg = Configuration()  

In [6]:
#init kfold

In [7]:
Train_ids = pd.read_csv("Colonic_crypt_dataset/train.csv").iloc[0:-1,:]['id'].values

In [8]:
kfold = KFold(cfg.NFOLDS, shuffle=True, random_state=0)


In [9]:
df = pd.read_csv("train_data.csv")


In [10]:
class Trainer:
    def __init__(self,cfg:Configuration,train_data_loader:DataLoader,valid_data_loader:DataLoader)->None:
        self.cfg = cfg
        self.model = sm.Unet(encoder_name=self.cfg.ENCODER, 
                     encoder_weights=self.cfg.PRETRAINED_WEIGHTS, 
                     in_channels=self.cfg.INPUT_CHANNELS, 
                     classes=self.cfg.CLASSES)
        self.loss_function = self.cfg.LOSS_CROSSENTROPY
        self.lr = self.cfg.LEARNING_RATE
        self.batch_size = self.cfg.BATCH_SIZE
        self.train_dataloader = train_data_loader
        self.valid_dataloader = valid_data_loader
        self.device = self.cfg.DEVICE
        self.epochs = self.cfg.EPOCHS
        self.lr = 1e-3
        self.log = pd.DataFrame(columns=["model_name","train_loss","train_dice","valid_loss","valid_dice"])
        
        
        
        self.optimizer = t_optim.Ranger(self.model.parameters(),weight_decay=self.cfg.WEIGHT_DECAY)
        if self.cfg.ONECYCLELR:
            self.optimizer = OneCycleLR(self.optimizer, max_lr=self.cfg.MAX_LR_FOR_ONECYCLELR, steps_per_epoch=len(self.train_dataloader), epochs=self.EPOCHS)
        
   
    
    def calculate_metrics(self,data_loader:DataLoader):
        self.model.eval()
        total_loss = 0
        total_dice = 0 #batch wise dice loss
        with torch.no_grad():
            for data in tqdm(data_loader,total=len(data_loader)):
                im = data[0].to(self.device)
                mask = data[1].to(self.device)
                out = self.model(im)
                loss = self.loss_function(out.data,mask) #+ cfg.LOSS_DICE(out,mask)
                total_loss+=loss.item()
                total_dice+= self.cfg.DICE_COEF(out.data,mask)
        return total_dice/len(data_loader),total_loss/len(data_loader)


        
    def fit(self)->None:
        print("started fitting the model")
        best_loss = 9999999
        for epoch in range(self.epochs):
            self.model.train()
            self.model.to(self.device)
            
            dice_score_ = 0
            loss_ = 0
            
            for j,data in enumerate(tqdm(self.train_dataloader,total = len(self.train_dataloader))):
                input_image_batch = data[0].to(self.device)
                mask_batch = data[1].to(self.device)
                self.optimizer.zero_grad()
                output = self.model(input_image_batch)
                loss = self.loss_function(output,mask_batch)#+cfg.LOSS_DICE(output,mask_batch.unsqueeze(1))
                loss.backward()
                self.optimizer.step()
                loss_+=loss.item()
                dice_score_+= self.cfg.DiceScore(output.data,mask_batch)

            
            dice_score_valid,loss_valid, = self.calculate_metrics(self.valid_dataloader)
            train_dice = dice_score_/len(self.train_dataloader)
            train_loss = loss_/len(self.train_dataloader)
            print(f"train dice score : {train_dice}, train loss {train_loss}")
            print(f"valid dice score : {dice_score_valid}, valid loss {loss_valid}")
            
            self.log.loc[epoch,:] = [f"fold_{i}_{self.cfg.ENCODER}.pth",f"{train_loss}",f"{train_dice}",f"{loss_valid}",f"{dice_score_valid}"]
            self.log.to_csv(self.cfg.MODEL_SAVEPATH+f"/fold_{i}_{cfg.ENCODER}_CE_Valid_slicing_all.csv",index=False)

            if train_loss < best_loss:
                best_loss = train_loss
                torch.save(self.model.state_dict(),self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}_{self.cfg.ENCODER}_CE_Valid_slicing_all.pth")
        
        
        

        

In [11]:
for i, (train_idx, val_idx) in enumerate(kfold.split(Train_ids)):
    train_ids = (df[df.Train_image_path.str.contains("|".join(Train_ids[train_idx]))]).index
    valid_ids = (df[df.Train_image_path.str.contains("|".join(Train_ids[val_idx]))]).index
    train_dataset = Colon_Dataset("train_data.csv",indexes=train_ids,transform=train_transform,preprocessing=preprocessing_fucntion(cfg.PREPROCESS))
    valid_dataset = Colon_Dataset("train_data.csv",indexes=valid_ids,preprocessing=preprocessing_fucntion(cfg.PREPROCESS))
    
    train_dataloader = DataLoader(train_dataset,batch_size=cfg.BATCH_SIZE,shuffle=True,num_workers=8)
    valid_dataloader = DataLoader(valid_dataset,batch_size=cfg.BATCH_SIZE,shuffle=False,num_workers=8)
    cfg.MODEL_NAME = str(i)
    trainer =Trainer(cfg,train_dataloader,valid_dataloader)
    trainer.fit()

started fitting the model


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.14549367129802704, train loss 0.70356880730771
valid dice score : 0.15102466940879822, valid loss 0.6046241745352745


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.1555912047624588, train loss 0.5069651863676436
valid dice score : 0.13499605655670166, valid loss 0.49028777331113815


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.162323459982872, train loss 0.3650266496424979
valid dice score : 0.16826657950878143, valid loss 0.3502356596291065


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.2991184890270233, train loss 0.24406688042143557
valid dice score : 0.4026572108268738, valid loss 0.2617147322744131


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.45884042978286743, train loss 0.1622293450413866
valid dice score : 0.45979875326156616, valid loss 0.18297969829291105


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.5297406911849976, train loss 0.13078120747145186
valid dice score : 0.5077509880065918, valid loss 0.17555162589997053


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.6000689268112183, train loss 0.10682061536515013
valid dice score : 0.5431884527206421, valid loss 0.15840877080336213


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.6243547797203064, train loss 0.09650277552452494
valid dice score : 0.5526847839355469, valid loss 0.13707481743767858


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.649639904499054, train loss 0.0887027752209217
valid dice score : 0.5567726492881775, valid loss 0.13557483535259962


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train dice score : 0.6986401677131653, train loss 0.07517618257948692
valid dice score : 0.5633401870727539, valid loss 0.13439303822815418
started fitting the model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.15324977040290833, train loss 0.8544206085957979
valid dice score : 0.13735713064670563, valid loss 0.681135892868042


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.16521991789340973, train loss 0.653639576937023
valid dice score : 0.14650863409042358, valid loss 0.5488022706087898


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.16448916494846344, train loss 0.4737762777428878
valid dice score : 0.14905805885791779, valid loss 0.45816096488167257


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.18631479144096375, train loss 0.35881907845798294
valid dice score : 0.18868623673915863, valid loss 0.32867976321893577


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.2802342474460602, train loss 0.2629535409965013
valid dice score : 0.36661335825920105, valid loss 0.20789140199913697


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.4229559600353241, train loss 0.19155175121206985
valid dice score : 0.4125586152076721, valid loss 0.16200096729923696


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.5042730569839478, train loss 0.15347032974425115
valid dice score : 0.48493537306785583, valid loss 0.12307130282416064


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.5407430529594421, train loss 0.13195736510188957
valid dice score : 0.5336718559265137, valid loss 0.10778799214783837


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.5959455370903015, train loss 0.11315887107660896
valid dice score : 0.5672350525856018, valid loss 0.09841509163379669


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.6151795983314514, train loss 0.10684863654406447
valid dice score : 0.6183412671089172, valid loss 0.09275092820034307
started fitting the model


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.134597510099411, train loss 0.5266858135399065
valid dice score : 0.1393374353647232, valid loss 0.5607414561159471


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.17195294797420502, train loss 0.3943428561875695
valid dice score : 0.13538770377635956, valid loss 0.45143697542302746


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.17245040833950043, train loss 0.3069569288115752
valid dice score : 0.1774313598871231, valid loss 0.3543376221376307


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.2984291613101959, train loss 0.21653759989299273
valid dice score : 0.3688996732234955, valid loss 0.23774818199522355


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.43144094944000244, train loss 0.15550456333317256
valid dice score : 0.4702044427394867, valid loss 0.1635129723478766


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.4933711588382721, train loss 0.13464617454691938
valid dice score : 0.5087540149688721, valid loss 0.14304827372817433


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.5690075755119324, train loss 0.11373356965027358
valid dice score : 0.5691515803337097, valid loss 0.1253201698555666


  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

train dice score : 0.6395724415779114, train loss 0.09855852040805314
valid dice score : 0.6070603132247925, valid loss 0.12679636631818378


  0%|          | 0/38 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [16]:
# def calculate_metrics(model,data_loader):
#     model.eval()
#     total_loss = 0
#     total_dice = 0 #batch wise dice loss
#     with torch.no_grad():
#         for data in tqdm.tqdm(data_loader,total=len(data_loader)):
#             im = data[0].cuda()
#             mask = data[1].cuda()
#             out = model(im)
#             loss = cfg.LOSS_CROSSENTROPY(out,mask) #+ cfg.LOSS_DICE(out,mask)
#             total_loss+=loss.item()
#             total_dice+= 1-cfg.LOSS_DICE(out,mask)
            
            
#     return total_dice/len(data_loader),total_loss/len(data_loader)

In [18]:
# for i, (train_idx, val_idx) in enumerate(kfold.split(Train_ids)):
#     # if i == 0 or i == 1:
#     #     continue
#     log = pd.DataFrame(columns=["model_name","train_loss","train_dice","valid_loss","valid_dice"])
#     train_ids = (df[df.Train_image_path.str.contains("|".join(Train_ids[train_idx]))]).index
#     valid_ids = (df[df.Train_image_path.str.contains("|".join(Train_ids[val_idx]))]).index
#     train_dataset = Colon_Dataset("train_data.csv",indexes=train_ids,transform=train_transform,preprocessing=preprocessing_fucntion(cfg.PREPROCESS))
#     valid_dataset = Colon_Dataset("train_data.csv",indexes=valid_ids,preprocessing=preprocessing_fucntion(cfg.PREPROCESS))
    
#     train_dataloader = DataLoader(train_dataset,batch_size=cfg.BATCH_SIZE,shuffle=True,num_workers=8)
#     valid_dataloader = DataLoader(valid_dataset,batch_size=cfg.BATCH_SIZE,shuffle=False,num_workers=8)
    
#     model = sm.Unet(encoder_name=cfg.ENCODER, 
#                      encoder_weights=cfg.PRETRAINED_WEIGHTS, 
#                      in_channels=cfg.INPUT_CHANNELS, 
#                      classes=cfg.CLASSES)
    
#     # model = nn.DataParallel(model)
#     model.cuda()
#     optimizer = t_optim.Ranger(model.parameters(),weight_decay=cfg.WEIGHT_DECAY)#optim.Adam(model.parameters())
#     best_loss = 999999
#     total_train_loss = 0
    
#     for epoch in range(cfg.EPOCHS):
#         dice_score_ = 0
#         loss_ = 0
#         for j,data in enumerate(tqdm(train_dataloader,total = len(train_dataloader))):
#             input_image_batch = data[0].cuda()
#             mask_batch = data[1].cuda()
#             optimizer.zero_grad()
#             output = model(input_image_batch)
#             loss = cfg.LOSS_CROSSENTROPY(output,mask_batch)#+cfg.LOSS_DICE(output,mask_batch.unsqueeze(1))
#             loss.backward()
#             optimizer.step()
#             loss_+=loss.item()
#             dice_score_+= 1-cfg.LOSS_DICE(output.detach().cpu(),mask_batch.detach().cpu())#1-utils.DiceLoss()(output.detach(),mask_batch.unsqueeze(1))
            
#     # 1-cfg.LOSS_DICE(output,mask_batch)#
#         dice_score_train,loss_train, = calculate_metrics(model,train_dataloader)
#         dice_score_valid,loss_valid, = calculate_metrics(model,valid_dataloader)
#         print(f"train dice score : {dice_score_train}, train loss {loss_train}")
#         print(f"valid dice score : {dice_score_valid}, valid loss {loss_valid}")
#         print(f"loop train dice : {dice_score_/len(train_dataloader)}, train loss {loss_/len(train_dataloader)}")
#         log.loc[epoch,:] = [f"fold_{i}_{cfg.ENCODER}.pth",f"{loss_train}",f"{dice_score_train}",f"{loss_valid}",f"{dice_score_valid}"]
#         log.to_csv(cfg.MODEL_SAVEPATH+f"/fold_{i}_{cfg.ENCODER}_CE_Valid_slicing_all.csv",index=False)
            
#         if loss_train < best_loss:
#             best_loss = loss_train
#             torch.save(model.state_dict(),cfg.MODEL_SAVEPATH+f"/fold_{i}_{cfg.ENCODER}_CE_Valid_slicing_all.pth")
    
    
    
    

                
        
            
            
    
    
    

  0%|          | 0/47 [00:00<?, ?it/s]

Exception ignored in: <generator object tqdm.__iter__ at 0x2b2a2d6fb200>
Traceback (most recent call last):
  File "/N/project/DL_MRI/Myocarditis-segmentation/vm_sripad/lib/python3.9/site-packages/tqdm/std.py", line 1181, in __iter__
    yield obj
KeyboardInterrupt: 


KeyboardInterrupt: 