In [1]:
import sys
sys.path.append("..") 

In [2]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import os
import torch.cuda
import albumentations as albu
import segmentation_models_pytorch as smp
from  segmentation_models_pytorch.utils.base import Metric
from segmentation_models_pytorch.base.modules import Activation
from collections import defaultdict

In [3]:
from Utils.dataset_utils import *
from Utils.visualization_utils import *

In [4]:
from torchsummary import summary
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [5]:
DATASET_2D_BASE_PATH=Path(r'C:\Users\lm3088\Documents\POM-CTproject\Muhammad\MiceCT_2Daxials\Axials')
WIDTH=320
HEIGHT=320

In [6]:


def get_training_augmentation():
    train_transform = [

        albu.Resize(HEIGHT, WIDTH),
        albu.ShiftScaleRotate(scale_limit=0.20, rotate_limit=30, shift_limit=0.1, p=1, border_mode=cv2.BORDER_CONSTANT),
        albu.RandomCrop(height=320, width=320),
        albu.Blur(blur_limit=3, p=0.4),
        albu.GaussNoise(p=0.5),
        albu.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        albu.RandomBrightness(p=0.75)

    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    test_transform = [
        albu.Resize(512, 512)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [7]:
df_train=load_data(DATASET_2D_BASE_PATH/'train')
df_val=load_data(DATASET_2D_BASE_PATH/'val')
df_train.head()

Unnamed: 0,images,masks
0,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...
1,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...
2,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...
3,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...
4,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...,C:\Users\lm3088\Documents\POM-CTproject\Muhamm...


In [8]:
class DCS(Metric):
    __name__ = 'DCS'

    def __init__(self, eps=0.00001, activation=None, ignore_channels=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
     
        self.activation = Activation(activation)
        self.ignore_channels = ignore_channels

    def forward(self, y_pr, y_gt):
        y_pr = self.activation(y_pr)
        dice_numerator = 2 * torch.sum(y_pr * y_gt) + self.eps
        dice_denominator = torch.sum(y_pr) + torch.sum(y_gt) + self.eps
        dice_coefficient = dice_numerator / dice_denominator
        return dice_coefficient


### Training

In [9]:
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'
TRAIN_RUNS_PATH=r'C:\Users\lm3088\Documents\GitHub\MicroCTsegmentation\runs'
MODEL_NAME='Unet'
BATCH_SIZE=8
EPOCHS=100


In [10]:
ENCODERS = ['se_resnext50_32x4d','efficientnet-b2','resnet101','resnet34','densenet121']

In [11]:
for j in range(len(ENCODERS)):
    
    model = smp.Unet(
    encoder_name=ENCODERS[j], 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
    )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODERS[j], ENCODER_WEIGHTS)
    
    WEIGHTS_PATH=os.path.join(TRAIN_RUNS_PATH,f'{MODEL_NAME}_{ENCODERS[j]}')
    if os.path.exists(WEIGHTS_PATH)==False:
        os.mkdir(WEIGHTS_PATH)
    else:
        print(f"Warning! Directory {WEIGHTS_PATH } already exists")
    
    train_dataset = Dataset(
    df_train['images'], 
    df_train['masks'], 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn)
    )

    valid_dataset = Dataset(
        df_val['images'], 
        df_val['masks'], 
        augmentation=get_validation_augmentation(), 
        preprocessing=get_preprocessing(preprocessing_fn)
    ) 

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    loss = smp.utils.losses.DiceLoss()
    metrics = [
        smp.utils.metrics.IoU(threshold=0.5),
        DCS()

    ]

    optimizer = torch.optim.Adam([ 
        dict(params=model.parameters(), lr=0.0001),
    ])
    train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
    )

    valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
    )
    min_loss = 100000000
    train_history=defaultdict(list)
    valid_history=defaultdict(list)

    for i in range(0, EPOCHS):

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        # do something (save model, change lr, etc.)
        if min_loss > valid_logs['dice_loss']:
            min_loss = valid_logs['dice_loss']
            torch.save(model, os.path.join(WEIGHTS_PATH,f'best_{str(i)}_{round(min_loss,4)}.pt'))
            print('Model saved!')

        if i == 25:
            optimizer.param_groups[0]['lr'] = 1e-5
            print('Decrease decoder learning rate to 1e-5!')
        # Maintain History
        for log_key in train_logs.keys():
            train_history[log_key].append(train_logs[log_key])
            valid_history[log_key].append(valid_logs[log_key])
    pd.DataFrame(valid_history).to_csv(os.path.join(WEIGHTS_PATH,'validation_logs.csv'))
    pd.DataFrame(train_history).to_csv(os.path.join(WEIGHTS_PATH,'train_logs.csv'))
    print(f"{WEIGHTS_PATH} Completed!!!")






