### Global definitions

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from models.basic_unet import BasicUNET
from tqdm import tqdm
from loader import load_BrainTissue_data
from PIL import Image
from models_training import save_params, train_network
import configs
import numpy as np

In [3]:
if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    print('Running on the GPU')
else:
    DEVICE = 'cpu'
    print("Running on the CPU")

TRAIN_IDX_START = 1
TRAIN_IDX_STOP = 33
VAL_IDX_START = 34
VAL_IDX_STOP = 35

get_model_path = lambda model_name, axis : f"{configs.BASE_PATH}/trained_models/model_files/{model_name}_{axis}"
get_parms_path = lambda model_name, axis :  f"{configs.BASE_PATH}/trained_models/model_params/{model_name}_{axis}_params"
get_loss_path = lambda model_name, axis : f"{configs.BASE_PATH}/trained_models/model_losses/{model_name}_{axis}_loss"

Running on the CPU


### 3 x Basic Unet 

In [4]:
# config params
load_model = False
model_name = 'basic_unet_test'
axes = ['X', 'Y', 'Z']

for axis in axes:
    print(f'===== Training {model_name} for axis {axis} =====')

    # networks params
    network = BasicUNET(in_channels=1, classes=8)
    optimizer_name = "adam"
    learning_rate = 0.0005
    loss_fn = nn.CrossEntropyLoss()
    batch_size = 16
    epochs = 5
    img_height = 32
    img_width = 32

    # data
    transform = transforms.Compose([
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST)
    ])
    train_loader = load_BrainTissue_data(configs.DATA_2D_PATH, TRAIN_IDX_START, TRAIN_IDX_STOP, axis, transform, batch_size)
    val_loader = load_BrainTissue_data(configs.DATA_2D_PATH, VAL_IDX_START, VAL_IDX_STOP, axis, transform, batch_size)

    # save model params
    save_params(
        model_name=model_name,
        axis=axis,
        optimizer_name=optimizer_name,
        loss_fn=str(loss_fn),
        batch_size=batch_size,
        epochs=epochs,
        img_height=img_height,
        img_width=img_width,
        parms_path=get_parms_path(model_name, axis)
    )

    # train network
    train_network(
        train_loader=train_loader,
        val_loader=val_loader, 
        network=network, 
        optimizer_name=optimizer_name, 
        learning_rate=learning_rate, 
        loss_fn=loss_fn, 
        batch_size=batch_size, 
        epochs=epochs, 
        load_model=load_model,
        model_path=get_model_path(model_name, axis),  
        loss_path=get_loss_path(model_name, axis),
        device=DEVICE
    )


===== Training basic_unet_test for axis X =====
  Epoch: 0
[epoch: 0, batches:     0 -    15] train loss: 0.859
[epoch: 0, batches:    16 -    31] train loss: 0.299
[epoch: 0, batches:    32 -    47] train loss: 0.223
[epoch: 0, batches:    48 -    63] train loss: 0.211
[epoch: 0, batches:    64 -    79] train loss: 0.199
[epoch: 0, batches:    80 -    95] train loss: 0.215
[epoch: 0, batches:    96 -   111] train loss: 0.203
[epoch: 0, batches:   112 -   127] train loss: 0.185
[epoch: 0, batches:   128 -   143] train loss: 0.184
[epoch: 0, batches:   144 -   159] train loss: 0.172
[epoch: 0, batches:   160 -   175] train loss: 0.169
[epoch: 0, batches:   176 -   191] train loss: 0.155
[epoch: 0, batches:   192 -   207] train loss: 0.158
[epoch: 0, batches:   208 -   223] train loss: 0.150
[epoch: 0, batches:   224 -   239] train loss: 0.152
[epoch: 0, batches:   240 -   255] train loss: 0.134
[epoch: 0, batches:   256 -   271] train loss: 0.131
[epoch: 0, batches:   272 -   287] train

### 3 x Basic Unet 64x64 + data augmentation

