#### Training

This notebook contains code to train the model for crypt segmentation in colon images. We use segmentation models to train 

In [1]:
import torch.nn as nn
import torch
import segmentation_models_pytorch as sm
import numpy as np
import pandas as pd
import skimage.io as io
from PIL import Image
import cv2
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import tifffile
from sklearn.model_selection import KFold
import glob
import torch_optimizer as t_optim
import utils
import torch.optim as optim
import tqdm as tqdm
# import pytorch_lightning as pl
# from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping



In [2]:
class Colon_Dataset(Dataset):
    def __init__(self,data_csv_path:str="train_data.csv",indexes:list= None,valid:bool=False,transform:transforms = None, target_transform:transforms=None,preprocessing=None):
        self.data = pd.read_csv(data_csv_path)
        self.indexed_data = self.data.iloc[indexes,:]
        self.transform = transform
        self.target_transform = target_transform
        self.preprocessing = preprocessing
        
    def __getitem__(self,idx):
        
        image = tifffile.imread(self.indexed_data.iloc[idx,0])
        mask = tifffile.imread(self.indexed_data.iloc[idx,1]).astype(float)
        
        
        
        
        if self.transform:
            augmentations = self.transform(image=image,mask=mask)
            image,mask = augmentations['image'],augmentations['mask']
            
            
        if self.preprocessing:
            preprocessed = self.preprocessing(image=image,mask=mask)
            image,mask = preprocessed['image'],preprocessed['mask']
        
            
#         backgroud_mask = (mask==0)[:,:,0]
#         foreground_mask = (mask==255)[:,:,0]
        
#         mask = torch.stack([backgroud_mask,foreground_mask])
    
        return image.float(),(mask[:,:,0]/255.0).type(torch.LongTensor)
    
    
    def __len__(self):
        return len(self.indexed_data)        

In [3]:
train_transform =  A.Compose([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9,
                         border_mode=cv2.BORDER_REFLECT),
        A.OneOf([
            A.ElasticTransform(p=.3),
            A.GaussianBlur(p=.3),
            A.GaussNoise(p=.3),
            A.OpticalDistortion(p=0.3),
            A.GridDistortion(p=.1),
            A.PiecewiseAffine(p=0.3),
        ], p=0.3),
        A.OneOf([
            A.HueSaturationValue(15,25,0),
            A.CLAHE(clip_limit=2),
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
        ], p=0.3),
    
        
    ])

validation_transform = A.Compose([ToTensorV2()])

def preprocessing_fucntion(preprocesing_function=None):
    return A.Compose([A.Lambda(image=preprocesing_function),ToTensorV2()])
    

In [4]:
#Defining configurations
class Configuration:
    MODEL_SAVEPATH = "models/"
    ENCODER = "efficientnet-b2"#"se_resnext50_32x4d"
    PRETRAINED_WEIGHTS = "imagenet"
    BATCH_SIZE = 16
    INPUT_CHANNELS = 3

    INPUT_SHAPE = (512,512,3)
    NFOLDS = 5
    ACTIVATION = None
    CLASSES = 2 #(crypts 1 background 0)
    DEVICE = "cuda" if torch.cuda.is_available() else cpu
    EPOCHS = 15
    LOSS_CROSSENTROPY = nn.CrossEntropyLoss() 
    LOSS_DICE = utils.DiceLoss()
    WEIGHT_DECAY = 1e-4
    METRICS = [sm.utils.metrics.IoU(0.5),
sm.utils.metrics.Fscore()]
    PREPROCESS = sm.encoders.get_preprocessing_fn(ENCODER,PRETRAINED_WEIGHTS)
    
    
    
    
    
cfg = Configuration()  

In [5]:
#init kfold

In [6]:
Train_ids = pd.read_csv("Colonic_crypt_dataset/train.csv").iloc[0:-1,:]['id'].values

In [7]:
kfold = KFold(cfg.NFOLDS, shuffle=True, random_state=0)


In [8]:
df = pd.read_csv("train_data.csv")


