In [1]:
import sys
sys.path.append("..") 

In [2]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import os
import torch.cuda
import albumentations as albu
import segmentation_models_pytorch as smp
from  segmentation_models_pytorch.utils.base import Metric
from segmentation_models_pytorch.base.modules import Activation
from collections import defaultdict

In [3]:
from Utils.dataset_utils import *
from Utils.visualization_utils import *

In [4]:
from torchsummary import summary
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [5]:
DATASET_2D_BASE_PATH=Path('G:\Projects and Work\Mouse Heart Segmentation\Mice_CT_Dataset\Axials')
WIDTH=320
HEIGHT=320

In [6]:


def get_training_augmentation():
    train_transform = [

        albu.Resize(HEIGHT, WIDTH),
        albu.ShiftScaleRotate(scale_limit=0.20, rotate_limit=30, shift_limit=0.1, p=1, border_mode=cv2.BORDER_CONSTANT),
        albu.RandomCrop(height=320, width=320),
        albu.Blur(blur_limit=3, p=0.4),
        albu.GaussNoise(p=0.5),
        albu.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        albu.RandomBrightness(p=0.75)

    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    test_transform = [
        albu.Resize(512, 512)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [7]:
df_train=load_data(DATASET_2D_BASE_PATH/'train')
df_val=load_data(DATASET_2D_BASE_PATH/'val')
df_train.head()

Unnamed: 0,images,masks
0,G:\Projects and Work\Mouse Heart Segmentation\...,G:\Projects and Work\Mouse Heart Segmentation\...
1,G:\Projects and Work\Mouse Heart Segmentation\...,G:\Projects and Work\Mouse Heart Segmentation\...
2,G:\Projects and Work\Mouse Heart Segmentation\...,G:\Projects and Work\Mouse Heart Segmentation\...
3,G:\Projects and Work\Mouse Heart Segmentation\...,G:\Projects and Work\Mouse Heart Segmentation\...
4,G:\Projects and Work\Mouse Heart Segmentation\...,G:\Projects and Work\Mouse Heart Segmentation\...


In [8]:
class DCS(Metric):
    __name__ = 'DCS'

    def __init__(self, eps=0.00001, activation=None, ignore_channels=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
     
        self.activation = Activation(activation)
        self.ignore_channels = ignore_channels

    def forward(self, y_pr, y_gt):
        y_pr = self.activation(y_pr)
        dice_numerator = 2 * torch.sum(y_pr * y_gt) + self.eps
        dice_denominator = torch.sum(y_pr) + torch.sum(y_gt) + self.eps
        dice_coefficient = dice_numerator / dice_denominator
        return dice_coefficient


### Training

In [9]:
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'
TRAIN_RUNS_PATH=r'G:\Projects and Work\Mouse Heart Segmentation\runs'
MODEL_NAME='Unet'
BATCH_SIZE=8
EPOCHS=100


In [10]:
ENCODERS = ['se_resnext50_32x4d','efficientnet-b2','resnet101','resnet34','densenet121']

In [11]:
for j in range(len(ENCODERS)):
    
    model = smp.Unet(
    encoder_name=ENCODERS[j], 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
    )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODERS[j], ENCODER_WEIGHTS)
    
    WEIGHTS_PATH=os.path.join(TRAIN_RUNS_PATH,f'{MODEL_NAME}_{ENCODERS[j]}')
    if os.path.exists(WEIGHTS_PATH)==False:
        os.mkdir(WEIGHTS_PATH)
    else:
        print(f"Warning! Directory {WEIGHTS_PATH } already exists")
    
    train_dataset = Dataset(
    df_train['images'], 
    df_train['masks'], 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn)
    )

    valid_dataset = Dataset(
        df_val['images'], 
        df_val['masks'], 
        augmentation=get_validation_augmentation(), 
        preprocessing=get_preprocessing(preprocessing_fn)
    ) 

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    loss = smp.utils.losses.DiceLoss()
    metrics = [
        smp.utils.metrics.IoU(threshold=0.5),
        DCS()

    ]

    optimizer = torch.optim.Adam([ 
        dict(params=model.parameters(), lr=0.0001),
    ])
    train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
    )

    valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
    )
    min_loss = 100000000
    train_history=defaultdict(list)
    valid_history=defaultdict(list)

    for i in range(0, EPOCHS):

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        # do something (save model, change lr, etc.)
        if min_loss > valid_logs['dice_loss']:
            min_loss = valid_logs['dice_loss']
            torch.save(model, os.path.join(WEIGHTS_PATH,f'best_{str(i)}_{round(min_loss,4)}.pt'))
            print('Model saved!')

        if i == 25:
            optimizer.param_groups[0]['lr'] = 1e-5
            print('Decrease decoder learning rate to 1e-5!')
        # Maintain History
        for log_key in train_logs.keys():
            train_history[log_key].append(train_logs[log_key])
            valid_history[log_key].append(valid_logs[log_key])
    pd.DataFrame(valid_history).to_csv(os.path.join(WEIGHTS_PATH,'validation_logs.csv'))
    pd.DataFrame(train_history).to_csv(os.path.join(WEIGHTS_PATH,'train_logs.csv'))
    print(f"{WEIGHTS_PATH} Completed!!!")




Epoch: 0
train: 100%|███████████████████| 576/576 [05:59<00:00,  1.60it/s, dice_loss - 0.5367, iou_score - 0.4991, DCS - 0.4633]
valid: 100%|███████████████████| 192/192 [01:28<00:00,  2.16it/s, dice_loss - 0.6823, iou_score - 0.6554, DCS - 0.3177]
Model saved!

Epoch: 1
train: 100%|███████████████████| 576/576 [05:56<00:00,  1.62it/s, dice_loss - 0.2124, iou_score - 0.7065, DCS - 0.7876]
valid: 100%|███████████████████| 192/192 [01:29<00:00,  2.14it/s, dice_loss - 0.6863, iou_score - 0.6571, DCS - 0.3135]

Epoch: 2
train: 100%|███████████████████| 576/576 [05:56<00:00,  1.62it/s, dice_loss - 0.1689, iou_score - 0.7388, DCS - 0.8311]
valid: 100%|████████████████████| 192/192 [01:29<00:00,  2.14it/s, dice_loss - 0.772, iou_score - 0.5726, DCS - 0.2273]

Epoch: 3
train: 100%|█████████████████████| 576/576 [05:56<00:00,  1.61it/s, dice_loss - 0.148, iou_score - 0.7573, DCS - 0.852]
valid: 100%|███████████████████| 192/192 [01:29<00:00,  2.14it/s, dice_loss - 0.6837, iou_score - 0.6807, D

KeyboardInterrupt: 