# Dataset

In [None]:
import shutil
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

In [None]:
data_root = 'C:\\Users\\Admin\\Desktop\\KohanovDiploma\\BraTS\\MICCAI_BraTS2020_TrainingData\\'
val_root = 'C:\\Users\\Admin\\Desktop\\KohanovDiploma\\BraTS\\MICCAI_BraTS2020_ValidationData\\'

In [None]:
def list_of_images(data_root, modal, folder_start=0, folder_finish=len(os.listdir(data_root))):
    images = []
    for fold in tqdm(os.listdir(data_root)[folder_start:folder_finish]):
        if os.path.isdir(data_root + fold):
            for item in os.listdir(data_root + fold + f'\\{modal}\\'):
                images.append(data_root + fold + f'\\{modal}\\' + item)
        else:
            continue
    return images

def create_dataset_paths(images, masks):
    images = np.array(images)
    masks = np.array(masks)
    return np.c_[images, masks]

In [None]:
images = list_of_images(data_root, 't1ce')
masks = list_of_images(data_root, 'seg')

train_paths = create_dataset_paths(images, masks)

In [None]:
images_val = list_of_images(val_root, 't1ce')
masks_val = list_of_images(val_root, 'seg')

val_paths = create_dataset_paths(images_val, masks_val)

In [None]:
len(train_paths)

In [None]:
len(val_paths)

# Pytorch Generator

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from collections import defaultdict
import torch.nn.functional as F
from torch.optim import lr_scheduler
import time
import copy

In [None]:
class BratsDataset(Dataset):
    def __init__(self, files, transform=None):
        self.transform = transform
        self.files = files
                
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, index):        
        image = cv2.imread(self.files[index][0])
        
        mask = cv2.imread(self.files[index][1])
       
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            
            image_transformed = transformed["image"]
            image_transformed = image_transformed/255
            
            mask_transformed = transformed["mask"]
            mask_transformed = mask_transformed/255
            
        return image_transformed, mask_transformed

# Augment Data

In [None]:
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.Resize(256,256,cv2.INTER_NEAREST),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    ToTensorV2(transpose_mask=True)
    
])

val_transform = A.Compose([
    A.Resize(256,256,cv2.INTER_NEAREST),
    ToTensorV2(transpose_mask=True)
])

train_dataset = BratsDataset(train_paths, train_transform)
val_dataset = BratsDataset(val_paths, val_transform)

In [None]:
batch_size = 16

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
#test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

dataloaders = {'train': train_loader,
               'val': val_loader}

In [None]:
sample = next(iter(train_loader))
for i in range(len(sample[0])):
    f, axes = plt.subplots(1, 2)
    f.set_size_inches(12,12)
    image = sample[0][i]
    image = torch.tensor(image)
    image = image.permute(1,2,0)
    axes[0].imshow(image)
    axes[0].set_title('image')
    mask = sample[1][i]
    mask = torch.tensor(mask)
    mask = mask.permute(1,2,0)
    axes[1].set_title('mask')
    axes[1].imshow(mask)

In [None]:
sample = next(iter(val_loader))
for i in range(len(sample[0])):
    f, axes = plt.subplots(1, 2)
    f.set_size_inches(12,12)
    image = sample[0][i]
    image = torch.tensor(image)
    image = image.permute(1,2,0)
    axes[0].imshow(image)
    axes[0].set_title('image')
    mask = sample[1][i]
    mask = torch.tensor(mask)
    mask = mask.permute(1,2,0)
    axes[1].set_title('mask')
    axes[1].imshow(mask)

# Train function, metrics, loss

In [None]:
def dice_score(pred, target, smooth = 1e-6):
    """ This definition generalize to real valued pred and target vector.
        This should be differentiable.
        pred: tensor with first dimension as batch
        target: tensor with first dimension as batch
    """

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(iflat * iflat)
    B_sum = torch.sum(tflat * tflat)
    
    return (2. * intersection + smooth) / (A_sum + B_sum + smooth) 

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [None]:
def calc_loss(pred, target, metrics, bce_weight=0.5, dice_weight=0.5, focal_weight=0):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = torch.sigmoid(pred)
    dice = dice_score(pred, target)
    dice_loss = 1 - dice
    focal_loss = FocalLoss(logits=True)

    loss = bce * bce_weight + dice_loss * dice_weight + focal_weight * focal_loss(pred, target)

    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