In [9]:
def calculate_metrics(model,data_loader):
    model.eval()
    total_loss = 0
    total_dice = 0 #batch wise dice loss
    with torch.no_grad():
        for data in tqdm.tqdm(data_loader,total=len(data_loader)):
            im = data[0].cuda()
            mask = data[1].cuda()
            out = model(im)
            loss = cfg.LOSS_CROSSENTROPY(out,mask) #+ cfg.LOSS_DICE(out,mask.unsqueeze(1))
            total_loss+=loss.item()
            total_dice+= 1-utils.DiceLoss()(out,mask.unsqueeze(1))
            
    return total_dice/len(data_loader),total_loss/len(data_loader)

In [None]:
logs = []
for i, (train_idx, val_idx) in enumerate(kfold.split(Train_ids)):
    log = pd.DataFrame(columns=["model_name","train_loss","train_dice","valid_loss","valid_dice"])
    train_ids = (df[df.Train_image_path.str.contains("|".join(Train_ids[train_idx]))]).index
    valid_ids = (df[df.Train_image_path.str.contains("|".join(Train_ids[val_idx]))]).index
    train_dataset = Colon_Dataset("train_data.csv",indexes=train_ids,transform=train_transform,preprocessing=preprocessing_fucntion(cfg.PREPROCESS))
    valid_dataset = Colon_Dataset("train_data.csv",indexes=valid_ids,preprocessing=preprocessing_fucntion(cfg.PREPROCESS))
    
    train_dataloader = DataLoader(train_dataset,batch_size=cfg.BATCH_SIZE,shuffle=True)
    valid_dataloader = DataLoader(valid_dataset,batch_size=cfg.BATCH_SIZE,shuffle=False)
    model = sm.Unet(encoder_name=cfg.ENCODER, 
                     encoder_weights=cfg.PRETRAINED_WEIGHTS, 
                     in_channels=cfg.INPUT_CHANNELS, 
                     classes=cfg.CLASSES)
    model.cuda()
    optimizer = t_optim.Ranger(model.parameters(),)#optim.Adam(model.parameters())
    best_loss = 99
    total_train_loss = 0
    
    for epoch in range(cfg.EPOCHS):
        for j,data in enumerate(tqdm.tqdm(train_dataloader,total = len(train_dataloader))):
            input_image_batch = data[0].cuda()
            mask_batch = data[1].cuda()
            optimizer.zero_grad()
            output = model(input_image_batch)
            loss = cfg.LOSS_CROSSENTROPY(output,mask_batch)#+cfg.LOSS_DICE(output,mask_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
    
        dice_score_train,loss_train = calculate_metrics(model,train_dataloader)
        dice_score_valid,loss_valid = calculate_metrics(model,valid_dataloader)
        print(f"train dice score : {dice_score_train}, train loss {loss_train}")
        print(f"valid dice score : {dice_score_valid}, valid loss {loss_valid}")
        log.loc[epoch,:] = [f"fold_{i}_{cfg.ENCODER}.pth",f"{loss_train}",f"{dice_score_train}",f"{loss_valid}",f"{dice_score_valid}"]
            
        if loss_train < best_loss:
            best_loss = loss_train
            torch.save(model.state_dict(),cfg.MODEL_SAVEPATH+f"/fold_{i}_{cfg.ENCODER}_CE_Valid_slicing_all.pth")
    logs.append(log)
    
    break
    

                
        
            
            
    
    
    

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1050.)
  exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
100%|███████████████████████████████████████████| 47/47 [01:40<00:00,  2.14s/it]
100%|███████████████████████████████████████████| 47/47 [01:02<00:00,  1.33s/it]
100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.74it/s]


train dice score : 0.3657574951648712, train loss 0.9188516203393328
valid dice score : 0.3642618954181671, valid loss 0.9838464334607124


100%|███████████████████████████████████████████| 47/47 [01:30<00:00,  1.93s/it]
100%|███████████████████████████████████████████| 47/47 [01:04<00:00,  1.38s/it]
100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.02it/s]


train dice score : 0.5251930356025696, train loss 0.2702587628935246
valid dice score : 0.5154837965965271, valid loss 0.3009013459086418


100%|███████████████████████████████████████████| 47/47 [01:23<00:00,  1.78s/it]
100%|███████████████████████████████████████████| 47/47 [01:00<00:00,  1.29s/it]
100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.04it/s]


train dice score : 0.730129599571228, train loss 0.12622607856037768
valid dice score : 0.6994335651397705, valid loss 0.16591704543679953