In [4]:
# config params
load_model = False
model_name = 'basic_unet_64_local'
axes = ['X', 'Y', 'Z']

for axis in axes:
    print(f'===== Training {model_name} for axis {axis} =====')

    # networks params
    network = BasicUNET(in_channels=1, classes=8)
    optimizer_name = "adam"
    learning_rate = 0.0005
    loss_fn = nn.CrossEntropyLoss()
    batch_size = 32
    epochs = 10
    img_height = 64
    img_width = 64

    # data
    train_transform = transforms.Compose([
        
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=50),
        transforms.RandomResizedCrop(size=(configs.IMG_ORYG_HEIGHT, configs.IMG_ORYG_WIDTH), scale=(0.6, 1.0)),
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST),
    ])
    val_transform = transforms.Compose([
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST),
    ])
    train_loader = load_BrainTissue_data(configs.DATA_2D_PATH, TRAIN_IDX_START, TRAIN_IDX_STOP, axis, train_transform, batch_size)
    val_loader = load_BrainTissue_data(configs.DATA_2D_PATH, VAL_IDX_START, VAL_IDX_STOP, axis, val_transform, batch_size)

    # save model params
    save_params(
        model_name=model_name,
        axis=axis,
        optimizer_name=optimizer_name,
        loss_fn=str(loss_fn),
        batch_size=batch_size,
        epochs=epochs,
        img_height=img_height,
        img_width=img_width,
        parms_path=get_parms_path(model_name, axis)
    )

    # train network
    train_network(
        train_loader=train_loader,
        val_loader=val_loader, 
        network=network, 
        optimizer_name=optimizer_name, 
        learning_rate=learning_rate, 
        loss_fn=loss_fn, 
        batch_size=batch_size, 
        epochs=epochs, 
        load_model=load_model,
        model_path=get_model_path(model_name, axis),  
        loss_path=get_loss_path(model_name, axis),
        device=DEVICE
    )


===== Training basic_unet_64_local for axis Y =====
  Epoch: 0
[epoch: 0, batches:     0 -    15] train loss: 0.960
[epoch: 0, batches:    16 -    31] train loss: 0.298
[epoch: 0, batches:    32 -    47] train loss: 0.246
[epoch: 0, batches:    48 -    63] train loss: 0.230
[epoch: 0, batches:    64 -    79] train loss: 0.224
[epoch: 0, batches:    80 -    95] train loss: 0.228
[epoch: 0, batches:    96 -   111] train loss: 0.206
[epoch: 0, batches:   112 -   127] train loss: 0.191
[epoch: 0, batches:   128 -   143] train loss: 0.188
[epoch: 0, batches:   144 -   159] train loss: 0.174
[epoch: 0, batches:   160 -   175] train loss: 0.176
[epoch: 0, batches:   176 -   191] train loss: 0.179
Training in epoch 0 finished. Train loss: 0.1792494971305132]
[Validation in epoch 0 finished. Validation loss: 0.166]
  Epoch: 1