def train_model(model, optimizer, scheduler,
                bce_weight=0.5, dice_weight=0.5, focal_weight=0,
                num_epochs=25, isDeepLabV3=False):
    loss_train = []
    dice_train = []
    loss_val = []
    dice_val = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice = 0

    for epoch in range(num_epochs):
        print(' Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 65)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    if isDeepLabV3:
                        outputs = model(inputs)['out']
                    else:
                        outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics, bce_weight, dice_weight, focal_weight)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
            scheduler.step()

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples
            epoch_dice = metrics['dice'] / epoch_samples
            
            if phase == 'train': 
                loss_train.append(epoch_loss)
                dice_train.append(epoch_dice)
            else:
                loss_val.append(epoch_loss)
                dice_val.append(epoch_dice)

            # deep copy the model
            if phase == 'val' and epoch_dice > best_dice:
                print("saving best model")
                best_loss = epoch_loss
                best_dice = epoch_dice
                best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val dice: {:4f}'.format(best_dice))

    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, loss_train, dice_train, loss_val, dice_val

In [None]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# Plotting

In [None]:
def plot_curves(*paths):
    plt.figure(figsize=(16,8))
    plt.grid(True)
    for i in range(len(paths)):
        with open(paths[i], "r") as f:
            values = f.read()
            values = values.split(",")
            values = values[:-1]
            values = list(map(lambda x: float(x), values))
            f.close()
        plt.plot(values, label=paths[i][38:-13])
    plt.legend()
    plt.show()

# Predict

In [None]:
def predict_masks(model, inputs, isDeepLabV3=False):
    inputs = inputs.to(device)
    if isDeepLabV3:
        pred = model(inputs)['out']
    else:
        pred = model(inputs)
    # The loss functions include the sigmoid function.
    pred = torch.sigmoid(pred)
    pred = pred.data.cpu().numpy()
    pred = torch.tensor(pred)
    
    return pred

def plot_predictions(inputs, labels, pred):
    for i in range(len(labels)):
        f, axes = plt.subplots(1, 3)
        f.set_size_inches(12,12)
        image = inputs[i]
        image = image.permute(1,2,0)
        axes[0].imshow(image)
        axes[0].set_title('MRI-photo')
        real_mask = labels[i]
        real_mask = real_mask.permute(1,2,0)
        axes[1].imshow(real_mask)
        axes[1].set_title('real_mask')
        predicted_mask = pred[i]
        predicted_mask = predicted_mask.permute(1,2,0)
        predicted_mask.apply_(lambda x: 1 if x > 0.5 else 0)
        #predicted_mask.apply_(lambda x: 0 if x < 0.1 else x)
        axes[2].set_title('predicted_mask')
        axes[2].imshow(predicted_mask)
        
def show_dices_fot_predict(labels, pred):
    average_dice = 0
    for i in range(len(labels)):
        print(dice_score(labels[i], pred[i]))
        average_dice += dice_score(labels[i], pred[i])
    return average_dice/(len(labels))

# Unet with ResNet18-Encoder 

In [None]:
import torch
import torch.nn as nn
from torchvision import models

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=False)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

# Classic Unet (with dropout)

In [None]:
class double_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True))
        
    def forward(self, x):
        x = self.conv(x)
        return x
        
start_fm = 16