Epoch: 0
train: 100%|█████████████████████| 576/576 [02:43<00:00,  3.53it/s, dice_loss - 0.583, iou_score - 0.4728, DCS - 0.417]
valid: 100%|███████████████████| 192/192 [00:43<00:00,  4.36it/s, dice_loss - 0.7285, iou_score - 0.5773, DCS - 0.2715]
Model saved!

Epoch: 1
train: 100%|███████████████████| 576/576 [02:39<00:00,  3.61it/s, dice_loss - 0.2357, iou_score - 0.6886, DCS - 0.7643]
valid: 100%|███████████████████| 192/192 [00:43<00:00,  4.38it/s, dice_loss - 0.6946, iou_score - 0.6369, DCS - 0.3051]
Model saved!

Epoch: 2
train: 100%|███████████████████| 576/576 [02:39<00:00,  3.61it/s, dice_loss - 0.1845, iou_score - 0.7189, DCS - 0.8155]
valid: 100%|████████████████████| 192/192 [00:44<00:00,  4.35it/s, dice_loss - 0.6799, iou_score - 0.698, DCS - 0.3195]
Model saved!

Epoch: 3
train: 100%|███████████████████| 576/576 [02:41<00:00,  3.58it/s, dice_loss - 0.1681, iou_score - 0.7341, DCS - 0.8319]
valid: 100%|███████████████████| 192/192 [00:44<00:00,  4.34it/s, dice_loss - 0.6

valid: 100%|████████████████████| 192/192 [00:43<00:00,  4.46it/s, dice_loss - 0.1943, iou_score - 0.7648, DCS - 0.372]

Epoch: 64
train: 100%|██████████████████| 576/576 [02:36<00:00,  3.68it/s, dice_loss - 0.07959, iou_score - 0.8563, DCS - 0.9101]
valid: 100%|█████████████████████| 192/192 [00:43<00:00,  4.46it/s, dice_loss - 0.184, iou_score - 0.774, DCS - 0.3705]
Model saved!

Epoch: 65
train: 100%|██████████████████| 576/576 [02:35<00:00,  3.70it/s, dice_loss - 0.08223, iou_score - 0.8533, DCS - 0.9092]
valid: 100%|███████████████████| 192/192 [00:43<00:00,  4.43it/s, dice_loss - 0.1929, iou_score - 0.7673, DCS - 0.3697]

Epoch: 66
train: 100%|██████████████████| 576/576 [02:36<00:00,  3.69it/s, dice_loss - 0.07848, iou_score - 0.8581, DCS - 0.9161]
valid: 100%|███████████████████| 192/192 [00:43<00:00,  4.45it/s, dice_loss - 0.1932, iou_score - 0.7652, DCS - 0.3672]

Epoch: 67
train: 100%|██████████████████| 576/576 [02:36<00:00,  3.69it/s, dice_loss - 0.07665, iou_score - 0.859

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth" to C:\Users\lm3088/.cache\torch\hub\checkpoints\efficientnet-b2-8bb594d6.pth


  0%|          | 0.00/35.1M [00:00<?, ?B/s]


Epoch: 0
train: 100%|███████████████████| 576/576 [02:26<00:00,  3.93it/s, dice_loss - 0.5677, iou_score - 0.4376, DCS - 0.4323]
valid: 100%|████████████████████| 192/192 [00:37<00:00,  5.11it/s, dice_loss - 0.686, iou_score - 0.6683, DCS - 0.3139]
Model saved!

Epoch: 1
train: 100%|███████████████████| 576/576 [02:26<00:00,  3.92it/s, dice_loss - 0.2426, iou_score - 0.6643, DCS - 0.7574]
valid: 100%|███████████████████| 192/192 [00:38<00:00,  4.97it/s, dice_loss - 0.6711, iou_score - 0.5799, DCS - 0.3286]
Model saved!

Epoch: 2
train: 100%|████████████████████| 576/576 [02:26<00:00,  3.94it/s, dice_loss - 0.1908, iou_score - 0.708, DCS - 0.8092]
valid: 100%|███████████████████| 192/192 [00:39<00:00,  4.87it/s, dice_loss - 0.6564, iou_score - 0.6545, DCS - 0.3427]
Model saved!

Epoch: 3
train: 100%|███████████████████| 576/576 [02:27<00:00,  3.90it/s, dice_loss - 0.1639, iou_score - 0.7344, DCS - 0.8361]
valid: 100%|████████████████████| 192/192 [00:38<00:00,  5.00it/s, dice_loss - 0.

valid: 100%|███████████████████| 192/192 [00:37<00:00,  5.12it/s, dice_loss - 0.2078, iou_score - 0.7344, DCS - 0.3712]

Epoch: 64
train: 100%|████████████████████| 576/576 [02:25<00:00,  3.95it/s, dice_loss - 0.094, iou_score - 0.8307, DCS - 0.8966]
valid: 100%|█████████████████████| 192/192 [00:37<00:00,  5.10it/s, dice_loss - 0.2096, iou_score - 0.73, DCS - 0.3717]

Epoch: 65
train: 100%|██████████████████| 576/576 [02:24<00:00,  3.98it/s, dice_loss - 0.09447, iou_score - 0.8323, DCS - 0.8987]
valid: 100%|███████████████████| 192/192 [00:37<00:00,  5.14it/s, dice_loss - 0.2033, iou_score - 0.7365, DCS - 0.3726]

Epoch: 66
train: 100%|██████████████████| 576/576 [02:25<00:00,  3.97it/s, dice_loss - 0.09116, iou_score - 0.8365, DCS - 0.9003]
valid: 100%|███████████████████| 192/192 [00:37<00:00,  5.14it/s, dice_loss - 0.2117, iou_score - 0.7297, DCS - 0.3714]

Epoch: 67
train: 100%|██████████████████| 576/576 [02:25<00:00,  3.96it/s, dice_loss - 0.08918, iou_score - 0.8382, DCS - 0.90

valid: 100%|███████████████████| 192/192 [00:43<00:00,  4.43it/s, dice_loss - 0.3319, iou_score - 0.7263, DCS - 0.3493]
Model saved!

Epoch: 28
train: 100%|███████████████████| 576/576 [02:39<00:00,  3.62it/s, dice_loss - 0.1071, iou_score - 0.8138, DCS - 0.8893]
valid: 100%|███████████████████| 192/192 [00:43<00:00,  4.43it/s, dice_loss - 0.3235, iou_score - 0.7131, DCS - 0.3572]
Model saved!

Epoch: 29
train: 100%|███████████████████| 576/576 [02:38<00:00,  3.63it/s, dice_loss - 0.1102, iou_score - 0.8104, DCS - 0.8853]
valid: 100%|████████████████████| 192/192 [00:43<00:00,  4.43it/s, dice_loss - 0.2965, iou_score - 0.743, DCS - 0.3552]
Model saved!

Epoch: 30
train: 100%|███████████████████| 576/576 [02:40<00:00,  3.60it/s, dice_loss - 0.1055, iou_score - 0.8166, DCS - 0.8894]
valid: 100%|███████████████████| 192/192 [00:43<00:00,  4.37it/s, dice_loss - 0.2982, iou_score - 0.7373, DCS - 0.3544]

Epoch: 31
train: 100%|███████████████████| 576/576 [02:39<00:00,  3.62it/s, dice_loss -

valid: 100%|███████████████████| 192/192 [00:44<00:00,  4.28it/s, dice_loss - 0.2419, iou_score - 0.6982, DCS - 0.3636]

Epoch: 92
train: 100%|██████████████████| 576/576 [02:42<00:00,  3.55it/s, dice_loss - 0.09092, iou_score - 0.8373, DCS - 0.9012]
valid: 100%|████████████████████| 192/192 [00:44<00:00,  4.28it/s, dice_loss - 0.2615, iou_score - 0.679, DCS - 0.3698]

Epoch: 93
train: 100%|███████████████████| 576/576 [02:42<00:00,  3.55it/s, dice_loss - 0.0902, iou_score - 0.8386, DCS - 0.9012]
valid: 100%|███████████████████| 192/192 [00:44<00:00,  4.29it/s, dice_loss - 0.2417, iou_score - 0.7015, DCS - 0.3668]

Epoch: 94
train: 100%|██████████████████| 576/576 [02:43<00:00,  3.53it/s, dice_loss - 0.09033, iou_score - 0.8391, DCS - 0.9036]
valid: 100%|███████████████████| 192/192 [00:44<00:00,  4.28it/s, dice_loss - 0.2367, iou_score - 0.7017, DCS - 0.3666]

Epoch: 95
train: 100%|██████████████████| 576/576 [02:41<00:00,  3.57it/s, dice_loss - 0.08943, iou_score - 0.8413, DCS - 0.90

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to C:\Users\lm3088/.cache\torch\hub\checkpoints\resnet34-333f7ec4.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]


Epoch: 0
train: 100%|███████████████████| 576/576 [02:06<00:00,  4.54it/s, dice_loss - 0.6727, iou_score - 0.2483, DCS - 0.3272]
valid: 100%|███████████████████| 192/192 [00:36<00:00,  5.19it/s, dice_loss - 0.8285, iou_score - 0.6021, DCS - 0.1714]
Model saved!

Epoch: 1
train: 100%|███████████████████| 576/576 [02:07<00:00,  4.53it/s, dice_loss - 0.4711, iou_score - 0.6229, DCS - 0.5289]
valid: 100%|█████████████████████| 192/192 [00:36<00:00,  5.22it/s, dice_loss - 0.7693, iou_score - 0.6335, DCS - 0.23]
Model saved!

Epoch: 2
train: 100%|███████████████████| 576/576 [02:07<00:00,  4.52it/s, dice_loss - 0.4295, iou_score - 0.6859, DCS - 0.5705]
valid: 100%|███████████████████| 192/192 [00:36<00:00,  5.20it/s, dice_loss - 0.8126, iou_score - 0.6464, DCS - 0.1859]

Epoch: 3
train: 100%|███████████████████| 576/576 [02:07<00:00,  4.51it/s, dice_loss - 0.3957, iou_score - 0.7346, DCS - 0.6043]
valid: 100%|███████████████████| 192/192 [00:37<00:00,  5.19it/s, dice_loss - 0.7552, iou_scor

valid: 100%|███████████████████| 192/192 [00:36<00:00,  5.20it/s, dice_loss - 0.1826, iou_score - 0.7598, DCS - 0.3744]

Epoch: 64
train: 100%|███████████████████| 576/576 [02:07<00:00,  4.52it/s, dice_loss - 0.08552, iou_score - 0.8447, DCS - 0.911]
valid: 100%|███████████████████| 192/192 [00:36<00:00,  5.22it/s, dice_loss - 0.1805, iou_score - 0.7592, DCS - 0.3742]

Epoch: 65
train: 100%|██████████████████| 576/576 [02:07<00:00,  4.51it/s, dice_loss - 0.08588, iou_score - 0.8454, DCS - 0.9038]
valid: 100%|███████████████████| 192/192 [00:36<00:00,  5.22it/s, dice_loss - 0.1893, iou_score - 0.7488, DCS - 0.3741]

Epoch: 66
train: 100%|█████████████████████| 576/576 [02:07<00:00,  4.52it/s, dice_loss - 0.08986, iou_score - 0.84, DCS - 0.908]
valid: 100%|███████████████████| 192/192 [00:36<00:00,  5.20it/s, dice_loss - 0.1673, iou_score - 0.7747, DCS - 0.3738]
Model saved!

Epoch: 67
train: 100%|██████████████████| 576/576 [02:06<00:00,  4.55it/s, dice_loss - 0.08871, iou_score - 0.842

Downloading: "http://data.lip6.fr/cadene/pretrainedmodels/densenet121-fbdb23505.pth" to C:\Users\lm3088/.cache\torch\hub\checkpoints\densenet121-fbdb23505.pth


  0%|          | 0.00/30.9M [00:00<?, ?B/s]


Epoch: 0
train: 100%|███████████████████| 576/576 [02:36<00:00,  3.68it/s, dice_loss - 0.6668, iou_score - 0.4141, DCS - 0.3332]
valid: 100%|███████████████████| 192/192 [00:42<00:00,  4.52it/s, dice_loss - 0.7796, iou_score - 0.6158, DCS - 0.2203]
Model saved!

Epoch: 1
train: 100%|███████████████████| 576/576 [02:35<00:00,  3.70it/s, dice_loss - 0.3096, iou_score - 0.6606, DCS - 0.6904]
valid: 100%|███████████████████| 192/192 [00:42<00:00,  4.52it/s, dice_loss - 0.6797, iou_score - 0.5977, DCS - 0.3202]
Model saved!

Epoch: 2
train: 100%|████████████████████| 576/576 [02:36<00:00,  3.68it/s, dice_loss - 0.2064, iou_score - 0.711, DCS - 0.7936]
valid: 100%|████████████████████| 192/192 [00:42<00:00,  4.52it/s, dice_loss - 0.664, iou_score - 0.6239, DCS - 0.3357]
Model saved!

Epoch: 3
train: 100%|████████████████████| 576/576 [02:35<00:00,  3.70it/s, dice_loss - 0.1723, iou_score - 0.735, DCS - 0.8277]
valid: 100%|███████████████████| 192/192 [00:42<00:00,  4.53it/s, dice_loss - 0.6

valid: 100%|████████████████████| 192/192 [00:42<00:00,  4.53it/s, dice_loss - 0.221, iou_score - 0.7352, DCS - 0.3695]

Epoch: 64
train: 100%|███████████████████| 576/576 [02:36<00:00,  3.69it/s, dice_loss - 0.08629, iou_score - 0.844, DCS - 0.9047]
valid: 100%|███████████████████| 192/192 [00:42<00:00,  4.50it/s, dice_loss - 0.2152, iou_score - 0.7441, DCS - 0.3733]

Epoch: 65
train: 100%|██████████████████| 576/576 [02:35<00:00,  3.70it/s, dice_loss - 0.08785, iou_score - 0.8441, DCS - 0.9019]
valid: 100%|████████████████████| 192/192 [00:42<00:00,  4.49it/s, dice_loss - 0.2088, iou_score - 0.7458, DCS - 0.371]

Epoch: 66
train: 100%|██████████████████| 576/576 [02:35<00:00,  3.69it/s, dice_loss - 0.08466, iou_score - 0.8464, DCS - 0.9116]
valid: 100%|███████████████████| 192/192 [00:42<00:00,  4.48it/s, dice_loss - 0.2007, iou_score - 0.7572, DCS - 0.3718]

Epoch: 67
train: 100%|██████████████████| 576/576 [02:36<00:00,  3.68it/s, dice_loss - 0.08652, iou_score - 0.8439, DCS - 0.90