[epoch: 1, batches:     0 -    15] train loss: 0.168
[epoch: 1, batches:    16 -    31] train loss: 0.171
[epoch: 1, batches:    32 -    47] train loss: 0.156
[epoch: 1, b

### 3 x Basic Unet 128x128 + data augmentation

In [4]:
# config params
load_model = False
model_name = 'basic_unet_128_local'
axes = ['X', 'Y', 'Z']

for axis in axes:
    print(f'===== Training {model_name} for axis {axis} =====')

    # networks params
    network = BasicUNET(in_channels=1, classes=8)
    optimizer_name = "adam"
    learning_rate = 0.0005
    loss_fn = nn.CrossEntropyLoss()
    batch_size = 32
    epochs = 15
    img_height = 128
    img_width = 128

    # data
    train_transform = transforms.Compose([
        
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=50),
        transforms.RandomResizedCrop(size=(configs.IMG_ORYG_HEIGHT, configs.IMG_ORYG_WIDTH), scale=(0.6, 1.0)),
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST),
    ])
    val_transform = transforms.Compose([
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST),
    ])
    train_loader = load_BrainTissue_data(configs.DATA_2D_PATH, TRAIN_IDX_START, TRAIN_IDX_STOP, axis, train_transform, batch_size)
    val_loader = load_BrainTissue_data(configs.DATA_2D_PATH, VAL_IDX_START, VAL_IDX_STOP, axis, val_transform, batch_size)

    # save model params
    save_params(
        model_name=model_name,
        axis=axis,
        optimizer_name=optimizer_name,
        loss_fn=str(loss_fn),
        batch_size=batch_size,
        epochs=epochs,
        img_height=img_height,
        img_width=img_width,
        parms_path=get_parms_path(model_name, axis)
    )

    # train network
    train_network(
        train_loader=train_loader,
        val_loader=val_loader, 
        network=network, 
        optimizer_name=optimizer_name, 
        learning_rate=learning_rate, 
        loss_fn=loss_fn, 
        batch_size=batch_size, 
        epochs=epochs, 
        load_model=load_model,
        model_path=get_model_path(model_name, axis),  
        loss_path=get_loss_path(model_name, axis),
        device=DEVICE
    )

===== Training basic_unet_128_local for axis Z =====
Model loaded from checkpoint
  Epoch: 4
[epoch: 4, batches:     0 -    15] train loss: 0.152
[epoch: 4, batches:    16 -    31] train loss: 0.149
[epoch: 4, batches:    32 -    47] train loss: 0.147
[epoch: 4, batches:    48 -    63] train loss: 0.153
[epoch: 4, batches:    64 -    79] train loss: 0.133
[epoch: 4, batches:    80 -    95] train loss: 0.153
[epoch: 4, batches:    96 -   111] train loss: 0.141
[epoch: 4, batches:   112 -   127] train loss: 0.133
[epoch: 4, batches:   128 -   143] train loss: 0.139
[epoch: 4, batches:   144 -   159] train loss: 0.129
Training in epoch 4 finished. Train loss: 0.12945549562573433]
[Validation in epoch 4 finished. Validation loss: 0.117]
  Epoch: 5
[epoch: 5, batches:     0 -    15] train loss: 0.136
[epoch: 5, batches:    16 -    31] train loss: 0.135
[epoch: 5, batches:    32 -    47] train loss: 0.119
[epoch: 5, batches:    48 -    63] train loss: 0.140
[epoch: 5, batches:    64 -    79]

### 3 x Resnet-based Encoder Unet 64 x 64 (Cross Entropy Loss)

In [None]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
!pip install git+https://github.com/LucasFidon/GeneralizedWassersteinDiceLoss.git

In [None]:
from generalized_wasserstein_dice_loss.loss import GeneralizedWassersteinDiceLoss
import segmentation_models_pytorch as smp

In [None]:
# config params
load_model = False
model_name = 'UNetWithResnet34Encoder_64_diceceloss'
axes = ['X', 'Y', 'Z']