class Unet(nn.Module):
    
    def __init__(self):
        super(Unet, self).__init__()
        
        # Input 256x256x3
        
        #Contracting Path
        
        #(Double) Convolution 1        
        self.double_conv1 = double_conv(3, start_fm, 3, 1, 1)
        #Max Pooling 1
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        
        #Convolution 2
        self.double_conv2 = double_conv(start_fm, start_fm * 2, 3, 1, 1)
        #Max Pooling 2
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        
        #Convolution 3
        self.double_conv3 = double_conv(start_fm * 2, start_fm * 4, 3, 1, 1)
        #Max Pooling 3
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)
        
        #Convolution 4
        self.double_conv4 = double_conv(start_fm * 4, start_fm * 8, 3, 1, 1)
        #Max Pooling 4
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)
        
        #Convolution 5
        self.double_conv5 = double_conv(start_fm * 8, start_fm * 16, 3, 1, 1)
        
        #Transposed Convolution 4
        self.t_conv4 = nn.ConvTranspose2d(start_fm * 16, start_fm * 8, 2, 2)
        # Expanding Path Convolution 4 
        self.ex_double_conv4 = double_conv(start_fm * 16, start_fm * 8, 3, 1, 1)
        
        #Transposed Convolution 3
        self.t_conv3 = nn.ConvTranspose2d(start_fm * 8, start_fm * 4, 2, 2)
        #Convolution 3
        self.ex_double_conv3 = double_conv(start_fm * 8, start_fm * 4, 3, 1, 1)
        
        #Transposed Convolution 2
        self.t_conv2 = nn.ConvTranspose2d(start_fm * 4, start_fm * 2, 2, 2)
        #Convolution 2
        self.ex_double_conv2 = double_conv(start_fm * 4, start_fm * 2, 3, 1, 1)
        
        #Transposed Convolution 1
        self.t_conv1 = nn.ConvTranspose2d(start_fm * 2, start_fm, 2, 2)
        #Convolution 1
        self.ex_double_conv1 = double_conv(start_fm * 2, start_fm, 3, 1, 1)
        
        # One by One Conv
        self.one_by_one = nn.Conv2d(start_fm, 3, 1, 1, 0)
        #self.final_act = nn.Sigmoid()
        
        self.dropout = nn.Dropout(p=0.3)
        
        
    def forward(self, inputs):
        # Contracting Path
        conv1 = self.double_conv1(inputs)
        maxpool1 = self.maxpool1(conv1)

        conv2 = self.double_conv2(maxpool1)
        maxpool2 = self.maxpool2(conv2)
        
        dropout1 = self.dropout(maxpool2)

        conv3 = self.double_conv3(dropout1)
        maxpool3 = self.maxpool3(conv3)

        conv4 = self.double_conv4(maxpool3)
        maxpool4 = self.maxpool4(conv4)
        
        dropout2 = self.dropout(maxpool4)
            
        # Bottom
        conv5 = self.double_conv5(dropout2)
        
        # Expanding Path
        t_conv4 = self.t_conv4(conv5)
        cat4 = torch.cat([conv4 ,t_conv4], 1)
        ex_conv4 = self.ex_double_conv4(cat4)
        
        t_conv3 = self.t_conv3(ex_conv4)
        cat3 = torch.cat([conv3 ,t_conv3], 1)
        ex_conv3 = self.ex_double_conv3(cat3)

        t_conv2 = self.t_conv2(ex_conv3)
        cat2 = torch.cat([conv2 ,t_conv2], 1)
        ex_conv2 = self.ex_double_conv2(cat2)
        
        t_conv1 = self.t_conv1(ex_conv2)
        cat1 = torch.cat([conv1 ,t_conv1], 1)
        ex_conv1 = self.ex_double_conv1(cat1)
        
        one_by_one = self.one_by_one(ex_conv1)
        
        return one_by_one

# Train and save