100%|███████████████████████████████████████████| 47/47 [01:19<00:00,  1.69s/it]
100%|███████████████████████████████████████████| 47/47 [01:20<00:00,  1.71s/it]
100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.85it/s]


train dice score : 0.8035935759544373, train loss 0.09376924374002091
valid dice score : 0.7642971277236938, valid loss 0.13450106605887413


100%|███████████████████████████████████████████| 47/47 [01:34<00:00,  2.01s/it]
100%|███████████████████████████████████████████| 47/47 [02:02<00:00,  2.61s/it]
100%|█████████████████████████████████████████████| 8/8 [00:06<00:00,  1.16it/s]


train dice score : 0.8238102197647095, train loss 0.08140059956844817
valid dice score : 0.7768575549125671, valid loss 0.12819277867674828


100%|███████████████████████████████████████████| 47/47 [02:33<00:00,  3.27s/it]
100%|███████████████████████████████████████████| 47/47 [01:34<00:00,  2.01s/it]
100%|█████████████████████████████████████████████| 8/8 [00:06<00:00,  1.25it/s]


train dice score : 0.8154136538505554, train loss 0.076536071110279
valid dice score : 0.7641127109527588, valid loss 0.12386440811678767


100%|███████████████████████████████████████████| 47/47 [02:31<00:00,  3.22s/it]
100%|███████████████████████████████████████████| 47/47 [01:30<00:00,  1.93s/it]
100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.95it/s]


train dice score : 0.8724920153617859, train loss 0.06398663281443272
valid dice score : 0.8159602284431458, valid loss 0.13155545736663043


100%|███████████████████████████████████████████| 47/47 [01:45<00:00,  2.24s/it]
100%|███████████████████████████████████████████| 47/47 [01:49<00:00,  2.32s/it]
100%|█████████████████████████████████████████████| 8/8 [00:06<00:00,  1.30it/s]


train dice score : 0.8766228556632996, train loss 0.062284164586441315
valid dice score : 0.8182024359703064, valid loss 0.13004957768134773


100%|███████████████████████████████████████████| 47/47 [02:16<00:00,  2.91s/it]
100%|███████████████████████████████████████████| 47/47 [01:16<00:00,  1.62s/it]
100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.10it/s]


train dice score : 0.8863102793693542, train loss 0.057779832842185144
valid dice score : 0.8131837844848633, valid loss 0.16279348800890148


100%|███████████████████████████████████████████| 47/47 [01:21<00:00,  1.74s/it]
100%|███████████████████████████████████████████| 47/47 [01:00<00:00,  1.29s/it]
100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.01it/s]


train dice score : 0.8557848334312439, train loss 0.06115677936914119
valid dice score : 0.7727298736572266, valid loss 0.13604851393029094


100%|███████████████████████████████████████████| 47/47 [01:28<00:00,  1.88s/it]
100%|███████████████████████████████████████████| 47/47 [01:02<00:00,  1.33s/it]
100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.96it/s]


train dice score : 0.9084177017211914, train loss 0.04499176838138002
valid dice score : 0.8242107629776001, valid loss 0.13523804675787687


100%|███████████████████████████████████████████| 47/47 [02:06<00:00,  2.68s/it]
100%|███████████████████████████████████████████| 47/47 [01:53<00:00,  2.43s/it]
100%|█████████████████████████████████████████████| 8/8 [00:06<00:00,  1.27it/s]


train dice score : 0.8908453583717346, train loss 0.04759816486546968
valid dice score : 0.8194962739944458, valid loss 0.13376343436539173


100%|███████████████████████████████████████████| 47/47 [01:38<00:00,  2.09s/it]
100%|███████████████████████████████████████████| 47/47 [01:02<00:00,  1.34s/it]
100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.89it/s]


train dice score : 0.8844990730285645, train loss 0.05663048834639027
valid dice score : 0.8064718842506409, valid loss 0.13242571288719773


100%|███████████████████████████████████████████| 47/47 [01:40<00:00,  2.14s/it]
100%|███████████████████████████████████████████| 47/47 [01:05<00:00,  1.40s/it]
100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.80it/s]


train dice score : 0.8913809061050415, train loss 0.06653701371334969
valid dice score : 0.8006278872489929, valid loss 0.2557852154131979


 34%|██████████████▋                            | 16/47 [00:33<01:01,  1.99s/it]