In [1]:
1

1

In [2]:
import os
import math
import copy
import time
import random
import pprint
import tqdm
import platform
from datetime import datetime

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torchvision import utils
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import KFold
# from torchinfo import summary

import cv2
import nibabel as nib
import skimage.transform as skTrans
from numpy import logical_and as l_and, logical_not as l_not
from scipy.spatial.distance import directed_hausdorff

%matplotlib inline

%load_ext autoreload
%autoreload 2

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
# Set random seed for reproduciablity
torch.manual_seed(42)
random.seed(42)

In [5]:
class ScaleToFixed(object):

    def __init__(self, new_shape, interpolation=1, channels=4):
        self.shape= new_shape
        self.interpolation = interpolation
        self.channels = channels

    def __call__(self, image):
        # print('first shape', image.shape)
        if image is not None: # (some patients don't have segmentations)
            if self.channels == 1:
                short_shape = (self.shape[1], self.shape[2], self.shape[3])
                image = skTrans.resize(image, short_shape, order=self.interpolation, preserve_range=True)  #
                image = image.reshape(self.shape)
            else:
                image = skTrans.resize(image, self.shape, order=self.interpolation, preserve_range=True)  #

        # print('second shape', image.shape)
        # print()
        return image

class RandomFlip(object):
    """Randomly flips (horizontally as well as vertically) the given PIL.Image with a probability of 0.5
    """
    def __init__(self, prob_flip=0.5):
        self.prob_flip= prob_flip
    def __call__(self, image):

        if random.random() < self.prob_flip:
            flip_type = np.random.randint(0, 3) # flip across any 3D axis
            image = np.flip(image, flip_type)
        return image

class ZeroChannel(object):
    """Randomly sets channel to zero the given PIL.Image with a probability of 0.25
    """
    def __init__(self, prob_zero=0.25, channels=4):
        self.prob_zero= prob_zero
        self.channels = channels
    def __call__(self, image):

        if np.random.random() < self.prob_zero:
            channel_to_zero = np.random.randint(0, self.channels) # flip across any 3D axis
            zeros = np.zeros((image.shape[1], image.shape[2], image.shape[3]))
            image[channel_to_zero, :, :, :] = zeros
        return image