In [None]:
def train_and_save(model, model_name, bce_weight=0.5, 
                    dice_weight=0.5, focal_weight=0, 
                    num_epochs=25, isDeepLabV3=False):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('Device:', device)
    print('Batch size:', batch_size)

    optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.3)

    model, loss_train, dice_train, loss_val, dice_val = train_model(model, optimizer_ft,
                                                                    exp_lr_scheduler, num_epochs=num_epochs, 
                                                                    isDeepLabV3=isDeepLabV3,
                                                                    bce_weight=bce_weight,
                                                                    dice_weight=dice_weight,
                                                                    focal_weight=focal_weight)

    save_root = f'C:\\Users\\Admin\\Desktop\\KohanovDiploma\\models_v2\\{model_name}\\epochs_{num_epochs}\\bce{bce_weight}_dice{dice_weight}_focal{focal_weight}\\'
    os.makedirs(save_root)
    
    with open(save_root + 'loss_train.txt', 'w') as filehandle:
        filehandle.writelines("%s," % value for value in loss_train)
        filehandle.close()
    with open(save_root + 'dice_train.txt', 'w') as filehandle:
        filehandle.writelines("%s," % value for value in dice_train)
        filehandle.close()
    with open(save_root + 'loss_val.txt', 'w') as filehandle:
        filehandle.writelines("%s," % value for value in loss_val)
        filehandle.close()
    with open(save_root + 'dice_val.txt', 'w') as filehandle:
        filehandle.writelines("%s," % value for value in dice_val)
        filehandle.close()

    torch.save(model.state_dict(),
               save_root + model_name)
    print('MODEL AND LOGS SAVED TO: ' + save_root)

# Experiements with DeepLabV3 (find the best loss)

In [None]:
from torchvision.models.segmentation import deeplabv3_resnet50

BCE_weight = 0.3, DICE_weight = 0.7, FOCAL_weight = 0

In [None]:
deeplabv3_model = deeplabv3_resnet50(num_classes=3)
deeplabv3_model = deeplabv3_model.to(device)
model_name = 'deeplabv3'

train_and_save(deeplabv3_model, model_name, bce_weight=0.3, 
               dice_weight=0.7, focal_weight=0, 
               num_epochs=30, isDeepLabV3=True)

In [None]:
plot_curves('C:\\Users\\Admin\\Desktop\\KohanovDiploma\\models_v2\\unet_aug_affine\\epochs_50\\bce0.3_dice0.7_focal0\\dice_val.txt',
            'C:\\Users\\Admin\\Desktop\\KohanovDiploma\\models_v2\\unet_aug_soft\\epochs_50\\bce0.3_dice0.7_focal0\\dice_val.txt',
            'C:\\Users\\Admin\\Desktop\\KohanovDiploma\\models_v2\\unet_aug_strong\\epochs_50\\bce0.3_dice0.7_focal0\\dice_val.txt')

BCE_weight = 0.5, DICE_weight = 0.5, FOCAL_weight = 0

In [None]:
deeplabv3_model = deeplabv3_resnet50(num_classes=3)
deeplabv3_model = deeplabv3_model.to(device)
model_name = 'deeplabv3'

train_and_save(deeplabv3_model, model_name, bce_weight=0.5, 
               dice_weight=0.5, focal_weight=0, 
               num_epochs=10, isDeepLabV3=True)

BCE_weight = 0.3, DICE_weight = 0.2, FOCAL_weight = 0.5

In [None]:
deeplabv3_model = deeplabv3_resnet50(num_classes=3)
deeplabv3_model = deeplabv3_model.to(device)
model_name = 'deeplabv3'

train_and_save(deeplabv3_model, model_name, bce_weight=0.3, 
               dice_weight=0.2, focal_weight=0.5, 
               num_epochs=10, isDeepLabV3=True)

ЛУЧШИЙ ЛОСС ДЛЯ DEEPLABV3 - bce03_dice02_focal05

In [None]:
unet_model = Unet().to(device)
unet_model.load_state_dict(torch.load('C:\\Users\\Admin\\Desktop\\KohanovDiploma\\models_v2\\unet_aug_affine\\epochs_50\\bce0.3_dice0.7_focal0\\unet_aug_affine'))
unet_model.eval()

In [None]:
inputs, labels = next(iter(val_loader))

inputs = inputs.to(device)

pred = predict_masks(unet_model, inputs, isDeepLabV3=False)

inputs = inputs.data.cpu().numpy()
inputs = torch.tensor(inputs)
labels = labels.to(device)
labels = labels.data.cpu().numpy()
labels = torch.tensor(labels)

plot_predictions(inputs, labels, pred)