for axis in axes:
    print(f'===== Training {model_name} for axis {axis} =====')

    # networks params
    network = smp.Unet("resnet34", encoder_weights="imagenet", activation=None, classes=8, in_channels=1)
    optimizer_name = "adam"
    learning_rate = 0.0005
    loss_fn = nn.CrossEntropyLoss()
    batch_size = 32
    epochs = 20
    img_height = 64
    img_width = 64

    # data
    train_transform = transforms.Compose([
        
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=50),
        transforms.RandomResizedCrop(size=(configs.IMG_ORYG_HEIGHT, configs.IMG_ORYG_WIDTH), scale=(0.6, 1.0)),
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST),
    ])
    val_transform = transforms.Compose([
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST),
    ])
    train_loader = load_BrainTissue_data(configs.DATA_2D_PATH, TRAIN_IDX_START, TRAIN_IDX_STOP, axis, train_transform, batch_size)
    val_loader = load_BrainTissue_data(configs.DATA_2D_PATH, VAL_IDX_START, VAL_IDX_STOP, axis, val_transform, batch_size)

    # save model params
    save_params(
        model_name=model_name,
        axis=axis,
        optimizer_name=optimizer_name,
        loss_fn=str(loss_fn),
        batch_size=batch_size,
        epochs=epochs,
        img_height=img_height,
        img_width=img_width,
        parms_path=get_parms_path(model_name, axis)
    )

    # train network
    train_network(
        train_loader=train_loader,
        val_loader=val_loader, 
        network=network, 
        optimizer_name=optimizer_name, 
        learning_rate=learning_rate, 
        loss_fn=loss_fn, 
        batch_size=batch_size, 
        epochs=epochs, 
        load_model=load_model,
        model_path=get_model_path(model_name, axis),  
        loss_path=get_loss_path(model_name, axis),
        device=DEVICE
    )

### 3 x Resnet-based Encoder Unet 64 x 64 (Dice + Cross Entropy Loss)

In [None]:
class DiceCrossEntropyLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceCrossEntropyLoss, self).__init__()
        dist_mat = np.ones((8,8))  
        dist_mat = dist_mat - np.diag([1]*8)
        self.wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, inputs, targets, smooth=1):
        # dice
        inputs = F.sigmoid(inputs)  
        pred = torch.flatten(inputs, start_dim=-2)
        grnd = torch.flatten(targets, start_dim=-2)
        dice_loss = self.wass_loss(pred, grnd)
        
        # cross entropy
        ce_loss = self.ce_loss(inputs, targets)
        Dice_CE = ce_loss + dice_loss
        
        return Dice_CE


In [None]:
# config params
load_model = False
model_name = 'UNetWithResnet34Encoder_64_diceceloss'
axes = ['X', 'Y', 'Z']

for axis in axes:
    print(f'===== Training {model_name} for axis {axis} =====')

    # networks params
    network = smp.Unet("resnet34", encoder_weights="imagenet", activation=None, classes=8, in_channels=1)
    optimizer_name = "adam"
    learning_rate = 0.0005
    loss_fn = DiceCrossEntropyLoss()
    batch_size = 32
    epochs = 20
    img_height = 64
    img_width = 64

    # data
    train_transform = transforms.Compose([
        
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=50),
        transforms.RandomResizedCrop(size=(configs.IMG_ORYG_HEIGHT, configs.IMG_ORYG_WIDTH), scale=(0.6, 1.0)),
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST),
    ])
    val_transform = transforms.Compose([
        transforms.Resize((img_height, img_width), interpolation=transforms.InterpolationMode.NEAREST),
    ])
    train_loader = load_BrainTissue_data(configs.DATA_2D_PATH, TRAIN_IDX_START, TRAIN_IDX_STOP, axis, train_transform, batch_size)
    val_loader = load_BrainTissue_data(configs.DATA_2D_PATH, VAL_IDX_START, VAL_IDX_STOP, axis, val_transform, batch_size)

    # save model params
    save_params(
        model_name=model_name,
        axis=axis,
        optimizer_name=optimizer_name,
        loss_fn=str(loss_fn),
        batch_size=batch_size,
        epochs=epochs,
        img_height=img_height,
        img_width=img_width,
        parms_path=get_parms_path(model_name, axis)
    )

    # train network
    train_network(
        train_loader=train_loader,
        val_loader=val_loader, 
        network=network, 
        optimizer_name=optimizer_name, 
        learning_rate=learning_rate, 
        loss_fn=loss_fn, 
        batch_size=batch_size, 
        epochs=epochs, 
        load_model=load_model,
        model_path=get_model_path(model_name, axis),  
        loss_path=get_loss_path(model_name, axis),
        device=DEVICE
    )