class ZeroSprinkle(object):
    def __init__(self, prob_zero=0.25, prob_true=0.5, channels=4):
        self.prob_zero=prob_zero
        self.prob_true=prob_true
        self.channels=channels
    def __call__(self, image):

        if self.prob_true:
            mask = np.random.rand(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
            mask[mask < self.prob_zero] = 0
            mask[mask > 0] = 1
            image = image*mask

        return image


class MinMaxNormalize(object):
    """Min-Max normalization
    """
    def __call__(self, image):
        def norm(im):
            im = im.astype(np.float32)
            min_v = np.min(im)
            max_v = np.max(im)
            im = (im - min_v)/(max_v - min_v)
            return im
        image = norm(image)
        return image

class ToTensor(object):
    def __init__(self, scale=1):
        self.scale = scale

    def __call__(self, image):
        if image is not None:
            image = image.astype(np.float32)
            image = image.reshape((image.shape[0], int(image.shape[1]/self.scale), int(image.shape[2]/self.scale), int(image.shape[3]/self.scale)))
            image_tensor = torch.from_numpy(image)
            return image_tensor
        else:
            return image


class Compose(object):
    """
    Composes several transforms together.
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image):
        for i, t in enumerate(self.transforms):
            image = t(image)
        return image

In [6]:
def get_bb_3D(img, pad=0):
    '''
    This function returns a tumor 3D bounding box using a segmentation mask
    '''
    xs = np.nonzero(np.sum(np.sum(img, axis=1), axis=1))
    ys = np.nonzero(np.sum(np.sum(img, axis=0), axis=1))
    zs = np.nonzero(np.sum(np.sum(img, axis=0), axis=0))
    xmin, xmax = np.min(xs), np.max(xs)
    ymin, ymax = np.min(ys), np.max(ys)
    zmin, zmax = np.min(zs), np.max(zs)
    bbox = (xmin-pad, ymin-pad, zmin-pad, xmax+pad, ymax+pad, zmax+pad)
    return bbox

def min_max(img):
    '''
    Min-max normalization
    '''
    return (img - img.min()) / (img.max() - img.min())

def read_mri(mr_path_dict, pad=0):

    image_shape = nib.load(mr_path_dict['flair']).get_fdata().shape
    bb_seg = get_bb_3D(nib.load(mr_path_dict['flair']).get_fdata())
    (xmin, ymin, zmin, xmax, ymax, zmax) = bb_seg

    xmin = np.max([0, xmin-pad])
    ymin = np.max([0, ymin-pad])
    zmin = np.max([0, zmin-pad])

    xmax = np.min([image_shape[0]-1, xmax+pad])
    ymax = np.min([image_shape[1]-1, ymax+pad])
    zmax = np.min([image_shape[2]-1, zmax+pad])


    img_dict = {}
    for key in ['flair', 't1', 't1ce', 't2', 'seg']:
        img = nib.load(mr_path_dict[key])
        img_data = img.get_fdata()
        img_dict[key] = img_data[xmin:xmax, ymin:ymax, zmin:zmax]

    stacked_img = np.stack([min_max(img_dict['flair']), min_max(img_dict['t1']),min_max(img_dict['t1ce']),min_max(img_dict['t2'])], axis=0)
    return stacked_img, img_dict['seg']


In [7]:
def plot_(image, seg, predicted=False):
    #Overlay with Predicted
    img = image[slice, :, :, :].squeeze()
    img = utils.make_grid(img)
    img = img.detach().cpu().numpy()
    
    print(img.shape)
    
    # plot images
    plt.figure(figsize=(10, 8))
    img_list = [img[i].T for i in range(channels)] # 1 image per channel
    plt.imshow(np.hstack(img_list), cmap='Greys_r')
    
    ## plot segmentation mask ##
    seg_img = torch.tensor(pred[slice].squeeze())
    if not predicted:
        seg_img = torch.tensor(seg_img.numpy()[:, ::-1].copy()) #flip
    seg_img = utils.make_grid(seg_img).detach().cpu().numpy()
    
    print(np.unique(seg_img))

    plt.imshow(np.hstack([seg_img[0].T]), cmap='Greys_r', alpha=0.3)
    plt.show()
    

In [8]:
class GeneralDataset(Dataset):

    def __init__(self,
                metadata_df,
                root_dir,
                transform=None,
                seg_transform=None, ###
                dataformat=None, # indicates what shape (or content) should be returned (2D or 3D, etc.)
                returndims=None, # what size/shape 3D volumes should be returned as.
                output_shape=None,
                visualize=False,
                modality=None,
                pad=2,
                device='cpu'):
        """
        Args:
            metadata_df (string): Path to the csv file w/ patient IDs
            root_dir (string): Directory for MR images
            transform (callable, optional)
        """
        self.device=device
        self.metadata_df = metadata_df
        self.root_dir = root_dir
        self.transform = transform
        self.seg_transform = seg_transform
        self.returndims=returndims
        self.modality = modality
        self.pad = pad
        self.output_shape = output_shape

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        #print(type(idx), idx)
        if torch.is_tensor(idx):
            idx = idx.tolist()

        BraTS20ID = self.metadata_df.iloc[idx].BraTS_2020_subject_ID

        # make dictonary of paths to MRI volumnes (modalities) and segmenation masks
        mr_path_dict = {}
        sequence_type = ['seg', 't1', 't1ce', 'flair', 't2']
        for seq in sequence_type:
            mr_path_dict[seq] = os.path.join(self.root_dir, BraTS20ID, BraTS20ID + '_'+seq+'.nii.gz')

        image, seg_image = read_mri(mr_path_dict=mr_path_dict, pad=self.pad)
        
        if seg_image is not None:
            if self.output_shape == 'wt':
                seg_image[np.nonzero(seg_image)] = 1 #only 0's and 1's for background and tumor
            else:
                seg_image[seg_image == 4] = 3 #0,1,2,3 for background and tumor regions

        if self.transform:
            image = self.transform(image)
        if self.seg_transform:
            seg_image = self.seg_transform(seg_image)
        else:
            print('no transform')
        # print(image.shape)
        return (image, seg_image), BraTS20ID

In [9]:
def read_dataframe(params):
    if platform.system() == 'Windows':
        naming = pd.read_csv(f"{params['image_dir']}\\name_mapping.csv")
    else:
        naming = pd.read_csv(f"{params['image_dir']}/name_mapping.csv")
    
    data_df = pd.DataFrame(naming['BraTS_2020_subject_ID'])

    # n_patients_to_train_with
    total_num_patients = len(data_df)

    assert sum(params['tr_va_te_split']) == 100
    tr_split = int((total_num_patients * params['tr_va_te_split'][0]) / 100)
    va_split = int((total_num_patients * params['tr_va_te_split'][1]) / 100)
    te_split = total_num_patients - (tr_split + va_split)

    print(f"Data is split into train: {tr_split}, validation: {va_split} and test: {te_split}")
                   
    train_df = data_df[: tr_split]
    valid_df = data_df[tr_split : (tr_split + va_split)]
    test_df = data_df[(tr_split + va_split) :]

    return train_df, valid_df, test_df

In [10]:
def retrieve_dataset(df, train=False):

    image_dir, channels, resize_shape, output_shape = params['image_dir'], \
                                                      params['channels'], \
                                                      params['resize_shape'], \
                                                      params['output_shape']

    # basic data augmentation
    prob_voxel_zero = 0 # 0.1
    prob_channel_zero = 0 # 0.5
    prob_true = 0 # 0.8
    randomflip = RandomFlip()

    # MRI transformations
    train_transformations = Compose([
        MinMaxNormalize(),
        ScaleToFixed((channels, resize_shape[0], resize_shape[1],
                      resize_shape[2]), interpolation=1, channels=channels),
        ZeroSprinkle(prob_zero=prob_voxel_zero, prob_true=prob_true),
        ZeroChannel(prob_zero=prob_channel_zero),
        randomflip,
        ToTensor()
        ])
    
    val_transformations = Compose([
            MinMaxNormalize(),
            ScaleToFixed((channels, resize_shape[0], resize_shape[1],
                          resize_shape[2]), interpolation=1, channels=channels),
            ToTensor(),
        ])

    # GT segmentation mask transformations
    seg_transformations = Compose([
        ScaleToFixed((1, resize_shape[0], resize_shape[1],
                      resize_shape[2]), interpolation=0, channels=1),
        randomflip,
        ToTensor(),
        ])
    
    if train:
        dataset = GeneralDataset(metadata_df=df, 
                                root_dir=image_dir,
                                transform=train_transformations,
                                seg_transform=seg_transformations,
                                returndims=resize_shape,
                                output_shape=output_shape)
    else:
        dataset = GeneralDataset(metadata_df=df, 
                                root_dir=image_dir,
                                transform=val_transformations,
                                seg_transform=seg_transformations,
                                returndims=resize_shape,
                                output_shape=output_shape)
    return dataset

In [11]:
def get_data(params):

    train_df, valid_df, test_df = read_dataframe(params)

    train_dataset = retrieve_dataset(train_df, train=True)
    
    valid_dataset = retrieve_dataset(valid_df)

    test_dataset = retrieve_dataset(test_df)
    

    train_loader = DataLoader(train_dataset, batch_size=params['train_batch_size'], shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=params['train_batch_size'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=params['test_batch_size'])

    return train_loader, valid_loader, test_loader

In [12]:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, norm='b', num_groups=2, k_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv3d = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
                                stride=stride, padding=padding)
        if norm == 'b':
            self.norm = nn.BatchNorm3d(num_features=out_channels)
        else:
            # use only one group if the given number of groups is greater than the number of channels
            if out_channels < num_groups:
                num_groups = 1
            assert out_channels % num_groups == 0, f'Expected out_channels{out_channels} in input to be divisible by num_groups{num_groups}'
            self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)

    def forward(self, x):
        x = self.norm(self.conv3d(x))
        x = F.elu(x) #!
        return x

class ConvTranspose(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1, output_padding=1):
        super(ConvTranspose, self).__init__()
        self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels,
                                                   out_channels=out_channels,
                                                   kernel_size=k_size,
                                                   stride=stride,
                                                   padding=padding,
                                                   output_padding=output_padding)

    def forward(self, x):
        return self.conv3d_transpose(x)


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, init_features, norm='b', num_groups=2, model_depth=4, pool_size=2):
        super(EncoderBlock, self).__init__()
        self.root_feat_maps = init_features
        self.num_conv_blocks = 2
        self.module_dict = nn.ModuleDict()
        for depth in range(model_depth):
            feat_map_channels = 2 ** (depth + 1) * self.root_feat_maps
            for i in range(self.num_conv_blocks):
                self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels, norm=norm, num_groups=num_groups)
                self.module_dict["conv_{}_{}".format(depth, i)] = self.conv_block
                in_channels, feat_map_channels = feat_map_channels, feat_map_channels * 2
            if depth == model_depth - 1:
                break
            else:
                self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2, padding=0)
                self.module_dict["max_pooling_{}".format(depth)] = self.pooling

    def forward(self, x):
        down_sampling_features = []
        for k, op in self.module_dict.items():
            if k.startswith("conv"):
                x = op(x)
                #print(k, x.shape)
                if k.endswith("1"):
                    down_sampling_features.append(x)
            elif k.startswith("max_pooling"):
                x = op(x)
                #print(k, x.shape)

        return x, down_sampling_features


class DecoderBlock(nn.Module):
    def __init__(self, out_channels, init_features, norm, num_groups=2, model_depth=4):
        super(DecoderBlock, self).__init__()
        self.num_conv_blocks = 2
        self.num_feat_maps = init_features
        # user nn.ModuleDict() to store ops
        self.module_dict = nn.ModuleDict()

        for depth in range(model_depth - 2, -1, -1):
            # print(depth)
            feat_map_channels = 2 ** (depth + 1) * self.num_feat_maps
            # print(feat_map_channels * 4)
            self.deconv = ConvTranspose(in_channels=feat_map_channels * 4, out_channels=feat_map_channels * 4)
            self.module_dict["deconv_{}".format(depth)] = self.deconv
            for i in range(self.num_conv_blocks):
                if i == 0:
                    self.conv = ConvBlock(in_channels=feat_map_channels * 6, out_channels=feat_map_channels * 2, norm=norm, num_groups=num_groups)
                    self.module_dict["conv_{}_{}".format(depth, i)] = self.conv
                else:
                    self.conv = ConvBlock(in_channels=feat_map_channels * 2, out_channels=feat_map_channels * 2, norm=norm, num_groups=num_groups)
                    self.module_dict["conv_{}_{}".format(depth, i)] = self.conv
            if depth == 0:
                self.final_conv = ConvBlock(in_channels=feat_map_channels * 2, out_channels=out_channels, norm=norm, num_groups=num_groups)
                self.module_dict["final_conv"] = self.final_conv

    def forward(self, x, down_sampling_features):
        """
        :param x: inputs
        :param down_sampling_features: feature maps from encoder path
        :return: output
        """
        for k, op in self.module_dict.items():
            if k.startswith("deconv"):
                x = op(x)
                #print(k, x.shape)
                x = torch.cat((down_sampling_features[int(k[-1])], x), dim=1)
            elif k.startswith("conv"):
                x = op(x)
                #print(k, x.shape)
            else:
                x = op(x)
                #print(k, x.shape)
        return x


In [13]:
class UnetModel(nn.Module):

    def __init__(self, in_channels, out_channels, init_features, norm, num_groups=2, model_depth=4, final_activation="sigmoid"):
        super(UnetModel, self).__init__()
        self.encoder = EncoderBlock(in_channels=in_channels,
                                    init_features=init_features,
                                    norm=norm, num_groups=num_groups,
                                    model_depth=model_depth)
        self.decoder = DecoderBlock(out_channels=out_channels,
                                    init_features=init_features,
                                    norm=norm, num_groups=num_groups,
                                    model_depth=model_depth)
        if final_activation == "sigmoid":
            self.sigmoid = nn.Sigmoid()
        else:
            self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x, downsampling_features = self.encoder(x)
        x = self.decoder(x, downsampling_features)
        x = self.sigmoid(x)
        # print("Final output shape: ", x.shape)
        return x



In [14]:

class DiceLoss(nn.Module):
    def __init__(self, epsilon=1e-5):
        super(DiceLoss, self).__init__()
        # smooth factor
        self.epsilon = epsilon

    def forward(self, logits, targets):
        batch_size = targets.size(0)
        # log_prob = torch.sigmoid(logits)
        logits = logits.view(batch_size, -1).type(torch.FloatTensor)
        targets = targets.view(batch_size, -1).type(torch.FloatTensor)
        intersection = (logits * targets).sum(-1)
        dice_score = 2. * intersection / ((logits + targets).sum(-1) + self.epsilon)
        # dice_score = 1 - dice_score.sum() / batch_size
        dice_score = torch.mean(1. - dice_score)
        dice_score.requires_grad = True
        return dice_score

In [15]:
class EDiceLoss(nn.Module):
    """Dice loss tailored to Brats need.
    """

    def __init__(self, do_sigmoid=True):
        super(EDiceLoss, self).__init__()
        self.do_sigmoid = do_sigmoid
        self.labels = ["ET", "TC", "WT"]
        self.device = "cpu"

    def binary_dice(self, inputs, targets, label_index, metric_mode=False):
        smooth = 1.
        if self.do_sigmoid:
            inputs = torch.sigmoid(inputs)

        if metric_mode:
            inputs = inputs > 0.5
            if targets.sum() == 0:
                print(f"No {self.labels[label_index]} for this patient")
                if inputs.sum() == 0:
                    return torch.tensor(1., device=device)
                else:
                    return torch.tensor(0., device=device)
            # Threshold the pred
        intersection = EDiceLoss.compute_intersection(inputs, targets)
        if metric_mode:
            dice = (2 * intersection) / ((inputs.sum() + targets.sum()) * 1.0)
        else:
            dice = (2 * intersection + smooth) / (inputs.pow(2).sum() + targets.pow(2).sum() + smooth)
        if metric_mode:
            return dice
        return 1 - dice

    @staticmethod
    def compute_intersection(inputs, targets):
        intersection = torch.sum(inputs * targets)
        return intersection

    def forward(self, inputs, target):
        dice = 0
        for i in range(target.size(1)):
            dice = dice + self.binary_dice(inputs[:, i, ...], target[:, i, ...], i)
        final_dice = dice / target.size(1)
        final_dice.requires_grad = True
        return final_dice

    def metric(self, inputs, target):
        dices = []
        for j in range(target.size(0)):
            dice = []
            for i in range(target.size(1)):
                dice.append(self.binary_dice(inputs[j, i], target[j, i], i, True))
            dices.append(dice)
        return dices

In [16]:
def calculate_metrics(preds, targets, patient):
    """
    Parameters
    ----------
    preds:
        torch tensor of size 1*C*Z*Y*X, ours BS*Z*Y*X 
    targets:
        torch tensor of same shape
    patient :
        The patient ID
    """

    assert preds.shape == targets.shape, "Preds and targets do not have the same size"
    pp = pprint.PrettyPrinter(indent=4)
    
    preds, targets = preds.detach().cpu().numpy(), targets.detach().cpu().numpy()

    metrics_list = []

    metrics = dict(
        patient_id=patient,
    )
    # print(targets.shape, targets.dtype, targets)
    
    if np.sum(targets) == 0:
        print(f"{label} not present for {patient}")
    else:
        tp = np.sum(l_and(preds, targets))
        tn = np.sum(l_and(l_not(preds), l_not(targets)))
        fp = np.sum(l_and(preds, l_not(targets)))
        fn = np.sum(l_and(l_not(preds), targets))

        sens = tp / (tp + fn)
        spec = tn / (tn + fp)
        acc = (tn + tp) / (tn + tp + fn + fp)
        dice = 2 * tp / (2 * tp + fp + fn)

    metrics[DICE] = dice
    metrics[ACC] = acc
    metrics[SENS] = sens
    metrics[SPEC] = spec
    # pp.pprint(metrics)
    metrics_list.append(metrics)

    return acc, dice, metrics_list


DICE = "dice"
ACC = "acc"
SENS = "sens"
SPEC = "spec"
METRICS = [DICE, ACC, SENS, SPEC]


In [17]:
def calc_dice(preds, targets):
    return (2 * torch.sum(preds * targets)) / ((preds.sum() + preds.sum()) * 1.0)
    

In [18]:


def plot_metric(train, label, metric_name):
    # Plot losses
    plt.figure(figsize=(10,8))
    plt.semilogy(train, label=label)
    plt.xlabel('Epoch')
    plt.ylabel(metric_name)
    plt.legend()
    plt.title(f'Model {metric_name} Plot')
    plt.savefig(f'Model_{metric_name}_{label}_Plot.png')
    plt.show()
    plt.clf()

def plot_result(kfolds, num_epochs, fold_train_history, fold_valid_history):
    final_fold = {'train_loss':[],'valid_loss':[],'train_acc':[],'valid_acc':[]}

    for epoch in range(num_epochs):                                      
        final_fold['train_loss'].append(np.mean([fold_train_history[str(fold)]['train_loss'][epoch] for fold in range(kfolds)]))
        final_fold['train_acc'].append(np.mean([fold_train_history[str(fold)]['train_acc'][epoch]for fold in range(kfolds)]))

    plot_metric(final_fold['train_loss'], 'train', 'Loss')
    plot_metric(final_fold['train_acc'], 'validation', 'Accuracy')

    final_fold['valid_loss'].append([fold_valid_history[str(fold)]['valid_loss'] for fold in range(kfolds)])
    final_fold['valid_acc'].append([fold_valid_history[str(fold)]['valid_acc'] for fold in range(kfolds)])

    print(final_fold)

In [19]:
def load_checkpoint(path, params):

    test_model = UnetModel(params['pretrain_in_channels'],
                           params['pretrain_out_channels'],
                           params['init_features'], params['norm'],
                           params['num_groups'],
                           )

    test_optimizer = torch.optim.AdamW(test_model.parameters(),
                                       lr=params['learning_rate'],
                                       )
    
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        checkpoint = torch.load(path)
    else:
        checkpoint = torch.load(path, map_location=torch.device('cpu'))

    test_model.load_state_dict(checkpoint['model_state_dict'])
    test_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    return test_model, test_optimizer, epoch, loss

In [20]:
def save_model_folds(model, optimizer, fold, epoch, loss):
    # Saving the model
    save_path = f'model-fold-{fold}.pth'

    checkpoint = {'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': loss,
                  }
    torch.save(checkpoint, save_path)


In [21]:
def save_model_nofolds(model, optimizer, epoch, loss, params):
    now = datetime.now()
    # dd/mm/YY H:M:S
    dt_string = now.strftime("%d_%H_%M")

    # Saving the model
    save_path = f"model_{params['output_shape']}_{params['run_name']}_{dt_string}.pth"

    checkpoint = {'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': loss,
                  'params': params,
                  }
    torch.save(checkpoint, save_path)

In [22]:
def kFoldRunAll(criterion, dataset, params): 
                
    k_folds, num_epochs, train_batch_size = params['k_folds'],\
                                            params['no_epochs'],\
                                            params['train_batch_size']
    
    use_cuda, loss_name, in_channels, out_channels = params['use_cuda'], \
                                                     params['loss_name'],\
                                                     params['in_channels'],\
                                                     params['out_channels'],

    init_features, learning_rate, norm, num_groups = params['init_features'],\
                                                     params['learning_rate'],\
                                                     params['norm'], \
                                                     params['num_groups'],

    loss_function = criterion

    # Define the K-fold Cross Validator
    kfold = KFold(n_splits=k_folds, shuffle=True)

    fold_train_history = {}
    fold_valid_history = {}
    fold_train_and_valid_acc = {}
    fold_train_and_valid_loss = {}

    print('--------------------------------')
    # K-fold Cross Validation model evaluation
    for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
        print(f'FOLD {fold}')
        print('--------------------------------')
        # Sample elements randomly from a given list of ids, no replacement.
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        valid_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

        # Define data loaders for training and testing data in this fold
        dataloader_train = DataLoader(dataset, batch_size=train_batch_size, sampler=train_subsampler, num_workers=0)
        dataloader_valid = DataLoader(dataset, batch_size=train_batch_size, sampler=valid_subsampler, num_workers=0)

        # Initialize optimizer and Model
        model = UnetModel(in_channels=in_channels, out_channels=out_channels,
                          init_features=init_features, norm=norm, num_groups=num_groups)
        if use_cuda:
            model = model.cuda()
        #print(model)
        optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

        # Run the training, testing and saving loop for defined number of epochs
        start_time = time.time()

        t_loss, t_acc, t_history, v_loss, v_acc, v_history = train_test_folds(model, loss_function, optimizer,
                                                                               dataloader_train, dataloader_valid,
                                                                               fold, num_epochs, use_cuda,
                                                                               loss_name)

        end_time = time.time()
        print(f"Epoch Time: {end_time - start_time}")

        #Saving loss results 
        fold_train_and_valid_loss[str(fold)] = [t_loss, v_loss]
        fold_train_and_valid_acc[str(fold)] = [t_acc, v_acc]
        fold_train_history[str(fold)] = t_history
        fold_valid_history[str(fold)] = v_history

        # Print accuracy
        print(f'Accuracy for fold {fold}: {v_acc}')
        print(f'Loss for fold {fold}: {v_loss}')
        print('--------------------------------')  

    return fold_train_history, fold_valid_history, fold_train_and_valid_loss, fold_train_and_valid_acc
  


In [23]:
def train_test_folds(model, loss_function, optimizer, dataloader_train, dataloader_valid, fold, num_epochs, use_cuda, loss_name):
    train_history = {'train_loss': [], 'train_acc':[], 'train_dice':[]}
    valid_history = {'valid_loss': [], 'valid_acc':[], 'valid_dice':[]}
    best = math.inf

    edice = EDiceLoss()
    if use_cuda:
        edice = edice.cuda()
    metric = edice.metric
        
    for epoch in range(num_epochs):
        print(f'Starting Train epoch: {epoch+1}')

        train_loss = 0.0
        train_acc, train_dice = 0, 0
        model.train()

        for i, data in enumerate(tqdm.tqdm(dataloader_train)):
            (inputs, targets), ID = data
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            #print(inputs.shape, outputs.shape, targets.squeeze(1).long().shape)

            if loss_name == 'dice':
                class_outputs = outputs.argmax(dim=1)
                loss = loss_function(class_outputs, targets.squeeze(1).long())
                # print(np.unique(class_outputs.detach().numpy()), class_outputs.shape, targets.squeeze(1).shape)
            else:
                loss = loss_function(outputs, targets.squeeze(1).long())

            train_loss += loss.item() * outputs.size(0) #multiplying by batchsize
            
            rtrain_dice1 = calc_dice(outputs.argmax(dim=1), targets.squeeze(1))
            rtrain_acc, rtrain_dice2, _ = calculate_metrics(outputs.argmax(dim=1), targets.squeeze(1), ID)
            print(f'Train Dice 1: {rtrain_dice1}, 2: {rtrain_dice2} \t Acc: {rtrain_acc}')
            train_acc += rtrain_acc
            train_dice += rtrain_dice1
            
            print(f'Train Loss :{loss.item()}')

            loss.backward()
            optimizer.step()
           
        train_history['train_loss'].append(train_loss / len(dataloader_train.sampler))
        train_history['train_acc'].append(train_acc / len(dataloader_train.sampler))
        train_history['train_dice'].append(train_dice / len(dataloader_train.sampler))

        print(f"Train Epoch loss: {train_history['train_loss'][-1]}, \t ACC/DICE :{train_history['train_acc'][-1]}/{train_history['train_dice'][-1]} ")

    valid_loss = 0.0
    valid_acc, valid_dice = 0, 0
           
    model.eval()
    #! maybe change later to validate after some epochs
    with torch.no_grad():
        # Iterate over the test data and generate predictions
        for i, data in enumerate(tqdm.tqdm(dataloader_valid)):
            (inputs, targets), ID = data
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() 
            outputs = model(inputs)

            if params['loss_name'] == 'dice':
                class_outputs = outputs.argmax(dim=1)
                loss = loss_function(class_outputs, targets.squeeze(1).long())
                # print(np.unique(class_outputs.detach().numpy()), class_outputs.shape, targets.squeeze(1).shape)
            else:
                loss = loss_function(outputs, targets.squeeze(1).long())

            # print('Valid Loss:', loss.item())
            valid_loss += loss.item() * inputs.size(0)
            
            rvalid_dice1 = calc_dice(outputs.argmax(dim=1), targets.squeeze(1))

            rvalid_acc, rvalid_dice2, _ = calculate_metrics(outputs.argmax(dim=1), targets.squeeze(1), ID)
            # print(f'Val Dice 1: {rvalid_dice1}, 2: {rvalid_dice2}')
            valid_acc += rvalid_acc
            valid_dice += rvalid_dice1
            
        # Print accuracy
        print(f'Val Dice : {valid_dice}, len {len(dataloader_valid.sampler)}')
        valid_loss /= len(dataloader_valid.sampler) 
        valid_acc = valid_acc / len(dataloader_valid.sampler)
        valid_dice = valid_dice / len(dataloader_valid.sampler)
        #print(f" Fold Accuracy: {valid_acc}")

    valid_history['valid_loss'].append(valid_loss)
    valid_history['valid_acc'].append(valid_acc)
    valid_history['valid_dice'].append(valid_dice)

    print(f"Val Epoch loss: {valid_history['valid_loss'][-1]} \t acc/dice:/ {valid_history['valid_acc'][-1]}/ {valid_history['valid_dice'][-1]}")

    # saving best model for this fold
    if valid_loss < best:
        best = valid_loss
        save_model_folds(model, optimizer, fold, epoch, loss)
    
    
    return train_history['train_loss'][-1], train_history['train_acc'][-1], train_history, valid_history['valid_loss'][-1], valid_history['valid_acc'][-1], valid_history


In [24]:
def train_test_nofolds(model, loss_function, optimizer, dataloader_train, dataloader_valid, params):
    train_history = {'train_loss': [], 'train_acc':[], 'train_dice':[]}
    valid_history = {'valid_loss': [], 'valid_acc':[], 'valid_dice':[]}
    best = math.inf

    edice = EDiceLoss()
    if params['use_cuda']:
        edice = edice.cuda()
    metric = edice.metric

    #Automatic mixed precision addition
    #scaler = torch.cuda.amp.GradScaler()
        
    for epoch in range(params['no_epochs']):
        print(f'Starting Train epoch: {epoch+1}')

        train_loss = 0.0
        train_acc, train_dice = 0, 0
        model.train()

        for i, data in enumerate(tqdm.tqdm(dataloader_train)):
            (inputs, targets), ID = data
            if params['use_cuda']:
                inputs, targets = inputs.cuda(), targets.cuda()

            optimizer.zero_grad()
            #with torch.cuda.amp.autocast(enabled=params['autocast']):
            outputs = model(inputs)
            #print(inputs.shape, outputs.shape, targets.squeeze(1).long().shape)

            if params['loss_name'] == 'dice':
                class_outputs = outputs.argmax(dim=1)
                loss = loss_function(class_outputs, targets.squeeze(1).long())
                # print(np.unique(class_outputs.detach().numpy()), class_outputs.shape, targets.squeeze(1).shape)
            else:
                loss = loss_function(outputs, targets.squeeze(1).long())

            train_loss += loss.item() * outputs.size(0) # multiplying by batchsize
            
            rtrain_dice1 = calc_dice(outputs.argmax(dim=1), targets.squeeze(1))
            rtrain_acc, rtrain_dice2, _ = calculate_metrics(outputs.argmax(dim=1), targets.squeeze(1), ID)

            if i % 10 == 0:
                print(f'Train Dice 1: {rtrain_dice1}, 2: {rtrain_dice2} \t Acc: {rtrain_acc}')

            train_acc += rtrain_acc
            train_dice += rtrain_dice1
            if i % 10 == 0:
                print(f'Train Loss :{loss.item()}')
            
            #scaler.scale(loss).backward()
            #scaler.step(optimizer)
            #scaler.update()
            loss.backward()
            optimizer.step()
           
        train_history['train_loss'].append(train_loss / len(dataloader_train.sampler))
        train_history['train_acc'].append(train_acc / len(dataloader_train.sampler))
        train_history['train_dice'].append(train_dice / len(dataloader_train.sampler))

        print(f"Train Epoch loss: {train_history['train_loss'][-1]}, \t ACC/DICE :{train_history['train_acc'][-1]}/{train_history['train_dice'][-1]} ")
          
        if epoch % (.1 * params['no_epochs']) == 0:
            valid_loss = 0.0
            valid_acc, valid_dice = 0, 0
            model.eval()

            with torch.no_grad():
                # Iterate over the test data and generate predictions
                for i, data in enumerate(tqdm.tqdm(dataloader_valid)):
                    (inputs, targets), ID = data
                    if params['use_cuda']:
                        inputs, targets = inputs.cuda(), targets.cuda() 
                      
                    #with torch.cuda.amp.autocast(enabled=params['autocast']):
                    outputs = model(inputs)

                    if params['loss_name'] == 'dice':
                        class_outputs = outputs.argmax(dim=1)
                        loss = loss_function(class_outputs, targets.squeeze(1).long())
                        # print(np.unique(class_outputs.detach().numpy()), class_outputs.shape, targets.squeeze(1).shape)
                    else:
                        loss = loss_function(outputs, targets.squeeze(1).long())

                    # print('Valid Loss:', loss.item())
                    valid_loss += loss.item() * inputs.size(0)
                    
                    rvalid_dice1 = calc_dice(outputs.argmax(dim=1), targets.squeeze(1))

                    rvalid_acc, rvalid_dice2, _ = calculate_metrics(outputs.argmax(dim=1), targets.squeeze(1), ID)
                    # print(f'Val Dice 1: {rvalid_dice1}, 2: {rvalid_dice2}')
                    valid_acc += rvalid_acc
                    valid_dice += rvalid_dice1
                    
                # Print accuracy
                print(f'Val Dice : {valid_dice}, len {len(dataloader_valid.sampler)}')
                valid_loss /= len(dataloader_valid.sampler) 
                valid_acc = valid_acc / len(dataloader_valid.sampler)
                valid_dice = valid_dice / len(dataloader_valid.sampler)

            valid_history['valid_loss'].append(valid_loss)
            valid_history['valid_acc'].append(valid_acc)
            valid_history['valid_dice'].append(valid_dice)

            print(f"Val Epoch loss: {valid_history['valid_loss'][-1]} \t acc/dice:/ {valid_history['valid_acc'][-1]}/ {valid_history['valid_dice'][-1]}")

            # saving best model for this fold
            if valid_loss < best:
                best = valid_loss
                save_model_nofolds(model, optimizer, epoch, loss, params)
    
    
    return train_history['train_loss'][-1], train_history['train_acc'][-1], train_history, valid_history['valid_loss'][-1], valid_history['valid_acc'][-1], valid_history


In [25]:
def kfolds(params):
    
    if params['loss_name'] == 'ce':
        criterion = CrossEntropyLoss()
    elif params['loss_name'] == 'wce':
        wisdom_weights = [1, 355.36116969, 74.37872817, 254.58104099]
        nick_weights = [ 1.        ,  8.9263424 ,  7.79622053, 31.17438108]
        criterion = CrossEntropyLoss(weight=torch.Tensor(nick_weights))
    else:
        criterion = EDiceLoss().cuda()
      
    if params['use_cuda']:
        criterion.cuda()
    
    train_df, valid_df, _ = read_dataframe(params)

    train_valid_df = pd.concat([train_df, valid_df])

    train_valid_dataset = retrieve_dataset(train_valid_df)

    t_history, v_history, tv_loss, tv_acc = kFoldRunAll(criterion, 
                                                        train_valid_dataset,
                                                        params)
    
    plot_result(k_folds, no_epochs, t_history, v_history)

    return t_history, v_history, tv_loss, tv_acc

In [26]:
def transfer(path, params):
    
    if params['loss_name'] == 'ce':
        criterion = CrossEntropyLoss()
    elif params['loss_name'] == 'wce':
        criterion = CrossEntropyLoss(weight=torch.Tensor([1, 355.36116969, 74.37872817, 254.58104099]))
    else:
        criterion = EDiceLoss().cuda()
    
    model, optimizer, _, __ = load_checkpoint(path, params)

    # Transfer by changing(replacing) only last layer and finetuning to outdim=4
    model.decoder.final_conv = ConvBlock(in_channels=params['pretrain_in_final_conv'], 
                                          out_channels=params['out_channels'],
                                          norm=params['norm'],
                                          num_groups=params['num_groups'])

    model.decoder.module_dict.final_conv = ConvBlock(in_channels=params['pretrain_in_final_conv'], 
                                          out_channels=params['out_channels'],
                                          norm=params['norm'],
                                          num_groups=params['num_groups'])

    if params['use_cuda']:
        model.cuda()
        criterion.cuda()
    
    dataloader_train, dataloader_valid, _ = get_data(params)
    
    t_loss, t_acc, t_history, v_loss, v_acc, v_history = train_test_nofolds(
                                                model, criterion, optimizer,
                                                dataloader_train, dataloader_valid,
                                                params
                                                )

    # plot_result(k_folds, params['no_epochs'], t_history, v_history)

    return t_loss, t_acc, t_history, v_loss, v_acc, v_history

In [27]:
def learn_from_scratch(params):
    
    if params['loss_name'] == 'ce':
        criterion = CrossEntropyLoss()
    elif params['loss_name'] == 'wce':
        criterion = CrossEntropyLoss(weight=torch.Tensor([1, 355.36116969, 74.37872817, 254.58104099]))
    else:
        criterion = EDiceLoss()
    


    model = UnetModel(params['in_channels'],
                      params['out_channels'],
                      params['init_features'],
                      params['norm'],
                      params['num_groups'],
                      )

    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=params['learning_rate'],
                                  )

    if params['use_cuda']:
        model.cuda()
        criterion.cuda()
    
    dataloader_train, dataloader_valid, _ = get_data(params)
    
    t_loss, t_acc, t_history, v_loss, v_acc, v_history = train_test_nofolds(
                                                model, criterion, optimizer,
                                                dataloader_train, dataloader_valid,
                                                params
                                                )

    # plot_result(k_folds, params['no_epochs'], t_history, v_history)

    return t_loss, t_acc, t_history, v_loss, v_acc, v_history

In [28]:
def get_params():

    if platform.system() == 'Windows':
        image_dir = r"C:\Users\wisdomik\Documents\project\MICCAI_BraTS2020_TrainingData"
    else:
        image_dir = '../../data/data/mri/MICCAI_BraTS2020_TrainingData/'
    
    use_cuda = torch.cuda.is_available()

    params = {'run_name': 'run_1',
              'in_channels': 4,
              'out_channels': 4,
              'no_epochs': 20,
              'k_folds': 5,
              'learning_rate': 5e-4, # 1e-4,
              'loss_name': 'wce',
              'output_shape': 'all',
              'tr_va_te_split': [75, 25, 0],
              'pretrain_in_channels': 4,
              'pretrain_out_channels': 2,
              'pretrain_in_final_conv': 16,
              'init_features': 8,
              'train_batch_size': 1,
              'autocast': False, # not in use
              'test_batch_size': 1,
              'norm': 'g',
              'num_groups': 4,
              'channels': 4,
              'resize_shape': (144, 144, 144), # 128
              'image_dir': image_dir,
              'use_cuda': use_cuda
              }
              
    return params



In [30]:
# set params then run this

#load in trained model for evaluation
# fold_to_check = 9
# PATH = f'drive/MyDrive/Colab Notebooks/model-fold-{fold_to_check}.pth'

params = get_params()
params

{'run_name': 'run_1',
 'in_channels': 4,
 'out_channels': 4,
 'no_epochs': 20,
 'k_folds': 5,
 'learning_rate': 0.0005,
 'loss_name': 'wce',
 'output_shape': 'all',
 'tr_va_te_split': [75, 25, 0],
 'pretrain_in_channels': 4,
 'pretrain_out_channels': 2,
 'pretrain_in_final_conv': 16,
 'init_features': 8,
 'train_batch_size': 1,
 'autocast': False,
 'test_batch_size': 1,
 'norm': 'g',
 'num_groups': 4,
 'channels': 4,
 'resize_shape': (144, 144, 144),
 'image_dir': '../../data/data/mri/MICCAI_BraTS2020_TrainingData/',
 'use_cuda': True}

In [None]:
# Run learn from scratch (UNCOMMENT TO RUN)

t_loss, t_acc, t_history, v_loss, v_acc, v_history = learn_from_scratch(params)

  0%|          | 0/276 [00:00<?, ?it/s]

Data is split into train: 276, validation: 92 and test: 1
Starting Train epoch: 1
Train Dice 1: 0.025612566620111465, 2: 0.04134624607639819 	 Acc: 0.22347474065500686
Train Loss :1.4356637001037598


  4%|▍         | 11/276 [00:49<19:16,  4.37s/it]

Train Dice 1: 0.04484144598245621, 2: 0.04583972751076452 	 Acc: 0.32659250685871055
Train Loss :1.5656408071517944


  8%|▊         | 21/276 [01:35<21:54,  5.15s/it]

Train Dice 1: 0.2602895498275757, 2: 0.1457800531548134 	 Acc: 0.2697770651148834
Train Loss :1.2823047637939453


 11%|█         | 31/276 [02:17<16:19,  4.00s/it]

Train Dice 1: 0.18175466358661652, 2: 0.12217224938362517 	 Acc: 0.3692276984739369
Train Loss :1.2618675231933594


 15%|█▍        | 41/276 [03:00<16:33,  4.23s/it]

Train Dice 1: 0.24503342807292938, 2: 0.17270304096122435 	 Acc: 0.38746255840620714
Train Loss :1.26203191280365


 18%|█▊        | 51/276 [03:42<15:21,  4.10s/it]

Train Dice 1: 0.025027401745319366, 2: 0.027662905549296242 	 Acc: 0.48958802190500683
Train Loss :1.3376550674438477


 22%|██▏       | 61/276 [04:26<14:49,  4.14s/it]

Train Dice 1: 0.20477481186389923, 2: 0.1346265987266496 	 Acc: 0.39267792459705075
Train Loss :1.198861002922058


 26%|██▌       | 71/276 [05:11<14:33,  4.26s/it]

Train Dice 1: 0.05887152627110481, 2: 0.0923110919297697 	 Acc: 0.2378941079389575
Train Loss :1.4800654649734497


 29%|██▉       | 81/276 [05:52<12:50,  3.95s/it]

Train Dice 1: 0.2516501843929291, 2: 0.17371050757934553 	 Acc: 0.5588666918509945
Train Loss :1.296236515045166


 33%|███▎      | 91/276 [06:39<15:09,  4.92s/it]

Train Dice 1: 0.10976696014404297, 2: 0.08725294563111914 	 Acc: 0.5709243586033951
Train Loss :1.1908483505249023


 37%|███▋      | 101/276 [07:23<13:36,  4.67s/it]

Train Dice 1: 0.08324765413999557, 2: 0.08450888983787337 	 Acc: 0.5699471263074417
Train Loss :1.2389973402023315


 40%|████      | 111/276 [08:05<11:03,  4.02s/it]

Train Dice 1: 0.03660077974200249, 2: 0.03002300921416027 	 Acc: 0.5348129795739026
Train Loss :1.3444405794143677


 44%|████▍     | 121/276 [08:49<12:35,  4.87s/it]

Train Dice 1: 0.06960099190473557, 2: 0.1286918751426536 	 Acc: 0.5781136134687929
Train Loss :1.3693088293075562


 47%|████▋     | 131/276 [09:35<11:01,  4.56s/it]

Train Dice 1: 0.04361986368894577, 2: 0.04236256569219014 	 Acc: 0.44802617830504116
Train Loss :1.340135931968689


 51%|█████     | 141/276 [10:20<09:47,  4.35s/it]

Train Dice 1: 0.22386065125465393, 2: 0.176724858763672 	 Acc: 0.5993719323345336
Train Loss :1.334633469581604


 55%|█████▍    | 151/276 [11:03<09:54,  4.76s/it]

Train Dice 1: 0.03826045244932175, 2: 0.04810582835226549 	 Acc: 0.5721078880529835
Train Loss :1.3024797439575195


 58%|█████▊    | 161/276 [11:47<08:07,  4.24s/it]

Train Dice 1: 0.3015596866607666, 2: 0.22634078200147514 	 Acc: 0.6423939311128258
Train Loss :1.1875072717666626


 62%|██████▏   | 171/276 [12:27<06:57,  3.98s/it]

Train Dice 1: 0.025298211723566055, 2: 0.024192793601452687 	 Acc: 0.55655991458762
Train Loss :1.3353639841079712


 66%|██████▌   | 181/276 [13:11<06:56,  4.38s/it]

Train Dice 1: 0.07488054782152176, 2: 0.06712615454131972 	 Acc: 0.5796246731395748
Train Loss :1.2962720394134521


 69%|██████▉   | 191/276 [13:54<06:04,  4.28s/it]

Train Dice 1: 0.05659623071551323, 2: 0.08512300218022008 	 Acc: 0.5716587898662552
Train Loss :1.2953600883483887


 73%|███████▎  | 201/276 [14:40<06:06,  4.89s/it]

Train Dice 1: 0.14607951045036316, 2: 0.13795470144131777 	 Acc: 0.589463305898491
Train Loss :1.3400055170059204


 76%|███████▋  | 211/276 [15:22<04:24,  4.07s/it]

Train Dice 1: 0.07948766648769379, 2: 0.11207326320983713 	 Acc: 0.590217831039952
Train Loss :1.330883502960205


 80%|████████  | 221/276 [16:05<03:42,  4.05s/it]

Train Dice 1: 0.22275646030902863, 2: 0.2226019650206448 	 Acc: 0.6489763508444787
Train Loss :1.198224663734436


 84%|████████▎ | 231/276 [16:48<03:18,  4.42s/it]

Train Dice 1: 0.05495570972561836, 2: 0.045874751642744 	 Acc: 0.5812140989368999
Train Loss :1.194611668586731


 87%|████████▋ | 241/276 [17:28<02:28,  4.24s/it]

Train Dice 1: 0.008762557990849018, 2: 0.00938763014966136 	 Acc: 0.517757295417524
Train Loss :1.3486034870147705


 91%|█████████ | 251/276 [18:09<01:50,  4.40s/it]

Train Dice 1: 0.060977786779403687, 2: 0.0686034479159522 	 Acc: 0.581607938957476
Train Loss :1.280258059501648


 95%|█████████▍| 261/276 [18:49<00:59,  3.95s/it]

Train Dice 1: 0.10280238091945648, 2: 0.09020203181905309 	 Acc: 0.6026087212791496
Train Loss :1.336519479751587


 98%|█████████▊| 271/276 [19:30<00:21,  4.20s/it]

Train Dice 1: 0.2011374831199646, 2: 0.17527629594615887 	 Acc: 0.6318801440329218
Train Loss :1.151358723640442


100%|██████████| 276/276 [19:52<00:00,  4.32s/it]
  0%|          | 0/92 [00:00<?, ?it/s]

Train Epoch loss: 1.2970670500527257, 	 ACC/DICE :0.5110606354336732/0.11756820976734161 


100%|██████████| 92/92 [06:30<00:00,  4.24s/it]


Val Dice : 10.687841415405273, len 92
Val Epoch loss: 1.2212159361528314 	 acc/dice:/ 0.5991750341654177/ 0.116172194480896


  0%|          | 0/276 [00:00<?, ?it/s]

Starting Train epoch: 2


  0%|          | 1/276 [00:03<15:42,  3.43s/it]

Train Dice 1: 0.12764644622802734, 2: 0.10902627672595641 	 Acc: 0.5875229070216049
Train Loss :1.3156973123550415


  4%|▍         | 11/276 [00:38<15:43,  3.56s/it]

Train Dice 1: 0.08103632926940918, 2: 0.08936985489363504 	 Acc: 0.5966441883144719
Train Loss :1.1962621212005615


  8%|▊         | 21/276 [01:15<15:39,  3.68s/it]

Train Dice 1: 0.12764474749565125, 2: 0.1312088920079816 	 Acc: 0.6111161345807613
Train Loss :1.2037479877471924


 11%|█         | 31/276 [01:53<15:30,  3.80s/it]

Train Dice 1: 0.03200957924127579, 2: 0.036242230105894585 	 Acc: 0.6138867455418381
Train Loss :1.325472116470337


 15%|█▍        | 41/276 [02:31<15:03,  3.85s/it]

Train Dice 1: 0.1765252947807312, 2: 0.13508241139940663 	 Acc: 0.6048843530306928
Train Loss :1.1847271919250488


 18%|█▊        | 51/276 [03:09<14:23,  3.84s/it]

Train Dice 1: 0.10857203602790833, 2: 0.10405215754976098 	 Acc: 0.6211446544924554
Train Loss :1.3132230043411255


 22%|██▏       | 61/276 [03:48<13:46,  3.85s/it]

Train Dice 1: 0.10438639670610428, 2: 0.12099078953775391 	 Acc: 0.6070357376328875
Train Loss :1.1853020191192627


 26%|██▌       | 71/276 [04:26<13:07,  3.84s/it]

Train Dice 1: 0.15807104110717773, 2: 0.2559194116352432 	 Acc: 0.6739145286779835
Train Loss :1.1372126340866089


 29%|██▉       | 81/276 [05:05<12:33,  3.86s/it]

Train Dice 1: 0.23200677335262299, 2: 0.19812771017939756 	 Acc: 0.6296597704475309
Train Loss :1.305237054824829


 33%|███▎      | 91/276 [05:44<11:56,  3.87s/it]

Train Dice 1: 0.16071465611457825, 2: 0.1780377718178505 	 Acc: 0.6170552152992113
Train Loss :1.2612611055374146


 37%|███▋      | 101/276 [06:22<11:15,  3.86s/it]

Train Dice 1: 0.05371512472629547, 2: 0.04803452435497558 	 Acc: 0.6096439230752744
Train Loss :1.1879056692123413


 40%|████      | 111/276 [07:01<10:38,  3.87s/it]

Train Dice 1: 0.045875389128923416, 2: 0.036592739312321 	 Acc: 0.578435115526406
Train Loss :1.325801134109497


 44%|████▍     | 121/276 [07:39<09:58,  3.86s/it]

Train Dice 1: 0.24962204694747925, 2: 0.19377168162561043 	 Acc: 0.6285073865097737
Train Loss :1.1404602527618408


 47%|████▋     | 131/276 [08:18<09:22,  3.88s/it]

Train Dice 1: 0.12211979925632477, 2: 0.1183220794427104 	 Acc: 0.6408982767489712
Train Loss :1.3038662672042847


 51%|█████     | 141/276 [08:57<08:42,  3.87s/it]

Train Dice 1: 0.02604909986257553, 2: 0.02385678484468652 	 Acc: 0.5960095566486625
Train Loss :1.2795342206954956


 55%|█████▍    | 151/276 [09:35<08:03,  3.86s/it]

Train Dice 1: 0.019579054787755013, 2: 0.02007707231013126 	 Acc: 0.5781169624485597
Train Loss :1.3551989793777466


 58%|█████▊    | 161/276 [10:14<07:25,  3.87s/it]

Train Dice 1: 0.16179072856903076, 2: 0.17550235934292713 	 Acc: 0.6493085696373457
Train Loss :1.1518036127090454


 62%|██████▏   | 171/276 [10:52<06:44,  3.85s/it]

Train Dice 1: 0.13935935497283936, 2: 0.11614697567771047 	 Acc: 0.6375064300411523
Train Loss :1.2370798587799072


 66%|██████▌   | 181/276 [11:31<06:07,  3.86s/it]

Train Dice 1: 0.009681533090770245, 2: 0.010952637028356783 	 Acc: 0.5873172796639232
Train Loss :1.322075366973877


 69%|██████▉   | 191/276 [12:10<05:28,  3.86s/it]

Train Dice 1: 0.07612819969654083, 2: 0.07360801464100014 	 Acc: 0.6628203634044925
Train Loss :1.3070441484451294


 73%|███████▎  | 201/276 [12:48<04:50,  3.87s/it]

Train Dice 1: 0.05731318145990372, 2: 0.06451498467476445 	 Acc: 0.6467469350137174
Train Loss :1.253334641456604


 76%|███████▋  | 211/276 [13:27<04:12,  3.88s/it]

Train Dice 1: 0.10656635463237762, 2: 0.07337714365245653 	 Acc: 0.5875125251843278
Train Loss :1.1839444637298584


 80%|████████  | 221/276 [14:06<03:32,  3.87s/it]

Train Dice 1: 0.18965715169906616, 2: 0.22170611199219226 	 Acc: 0.6581535600994513
Train Loss :1.176035761833191


 84%|████████▎ | 231/276 [14:45<02:54,  3.89s/it]

Train Dice 1: 0.16157536208629608, 2: 0.12611992784768417 	 Acc: 0.6367311412251372
Train Loss :1.1206719875335693


 87%|████████▋ | 241/276 [15:23<02:15,  3.86s/it]

Train Dice 1: 0.14219099283218384, 2: 0.1396718925117244 	 Acc: 0.6223499523105281
Train Loss :1.3298652172088623


 91%|█████████ | 251/276 [16:02<01:36,  3.87s/it]

Train Dice 1: 0.07858806848526001, 2: 0.10541902610721278 	 Acc: 0.6018230506258574
Train Loss :1.3801623582839966


 95%|█████████▍| 261/276 [16:41<00:57,  3.86s/it]

Train Dice 1: 0.14131391048431396, 2: 0.13009317798590905 	 Acc: 0.6366581334662208
Train Loss :1.2586435079574585


 98%|█████████▊| 271/276 [17:19<00:19,  3.88s/it]

Train Dice 1: 0.2160724401473999, 2: 0.22593346808470446 	 Acc: 0.6569241496270576
Train Loss :1.132326364517212


100%|██████████| 276/276 [17:39<00:00,  3.84s/it]
  0%|          | 0/276 [00:00<?, ?it/s]

Train Epoch loss: 1.260661724684895, 	 ACC/DICE :0.6235798796906689/0.13054293394088745 
Starting Train epoch: 3


  0%|          | 1/276 [00:03<15:01,  3.28s/it]

Train Dice 1: 0.28551849722862244, 2: 0.23706981921992065 	 Acc: 0.6632999373070988
Train Loss :1.1461542844772339


  4%|▍         | 11/276 [00:41<16:58,  3.85s/it]

Train Dice 1: 0.03422345593571663, 2: 0.04578372094416846 	 Acc: 0.5733754768947188
Train Loss :1.2304340600967407


  8%|▊         | 21/276 [01:20<16:28,  3.88s/it]

Train Dice 1: 0.3699691891670227, 2: 0.25183550131022925 	 Acc: 0.6603705847050755
Train Loss :1.2900198698043823


 11%|█         | 31/276 [01:58<15:46,  3.87s/it]

Train Dice 1: 0.1446310132741928, 2: 0.14837338670038244 	 Acc: 0.6106519659850823
Train Loss :1.1792303323745728


 15%|█▍        | 41/276 [02:37<15:06,  3.86s/it]

Train Dice 1: 0.06100082769989967, 2: 0.062360737379011204 	 Acc: 0.6090689032493142
Train Loss :1.3305118083953857


 18%|█▊        | 51/276 [03:16<14:32,  3.88s/it]

Train Dice 1: 0.23813626170158386, 2: 0.19264845975750294 	 Acc: 0.6372733410493827
Train Loss :1.3225902318954468


 22%|██▏       | 61/276 [03:54<13:53,  3.88s/it]

Train Dice 1: 0.12175139039754868, 2: 0.11611608950794398 	 Acc: 0.6213777434842249
Train Loss :1.3393275737762451


 26%|██▌       | 71/276 [04:33<13:12,  3.87s/it]

Train Dice 1: 0.19531121850013733, 2: 0.13590852105166884 	 Acc: 0.6756904926483196
Train Loss :1.1292613744735718


 29%|██▉       | 81/276 [05:12<12:38,  3.89s/it]

Train Dice 1: 0.2640267312526703, 2: 0.19884437726222579 	 Acc: 0.6499746817129629
Train Loss :1.329275131225586


 33%|███▎      | 91/276 [05:51<11:57,  3.88s/it]

Train Dice 1: 0.1431044638156891, 2: 0.17190688885071292 	 Acc: 0.6771362472136488
Train Loss :1.35334050655365


 37%|███▋      | 101/276 [06:29<11:14,  3.86s/it]

Train Dice 1: 0.16481131315231323, 2: 0.14176394429649591 	 Acc: 0.6450232151277435
Train Loss :1.237870216369629


 40%|████      | 111/276 [07:08<10:41,  3.89s/it]

Train Dice 1: 0.07078944891691208, 2: 0.07076928812169254 	 Acc: 0.6326979648919753
Train Loss :1.2388259172439575


 44%|████▍     | 121/276 [07:47<09:59,  3.87s/it]

Train Dice 1: 0.0422416590154171, 2: 0.037289610464545084 	 Acc: 0.5865088359482168
Train Loss :1.2222312688827515


 47%|████▋     | 131/276 [08:26<09:20,  3.86s/it]

Train Dice 1: 0.22610566020011902, 2: 0.20450945715941857 	 Acc: 0.6392127352323388
Train Loss :1.310491919517517


 51%|█████     | 141/276 [09:04<08:44,  3.89s/it]

Train Dice 1: 0.32695063948631287, 2: 0.2658984926285559 	 Acc: 0.6507258578746571
Train Loss :1.2864934206008911


 55%|█████▍    | 151/276 [09:43<08:06,  3.89s/it]

Train Dice 1: 0.24745990335941315, 2: 0.18063286998928837 	 Acc: 0.6459672925240055
Train Loss :1.3480421304702759


 58%|█████▊    | 161/276 [10:22<07:26,  3.88s/it]

Train Dice 1: 0.15328574180603027, 2: 0.12861606156357414 	 Acc: 0.6381507737482853
Train Loss :1.1181141138076782


 62%|██████▏   | 171/276 [11:01<06:47,  3.88s/it]

Train Dice 1: 0.08692029863595963, 2: 0.08018765124377066 	 Acc: 0.6183670106738683
Train Loss :1.1670843362808228


 66%|██████▌   | 181/276 [11:40<06:09,  3.89s/it]

Train Dice 1: 0.15469259023666382, 2: 0.13007876744723995 	 Acc: 0.6334987059542181
Train Loss :1.3364366292953491


 69%|██████▉   | 191/276 [12:19<05:30,  3.89s/it]

Train Dice 1: 0.09946002811193466, 2: 0.1090788049465023 	 Acc: 0.6262156796553497
Train Loss :1.152700662612915


 73%|███████▎  | 201/276 [12:57<04:50,  3.87s/it]

Train Dice 1: 0.18762889504432678, 2: 0.2665027116023534 	 Acc: 0.6669874989283264
Train Loss :1.2278339862823486


 76%|███████▋  | 211/276 [13:36<04:11,  3.87s/it]

Train Dice 1: 0.08106313645839691, 2: 0.07882018484946558 	 Acc: 0.6291966065457819
Train Loss :1.2736802101135254


 80%|████████  | 221/276 [14:15<03:33,  3.89s/it]

Train Dice 1: 0.04751935601234436, 2: 0.04614780166566692 	 Acc: 0.6048545471107681
Train Loss :1.2784249782562256


 84%|████████▎ | 231/276 [14:54<02:54,  3.89s/it]

Train Dice 1: 0.03497398644685745, 2: 0.03814971686199989 	 Acc: 0.6023223835090878
Train Loss :1.2882729768753052


 87%|████████▋ | 241/276 [15:33<02:16,  3.89s/it]

Train Dice 1: 0.03681754320859909, 2: 0.0310271093791206 	 Acc: 0.6246955777391975
Train Loss :1.2889667749404907


 91%|█████████ | 251/276 [16:12<01:37,  3.88s/it]

Train Dice 1: 0.1639164686203003, 2: 0.12587998871585834 	 Acc: 0.6305760513117284
Train Loss :1.143691062927246


 95%|█████████▍| 261/276 [16:50<00:58,  3.89s/it]

Train Dice 1: 0.20893150568008423, 2: 0.14620025396288247 	 Acc: 0.641739875364369
Train Loss :1.1364750862121582


 98%|█████████▊| 271/276 [17:29<00:19,  3.86s/it]

Train Dice 1: 0.13580353558063507, 2: 0.15616174832240848 	 Acc: 0.6509599515603567
Train Loss :1.1915090084075928


100%|██████████| 276/276 [17:49<00:00,  3.87s/it]
  0%|          | 0/92 [00:00<?, ?it/s]

Train Epoch loss: 1.263738748388014, 	 ACC/DICE :0.6286282798257613/0.13062070310115814 


100%|██████████| 92/92 [05:09<00:00,  3.36s/it]


Val Dice : 11.47465991973877, len 92
Val Epoch loss: 1.2110050825969032 	 acc/dice:/ 0.6425074360455525/ 0.12472456693649292


  0%|          | 0/276 [00:00<?, ?it/s]

Starting Train epoch: 4


  0%|          | 1/276 [00:03<15:15,  3.33s/it]

Train Dice 1: 0.046421341598033905, 2: 0.0608318223322356 	 Acc: 0.6365924934627915
Train Loss :1.3444098234176636


  4%|▍         | 11/276 [00:38<15:45,  3.57s/it]

Train Dice 1: 0.17537479102611542, 2: 0.21876026682529137 	 Acc: 0.6384515121313443
Train Loss :1.1861534118652344


  8%|▊         | 21/276 [01:16<16:02,  3.77s/it]

Train Dice 1: 0.05478128045797348, 2: 0.06879928018374257 	 Acc: 0.6707333327974966
Train Loss :1.336034893989563


 11%|█         | 31/276 [01:54<15:38,  3.83s/it]

Train Dice 1: 0.14948683977127075, 2: 0.14390677999920187 	 Acc: 0.6407743644975995
Train Loss :1.3242474794387817


 15%|█▍        | 41/276 [02:33<15:04,  3.85s/it]

Train Dice 1: 0.2812322974205017, 2: 0.2043862253285373 	 Acc: 0.6526615681798696
Train Loss :1.2989312410354614


 18%|█▊        | 51/276 [03:11<14:30,  3.87s/it]

Train Dice 1: 0.14975519478321075, 2: 0.11487469356410787 	 Acc: 0.6788442268947188
Train Loss :1.232873558998108


 22%|██▏       | 61/276 [03:50<13:51,  3.87s/it]

Train Dice 1: 0.10835190117359161, 2: 0.09215898662225717 	 Acc: 0.6067571025162894
Train Loss :1.3888652324676514


 26%|██▌       | 71/276 [04:29<13:15,  3.88s/it]

Train Dice 1: 0.07865556329488754, 2: 0.09560733506833738 	 Acc: 0.6318905258701989
Train Loss :1.3114343881607056


 28%|██▊       | 77/276 [04:52<12:47,  3.86s/it]

In [None]:
# torch.cuda.empty_cache()