In [None]:
average_dice = show_dices_fot_predict(labels, pred)
average_dice

# Experiements with Unet (find the best loss)

BCE_weight = 0.3, DICE_weight = 0.7, FOCAL_weight = 0

In [None]:
unet_model = Unet().to(device)

model_name = 'unet'

train_and_save(unet_model, model_name, bce_weight=0.3, 
               dice_weight=0.7, focal_weight=0, 
               num_epochs=50, isDeepLabV3=False)

BCE_weight = 0.5, DICE_weight = 0.5, FOCAL_weight = 0

In [None]:
unet_model = Unet().to(device)

model_name = 'unet'

train_and_save(unet_model, model_name, bce_weight=0.5, 
               dice_weight=0.5, focal_weight=0, 
               num_epochs=50, isDeepLabV3=False)

BCE_weight = 0.3, DICE_weight = 0.2, FOCAL_weight = 0.5

In [None]:
unet_model = Unet().to(device)

model_name = 'unet'

train_and_save(unet_model, model_name, bce_weight=0.3, 
               dice_weight=0.2, focal_weight=0.5, 
               num_epochs=50, isDeepLabV3=False)

Unet losses on Validation (50 epochs)

In [None]:
plot_curves('models\\unet\\epochs_50\\bce0_dice1_focal0\\loss_val.txt',
            'models\\unet\\epochs_50\\bce0.3_dice0.7_focal0\\loss_val.txt',
            'models\\unet\\epochs_50\\bce0.5_dice0.5_focal0\\loss_val.txt',
            'models\\unet\\epochs_50\\bce0.3_dice0.2_focal0.5\\loss_val.txt')

Unet dices on validation (50 epochs)

In [None]:
plot_curves('models\\unet\\epochs_50\\bce0_dice1_focal0\\dice_val.txt',
            'models\\unet\\epochs_50\\bce0.3_dice0.7_focal0\\dice_val.txt',
            'models\\unet\\epochs_50\\bce0.5_dice0.5_focal0\\dice_val.txt',
            'models\\unet\\epochs_50\\bce0.3_dice0.2_focal0.5\\dice_val.txt')

ЛУЧШИЙ ЛОСС ДЛЯ UNet - bce05_dice05_focal0

# Experiements with TernausNet (find the best loss)

BCE_weight = 0.3, DICE_weight = 0.7, FOCAL_weight = 0

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print('Batch size: ', batch_size)

num_class = 3
resnet_unet_model = ResNetUNet(num_class).to(device)
model_name = 'ternausnet'

train_and_save(resnet_unet_model, model_name, bce_weight=0.3, 
               dice_weight=0.7, focal_weight=0, 
               num_epochs=50, isDeepLabV3=False)

BCE_weight = 0.5, DICE_weight = 0.5, FOCAL_weight = 0

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print('Batch size: ', batch_size)

num_class = 3
resnet_unet_model = ResNetUNet(num_class).to(device)
model_name = 'ternausnet'

train_and_save(resnet_unet_model, model_name, bce_weight=0.5, 
               dice_weight=0.5, focal_weight=0, 
               num_epochs=50, isDeepLabV3=False)

BCE_weight = 0.3, DICE_weight = 0.2, FOCAL_weight = 0.5

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print('Batch size: ', batch_size)

num_class = 3
resnet_unet_model = ResNetUNet(num_class).to(device)
model_name = 'ternausnet'

train_and_save(resnet_unet_model, model_name, bce_weight=0.3, 
               dice_weight=0.2, focal_weight=0.5, 
               num_epochs=50, isDeepLabV3=False)

# Experiements with unet (augmentations)

affine

In [None]:
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.Resize(256,256,cv2.INTER_NEAREST),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    ToTensorV2(transpose_mask=True)
    
])

val_transform = A.Compose([
    A.Resize(256,256,cv2.INTER_NEAREST),
    ToTensorV2(transpose_mask=True)
])

train_dataset = BratsDataset(train_paths, train_transform)
val_dataset = BratsDataset(val_paths, val_transform)

batch_size = 16

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

dataloaders = {'train': train_loader,
               'val': val_loader}

unet_model = Unet().to(device)

model_name = 'unet_aug_affine'

train_and_save(unet_model, model_name, bce_weight=0.3, 
               dice_weight=0.7, focal_weight=0, 
               num_epochs=50, isDeepLabV3=False)

soft

In [None]:
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.Resize(256,256,cv2.INTER_NEAREST),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.RandomGamma(p=0.5),
    ToTensorV2(transpose_mask=True)
])

val_transform = A.Compose([
    A.Resize(256,256,cv2.INTER_NEAREST),
    ToTensorV2(transpose_mask=True)
])


train_dataset = BratsDataset(train_paths, train_transform)
val_dataset = BratsDataset(val_paths, val_transform)

batch_size = 16

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

dataloaders = {'train': train_loader,
               'val': val_loader}

unet_model = Unet().to(device)

model_name = 'unet_aug_soft'

train_and_save(unet_model, model_name, bce_weight=0.3, 
               dice_weight=0.7, focal_weight=0, 
               num_epochs=50, isDeepLabV3=False)

strong

In [None]:
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.Resize(256,256,cv2.INTER_NEAREST),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.1, 
                       rotate_limit=10, p=0.3),
    A.OneOf([
        A.ElasticTransform(p=0.5, alpha=120, 
                         sigma=120*0.02, alpha_affine=120*0.03),
        A.GridDistortion(p=0.3),
        A.OpticalDistortion(p=0.3, distort_limit=2, shift_limit=0.2)
    ], p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.RandomGamma(p=0.5),
    ToTensorV2(transpose_mask=True)
])

val_transform = A.Compose([
    A.Resize(256,256,cv2.INTER_NEAREST),
    ToTensorV2(transpose_mask=True)
])


train_dataset = BratsDataset(train_paths, train_transform)
val_dataset = BratsDataset(val_paths, val_transform)

batch_size = 16

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

dataloaders = {'train': train_loader,
               'val': val_loader}

unet_model = Unet().to(device)

model_name = 'unet_aug_strong'

train_and_save(unet_model, model_name, bce_weight=0.3, 
               dice_weight=0.7, focal_weight=0, 
               num_epochs=50, isDeepLabV3=False)

In [None]:
plot_curves('models_v2\\unet_aug_soft\\epochs_50\\bce0.3_dice0.7_focal0\\dice_val.txt',
            'models_v2\\unet_aug_affine\\epochs_50\\bce0.3_dice0.7_focal0\\dice_val.txt',
            'models_v2\\unet_aug_strong\\epochs_50\\bce0.3_dice0.7_focal0\\dice_val.txt',
            'models_v2\\unet_with_dropout\\epochs_50\\bce0.3_dice0.7_focal0\\dice_val.txt')

Лучшая аугментация - афинная (геометрическая)

# Unet prtedictions

In [None]:
unet_model = Unet().to(device)
unet_model.load_state_dict(torch.load('models_v2\\unet_aug_affine\\epochs_50\\bce0.3_dice0.7_focal0\\unet_aug_affine'))
unet_model.eval()

In [None]:
inputs, labels = next(iter(val_loader))

inputs = inputs.to(device)

pred = predict_masks(unet_model, inputs, isDeepLabV3=False)

inputs = inputs.data.cpu().numpy()
inputs = torch.tensor(inputs)
labels = labels.to(device)
labels = labels.data.cpu().numpy()
labels = torch.tensor(labels)

plot_predictions(inputs, labels, pred)

In [None]:
average_dice = show_dices_fot_predict(labels, pred)
average_dice

In [None]:
unet_model = Unet().to(device)
unet_model.load_state_dict(torch.load('models_v2\\unet_with_dropout\\epochs_50\\bce0.3_dice0.7_focal0\\unet_with_dropout'))
unet_model.eval()

In [None]:
from torchvision.models.segmentation import deeplabv3_resnet50

deeplabv3_model = deeplabv3_resnet50(num_classes=3)
deeplabv3_model = deeplabv3_model.to(device)
deeplabv3_model.load_state_dict(torch.load('models\\deeplabv3\\epochs_30\\bce0.5_dice0.5_focal0\\deeplabv3'))
deeplabv3_model.eval()

In [None]:
num_class = 3
ternaus_model = ResNetUNet(num_class).to(device)

ternaus_model.load_state_dict(torch.load('models_v2\\ternausnet\\epochs_20\\bce0.3_dice0.7_focal0\\ternausnet'))
ternaus_model.eval()

In [None]:
inputs, labels = next(iter(val_loader))

inputs = inputs.to(device)
inputs = inputs.data.cpu().numpy()
inputs = torch.tensor(inputs)
labels = labels.to(device)
labels = labels.data.cpu().numpy()
labels = torch.tensor(labels)

f, axes = plt.subplots(1, 5)
f.set_size_inches(100,100)
image = inputs[1]
image = image.permute(1,2,0)
axes[0].imshow(image)
real_mask = labels[1]
real_mask = real_mask.permute(1,2,0)
axes[1].imshow(real_mask)

pred_unet = predict_masks(unet_model, inputs, isDeepLabV3=False)
unet_mask = pred_unet[1]
unet_mask = unet_mask.permute(1,2,0)
unet_mask.apply_(lambda x: 1 if x > 0.5 else 0)
axes[2].imshow(unet_mask)

pred_deeplab = predict_masks(deeplabv3_model, inputs, isDeepLabV3=True)
deeplab_mask = pred_deeplab[1]
deeplab_mask = deeplab_mask.permute(1,2,0)
deeplab_mask.apply_(lambda x: 1 if x > 0.5 else 0)
axes[3].imshow(deeplab_mask)

pred_ternaus = predict_masks(ternaus_model, inputs, isDeepLabV3=False)
ternaus_mask = pred_ternaus[1]
ternaus_mask = ternaus_mask.permute(1,2,0)
ternaus_mask.apply_(lambda x: 1 if x > 0.5 else 0)
axes[4].imshow(ternaus_mask)

# pred_pix2pix = cv2.imread('/home/podaval/Kochanov_diploma/pytorch-CycleGAN-and-pix2pix/results/brats_pix2pix/test_latest/images/56_fake_B.png')
# axes[5].imshow(pred_pix2pix)

In [None]:
inputs, labels = next(iter(val_loader))

inputs = inputs.to(device)
inputs = inputs.data.cpu().numpy()
inputs = torch.tensor(inputs)
labels = labels.to(device)
labels = labels.data.cpu().numpy()
labels = torch.tensor(labels)

f, axes = plt.subplots(1, 5)
f.set_size_inches(100,100)
image = inputs[1]
image = image.permute(1,2,0)
axes[0].imshow(image)

real_mask = labels[1]
real_mask = real_mask.permute(1,2,0)
axes[1].imshow(real_mask)

pred_unet = predict_masks(unet_model, inputs, isDeepLabV3=False)
unet_mask = pred_unet[1]
unet_mask = unet_mask.permute(1,2,0)
unet_mask.apply_(lambda x: 1 if x > 0.5 else 0)
axes[2].imshow(unet_mask)

pred_deeplab = predict_masks(deeplabv3_model, inputs, isDeepLabV3=True)
deeplab_mask = pred_deeplab[1]
deeplab_mask = deeplab_mask.permute(1,2,0)
deeplab_mask.apply_(lambda x: 1 if x > 0.5 else 0)
axes[3].imshow(deeplab_mask)

pred_ternaus = predict_masks(ternaus_model, inputs, isDeepLabV3=False)
ternaus_mask = pred_ternaus[1]
ternaus_mask = ternaus_mask.permute(1,2,0)
ternaus_mask.apply_(lambda x: 1 if x > 0.5 else 0)
axes[4].imshow(ternaus_mask)

# pred_pix2pix = cv2.imread('/home/podaval/Kochanov_diploma/pytorch-CycleGAN-and-pix2pix/results/brats_pix2pix/test_latest/images/56_fake_B.png')
# axes[5].imshow(pred_pix2pix)