In [1]:
1

1

In [2]:
import os
import math
import copy
import time
import random
import pprint
import tqdm
import platform
from datetime import datetime

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torchvision import utils
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import KFold
# from torchinfo import summary

import cv2
import nibabel as nib
import skimage.transform as skTrans
from numpy import logical_and as l_and, logical_not as l_not
from scipy.spatial.distance import directed_hausdorff

%matplotlib inline

%load_ext autoreload
%autoreload 2

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
# Set random seed for reproduciablity
torch.manual_seed(42)
random.seed(42)

In [5]:
class ScaleToFixed(object):

    def __init__(self, new_shape, interpolation=1, channels=4):
        self.shape= new_shape
        self.interpolation = interpolation
        self.channels = channels

    def __call__(self, image):
        # print('first shape', image.shape)
        if image is not None: # (some patients don't have segmentations)
            if self.channels == 1:
                short_shape = (self.shape[1], self.shape[2], self.shape[3])
                image = skTrans.resize(image, short_shape, order=self.interpolation, preserve_range=True)  #
                image = image.reshape(self.shape)
            else:
                image = skTrans.resize(image, self.shape, order=self.interpolation, preserve_range=True)  #

        # print('second shape', image.shape)
        # print()
        return image

class RandomFlip(object):
    """Randomly flips (horizontally as well as vertically) the given PIL.Image with a probability of 0.5
    """
    def __init__(self, prob_flip=0.5):
        self.prob_flip= prob_flip
    def __call__(self, image):

        if random.random() < self.prob_flip:
            flip_type = np.random.randint(0, 3) # flip across any 3D axis
            image = np.flip(image, flip_type)
        return image

class ZeroChannel(object):
    """Randomly sets channel to zero the given PIL.Image with a probability of 0.25
    """
    def __init__(self, prob_zero=0.25, channels=4):
        self.prob_zero= prob_zero
        self.channels = channels
    def __call__(self, image):

        if np.random.random() < self.prob_zero:
            channel_to_zero = np.random.randint(0, self.channels) # flip across any 3D axis
            zeros = np.zeros((image.shape[1], image.shape[2], image.shape[3]))
            image[channel_to_zero, :, :, :] = zeros
        return image

class ZeroSprinkle(object):
    def __init__(self, prob_zero=0.25, prob_true=0.5, channels=4):
        self.prob_zero=prob_zero
        self.prob_true=prob_true
        self.channels=channels
    def __call__(self, image):

        if self.prob_true:
            mask = np.random.rand(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
            mask[mask < self.prob_zero] = 0
            mask[mask > 0] = 1
            image = image*mask

        return image


class MinMaxNormalize(object):
    """Min-Max normalization
    """
    def __call__(self, image):
        def norm(im):
            im = im.astype(np.float32)
            min_v = np.min(im)
            max_v = np.max(im)
            im = (im - min_v)/(max_v - min_v)
            return im
        image = norm(image)
        return image

class ToTensor(object):
    def __init__(self, scale=1):
        self.scale = scale

    def __call__(self, image):
        if image is not None:
            image = image.astype(np.float32)
            image = image.reshape((image.shape[0], int(image.shape[1]/self.scale), int(image.shape[2]/self.scale), int(image.shape[3]/self.scale)))
            image_tensor = torch.from_numpy(image)
            return image_tensor
        else:
            return image


class Compose(object):
    """
    Composes several transforms together.
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image):
        for i, t in enumerate(self.transforms):
            image = t(image)
        return image

In [6]:
def get_bb_3D(img, pad=0):
    '''
    This function returns a tumor 3D bounding box using a segmentation mask
    '''
    xs = np.nonzero(np.sum(np.sum(img, axis=1), axis=1))
    ys = np.nonzero(np.sum(np.sum(img, axis=0), axis=1))
    zs = np.nonzero(np.sum(np.sum(img, axis=0), axis=0))
    xmin, xmax = np.min(xs), np.max(xs)
    ymin, ymax = np.min(ys), np.max(ys)
    zmin, zmax = np.min(zs), np.max(zs)
    bbox = (xmin-pad, ymin-pad, zmin-pad, xmax+pad, ymax+pad, zmax+pad)
    return bbox

def min_max(img):
    '''
    Min-max normalization
    '''
    return (img - img.min()) / (img.max() - img.min())

def read_mri(mr_path_dict, pad=0):

    image_shape = nib.load(mr_path_dict['flair']).get_fdata().shape
    bb_seg = get_bb_3D(nib.load(mr_path_dict['flair']).get_fdata())
    (xmin, ymin, zmin, xmax, ymax, zmax) = bb_seg

    xmin = np.max([0, xmin-pad])
    ymin = np.max([0, ymin-pad])
    zmin = np.max([0, zmin-pad])

    xmax = np.min([image_shape[0]-1, xmax+pad])
    ymax = np.min([image_shape[1]-1, ymax+pad])
    zmax = np.min([image_shape[2]-1, zmax+pad])


    img_dict = {}
    for key in ['flair', 't1', 't1ce', 't2', 'seg']:
        img = nib.load(mr_path_dict[key])
        img_data = img.get_fdata()
        img_dict[key] = img_data[xmin:xmax, ymin:ymax, zmin:zmax]

    stacked_img = np.stack([min_max(img_dict['flair']), min_max(img_dict['t1']),min_max(img_dict['t1ce']),min_max(img_dict['t2'])], axis=0)
    return stacked_img, img_dict['seg']


In [7]:
def plot_(image, seg, predicted=False):
    #Overlay with Predicted
    img = image[slice, :, :, :].squeeze()
    img = utils.make_grid(img)
    img = img.detach().cpu().numpy()
    
    print(img.shape)
    
    # plot images
    plt.figure(figsize=(10, 8))
    img_list = [img[i].T for i in range(channels)] # 1 image per channel
    plt.imshow(np.hstack(img_list), cmap='Greys_r')
    
    ## plot segmentation mask ##
    seg_img = torch.tensor(pred[slice].squeeze())
    if not predicted:
        seg_img = torch.tensor(seg_img.numpy()[:, ::-1].copy()) #flip
    seg_img = utils.make_grid(seg_img).detach().cpu().numpy()
    
    print(np.unique(seg_img))

    plt.imshow(np.hstack([seg_img[0].T]), cmap='Greys_r', alpha=0.3)
    plt.show()
    

In [8]:
class GeneralDataset(Dataset):

    def __init__(self,
                metadata_df,
                root_dir,
                transform=None,
                seg_transform=None, ###
                dataformat=None, # indicates what shape (or content) should be returned (2D or 3D, etc.)
                returndims=None, # what size/shape 3D volumes should be returned as.
                output_shape=None,
                visualize=False,
                modality=None,
                pad=2,
                device='cpu'):
        """
        Args:
            metadata_df (string): Path to the csv file w/ patient IDs
            root_dir (string): Directory for MR images
            transform (callable, optional)
        """
        self.device=device
        self.metadata_df = metadata_df
        self.root_dir = root_dir
        self.transform = transform
        self.seg_transform = seg_transform
        self.returndims=returndims
        self.modality = modality
        self.pad = pad
        self.output_shape = output_shape

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        #print(type(idx), idx)
        if torch.is_tensor(idx):
            idx = idx.tolist()

        BraTS20ID = self.metadata_df.iloc[idx].BraTS_2020_subject_ID

        # make dictonary of paths to MRI volumnes (modalities) and segmenation masks
        mr_path_dict = {}
        sequence_type = ['seg', 't1', 't1ce', 'flair', 't2']
        for seq in sequence_type:
            mr_path_dict[seq] = os.path.join(self.root_dir, BraTS20ID, BraTS20ID + '_'+seq+'.nii.gz')

        image, seg_image = read_mri(mr_path_dict=mr_path_dict, pad=self.pad)
        
        if seg_image is not None:
            if self.output_shape == 'wt':
                seg_image[np.nonzero(seg_image)] = 1 #only 0's and 1's for background and tumor
            else:
                seg_image[seg_image == 4] = 3 #0,1,2,3 for background and tumor regions

        if self.transform:
            image = self.transform(image)
        if self.seg_transform:
            seg_image = self.seg_transform(seg_image)
        else:
            print('no transform')
        # print(image.shape)
        return (image, seg_image), BraTS20ID

In [9]:
def read_dataframe(params):
    if platform.system() == 'Windows':
        naming = pd.read_csv(f"{params['image_dir']}\\name_mapping.csv")
    else:
        naming = pd.read_csv(f"{params['image_dir']}/name_mapping.csv")
    
    data_df = pd.DataFrame(naming['BraTS_2020_subject_ID'])

    # n_patients_to_train_with
    total_num_patients = len(data_df)

    assert sum(params['tr_va_te_split']) == 100
    tr_split = int((total_num_patients * params['tr_va_te_split'][0]) / 100)
    va_split = int((total_num_patients * params['tr_va_te_split'][1]) / 100)
    te_split = total_num_patients - (tr_split + va_split)

    print(f"Data is split into train: {tr_split}, validation: {va_split} and test: {te_split}")
                   
    train_df = data_df[: tr_split]
    valid_df = data_df[tr_split : (tr_split + va_split)]
    test_df = data_df[(tr_split + va_split) :]

    return train_df, valid_df, test_df

In [10]:
def retrieve_dataset(df, train=False):

    image_dir, channels, resize_shape, output_shape = params['image_dir'], \
                                                      params['channels'], \
                                                      params['resize_shape'], \
                                                      params['output_shape']

    # basic data augmentation
    prob_voxel_zero = 0 # 0.1
    prob_channel_zero = 0 # 0.5
    prob_true = 0 # 0.8
    randomflip = RandomFlip()

    # MRI transformations
    train_transformations = Compose([
        MinMaxNormalize(),
        ScaleToFixed((channels, resize_shape[0], resize_shape[1],
                      resize_shape[2]), interpolation=1, channels=channels),
        ZeroSprinkle(prob_zero=prob_voxel_zero, prob_true=prob_true),
        ZeroChannel(prob_zero=prob_channel_zero),
        randomflip,
        ToTensor()
        ])
    
    val_transformations = Compose([
            MinMaxNormalize(),
            ScaleToFixed((channels, resize_shape[0], resize_shape[1],
                          resize_shape[2]), interpolation=1, channels=channels),
            ToTensor(),
        ])

    # GT segmentation mask transformations
    seg_transformations = Compose([
        ScaleToFixed((1, resize_shape[0], resize_shape[1],
                      resize_shape[2]), interpolation=0, channels=1),
        randomflip,
        ToTensor(),
        ])
    
    if train:
        dataset = GeneralDataset(metadata_df=df, 
                                root_dir=image_dir,
                                transform=train_transformations,
                                seg_transform=seg_transformations,
                                returndims=resize_shape,
                                output_shape=output_shape)
    else:
        dataset = GeneralDataset(metadata_df=df, 
                                root_dir=image_dir,
                                transform=val_transformations,
                                seg_transform=seg_transformations,
                                returndims=resize_shape,
                                output_shape=output_shape)
    return dataset

In [11]:
def get_data(params):

    train_df, valid_df, test_df = read_dataframe(params)

    train_dataset = retrieve_dataset(train_df, train=True)
    
    valid_dataset = retrieve_dataset(valid_df)

    test_dataset = retrieve_dataset(test_df)
    

    train_loader = DataLoader(train_dataset, batch_size=params['train_batch_size'], shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=params['train_batch_size'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=params['test_batch_size'])

    return train_loader, valid_loader, test_loader

In [12]:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, norm='b', num_groups=2, k_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv3d = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
                                stride=stride, padding=padding)
        if norm == 'b':
            self.norm = nn.BatchNorm3d(num_features=out_channels)
        else:
            # use only one group if the given number of groups is greater than the number of channels
            if out_channels < num_groups:
                num_groups = 1
            assert out_channels % num_groups == 0, f'Expected out_channels{out_channels} in input to be divisible by num_groups{num_groups}'
            self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)

    def forward(self, x):
        x = self.norm(self.conv3d(x))
        x = F.elu(x) #!
        return x

class ConvTranspose(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1, output_padding=1):
        super(ConvTranspose, self).__init__()
        self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels,
                                                   out_channels=out_channels,
                                                   kernel_size=k_size,
                                                   stride=stride,
                                                   padding=padding,
                                                   output_padding=output_padding)

    def forward(self, x):
        return self.conv3d_transpose(x)


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, init_features, norm='b', num_groups=2, model_depth=4, pool_size=2):
        super(EncoderBlock, self).__init__()
        self.root_feat_maps = init_features
        self.num_conv_blocks = 2
        self.module_dict = nn.ModuleDict()
        for depth in range(model_depth):
            feat_map_channels = 2 ** (depth + 1) * self.root_feat_maps
            for i in range(self.num_conv_blocks):
                self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels, norm=norm, num_groups=num_groups)
                self.module_dict["conv_{}_{}".format(depth, i)] = self.conv_block
                in_channels, feat_map_channels = feat_map_channels, feat_map_channels * 2
            if depth == model_depth - 1:
                break
            else:
                self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2, padding=0)
                self.module_dict["max_pooling_{}".format(depth)] = self.pooling

    def forward(self, x):
        down_sampling_features = []
        for k, op in self.module_dict.items():
            if k.startswith("conv"):
                x = op(x)
                #print(k, x.shape)
                if k.endswith("1"):
                    down_sampling_features.append(x)
            elif k.startswith("max_pooling"):
                x = op(x)
                #print(k, x.shape)

        return x, down_sampling_features


class DecoderBlock(nn.Module):
    def __init__(self, out_channels, init_features, norm, num_groups=2, model_depth=4):
        super(DecoderBlock, self).__init__()
        self.num_conv_blocks = 2
        self.num_feat_maps = init_features
        # user nn.ModuleDict() to store ops
        self.module_dict = nn.ModuleDict()

        for depth in range(model_depth - 2, -1, -1):
            # print(depth)
            feat_map_channels = 2 ** (depth + 1) * self.num_feat_maps
            # print(feat_map_channels * 4)
            self.deconv = ConvTranspose(in_channels=feat_map_channels * 4, out_channels=feat_map_channels * 4)
            self.module_dict["deconv_{}".format(depth)] = self.deconv
            for i in range(self.num_conv_blocks):
                if i == 0:
                    self.conv = ConvBlock(in_channels=feat_map_channels * 6, out_channels=feat_map_channels * 2, norm=norm, num_groups=num_groups)
                    self.module_dict["conv_{}_{}".format(depth, i)] = self.conv
                else:
                    self.conv = ConvBlock(in_channels=feat_map_channels * 2, out_channels=feat_map_channels * 2, norm=norm, num_groups=num_groups)
                    self.module_dict["conv_{}_{}".format(depth, i)] = self.conv
            if depth == 0:
                self.final_conv = ConvBlock(in_channels=feat_map_channels * 2, out_channels=out_channels, norm=norm, num_groups=num_groups)
                self.module_dict["final_conv"] = self.final_conv

    def forward(self, x, down_sampling_features):
        """
        :param x: inputs
        :param down_sampling_features: feature maps from encoder path
        :return: output
        """
        for k, op in self.module_dict.items():
            if k.startswith("deconv"):
                x = op(x)
                #print(k, x.shape)
                x = torch.cat((down_sampling_features[int(k[-1])], x), dim=1)
            elif k.startswith("conv"):
                x = op(x)
                #print(k, x.shape)
            else:
                x = op(x)
                #print(k, x.shape)
        return x


In [13]:
class UnetModel(nn.Module):

    def __init__(self, in_channels, out_channels, init_features, norm, num_groups=2, model_depth=4, final_activation="sigmoid"):
        super(UnetModel, self).__init__()
        self.encoder = EncoderBlock(in_channels=in_channels,
                                    init_features=init_features,
                                    norm=norm, num_groups=num_groups,
                                    model_depth=model_depth)
        self.decoder = DecoderBlock(out_channels=out_channels,
                                    init_features=init_features,
                                    norm=norm, num_groups=num_groups,
                                    model_depth=model_depth)
        if final_activation == "sigmoid":
            self.sigmoid = nn.Sigmoid()
        else:
            self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x, downsampling_features = self.encoder(x)
        x = self.decoder(x, downsampling_features)
        x = self.sigmoid(x)
        # print("Final output shape: ", x.shape)
        return x



In [14]:

class DiceLoss(nn.Module):
    def __init__(self, epsilon=1e-5):
        super(DiceLoss, self).__init__()
        # smooth factor
        self.epsilon = epsilon

    def forward(self, logits, targets):
        batch_size = targets.size(0)
        # log_prob = torch.sigmoid(logits)
        logits = logits.view(batch_size, -1).type(torch.FloatTensor)
        targets = targets.view(batch_size, -1).type(torch.FloatTensor)
        intersection = (logits * targets).sum(-1)
        dice_score = 2. * intersection / ((logits + targets).sum(-1) + self.epsilon)
        # dice_score = 1 - dice_score.sum() / batch_size
        dice_score = torch.mean(1. - dice_score)
        dice_score.requires_grad = True
        return dice_score

In [15]:
class EDiceLoss(nn.Module):
    """Dice loss tailored to Brats need.
    """

    def __init__(self, do_sigmoid=True):
        super(EDiceLoss, self).__init__()
        self.do_sigmoid = do_sigmoid
        self.labels = ["ET", "TC", "WT"]
        self.device = "cpu"

    def binary_dice(self, inputs, targets, label_index, metric_mode=False):
        smooth = 1.
        if self.do_sigmoid:
            inputs = torch.sigmoid(inputs)

        if metric_mode:
            inputs = inputs > 0.5
            if targets.sum() == 0:
                print(f"No {self.labels[label_index]} for this patient")
                if inputs.sum() == 0:
                    return torch.tensor(1., device=device)
                else:
                    return torch.tensor(0., device=device)
            # Threshold the pred
        intersection = EDiceLoss.compute_intersection(inputs, targets)
        if metric_mode:
            dice = (2 * intersection) / ((inputs.sum() + targets.sum()) * 1.0)
        else:
            dice = (2 * intersection + smooth) / (inputs.pow(2).sum() + targets.pow(2).sum() + smooth)
        if metric_mode:
            return dice
        return 1 - dice

    @staticmethod
    def compute_intersection(inputs, targets):
        intersection = torch.sum(inputs * targets)
        return intersection

    def forward(self, inputs, target):
        dice = 0
        for i in range(target.size(1)):
            dice = dice + self.binary_dice(inputs[:, i, ...], target[:, i, ...], i)
        final_dice = dice / target.size(1)
        final_dice.requires_grad = True
        return final_dice

    def metric(self, inputs, target):
        dices = []
        for j in range(target.size(0)):
            dice = []
            for i in range(target.size(1)):
                dice.append(self.binary_dice(inputs[j, i], target[j, i], i, True))
            dices.append(dice)
        return dices

In [16]:
def calculate_metrics(preds, targets, patient):
    """
    Parameters
    ----------
    preds:
        torch tensor of size 1*C*Z*Y*X, ours BS*Z*Y*X 
    targets:
        torch tensor of same shape
    patient :
        The patient ID
    """

    assert preds.shape == targets.shape, "Preds and targets do not have the same size"
    pp = pprint.PrettyPrinter(indent=4)
    
    preds, targets = preds.detach().cpu().numpy(), targets.detach().cpu().numpy()

    metrics_list = []

    metrics = dict(
        patient_id=patient,
    )
    # print(targets.shape, targets.dtype, targets)
    
    if np.sum(targets) == 0:
        print(f"{label} not present for {patient}")
    else:
        tp = np.sum(l_and(preds, targets))
        tn = np.sum(l_and(l_not(preds), l_not(targets)))
        fp = np.sum(l_and(preds, l_not(targets)))
        fn = np.sum(l_and(l_not(preds), targets))

        sens = tp / (tp + fn)
        spec = tn / (tn + fp)
        acc = (tn + tp) / (tn + tp + fn + fp)
        dice = 2 * tp / (2 * tp + fp + fn)

    metrics[DICE] = dice
    metrics[ACC] = acc
    metrics[SENS] = sens
    metrics[SPEC] = spec
    # pp.pprint(metrics)
    metrics_list.append(metrics)

    return acc, dice, metrics_list


DICE = "dice"
ACC = "acc"
SENS = "sens"
SPEC = "spec"
METRICS = [DICE, ACC, SENS, SPEC]


In [17]:
def calc_dice(preds, targets):
    return (2 * torch.sum(preds * targets)) / ((preds.sum() + preds.sum()) * 1.0)
    

In [18]:


def plot_metric(train, label, metric_name):
    # Plot losses
    plt.figure(figsize=(10,8))
    plt.semilogy(train, label=label)
    plt.xlabel('Epoch')
    plt.ylabel(metric_name)
    plt.legend()
    plt.title(f'Model {metric_name} Plot')
    plt.savefig(f'Model_{metric_name}_{label}_Plot.png')
    plt.show()
    plt.clf()

def plot_result(kfolds, num_epochs, fold_train_history, fold_valid_history):
    final_fold = {'train_loss':[],'valid_loss':[],'train_acc':[],'valid_acc':[]}

    for epoch in range(num_epochs):                                      
        final_fold['train_loss'].append(np.mean([fold_train_history[str(fold)]['train_loss'][epoch] for fold in range(kfolds)]))
        final_fold['train_acc'].append(np.mean([fold_train_history[str(fold)]['train_acc'][epoch]for fold in range(kfolds)]))

    plot_metric(final_fold['train_loss'], 'train', 'Loss')
    plot_metric(final_fold['train_acc'], 'validation', 'Accuracy')

    final_fold['valid_loss'].append([fold_valid_history[str(fold)]['valid_loss'] for fold in range(kfolds)])
    final_fold['valid_acc'].append([fold_valid_history[str(fold)]['valid_acc'] for fold in range(kfolds)])

    print(final_fold)

In [19]:
def load_checkpoint(path, params):

    test_model = UnetModel(params['pretrain_in_channels'],
                           params['pretrain_out_channels'],
                           params['init_features'], params['norm'],
                           params['num_groups'],
                           )

    test_optimizer = torch.optim.AdamW(test_model.parameters(),
                                       lr=params['learning_rate'],
                                       )
    
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        checkpoint = torch.load(path)
    else:
        checkpoint = torch.load(path, map_location=torch.device('cpu'))

    test_model.load_state_dict(checkpoint['model_state_dict'])
    test_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    return test_model, test_optimizer, epoch, loss

In [20]:
def save_model_folds(model, optimizer, fold, epoch, loss):
    # Saving the model
    save_path = f'model-fold-{fold}.pth'

    checkpoint = {'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': loss,
                  }
    torch.save(checkpoint, save_path)


In [21]:
def save_model_nofolds(model, optimizer, epoch, loss, params):
    now = datetime.now()
    # dd/mm/YY H:M:S
    dt_string = now.strftime("%d_%H_%M")

    # Saving the model
    save_path = f"model_{params['output_shape']}_{params['run_name']}_{dt_string}.pth"

    checkpoint = {'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': loss,
                  'params': params,
                  }
    torch.save(checkpoint, save_path)

In [22]:
def kFoldRunAll(criterion, dataset, params): 
                
    k_folds, num_epochs, train_batch_size = params['k_folds'],\
                                            params['no_epochs'],\
                                            params['train_batch_size']
    
    use_cuda, loss_name, in_channels, out_channels = params['use_cuda'], \
                                                     params['loss_name'],\
                                                     params['in_channels'],\
                                                     params['out_channels'],

    init_features, learning_rate, norm, num_groups = params['init_features'],\
                                                     params['learning_rate'],\
                                                     params['norm'], \
                                                     params['num_groups'],

    loss_function = criterion

    # Define the K-fold Cross Validator
    kfold = KFold(n_splits=k_folds, shuffle=True)

    fold_train_history = {}
    fold_valid_history = {}
    fold_train_and_valid_acc = {}
    fold_train_and_valid_loss = {}

    print('--------------------------------')
    # K-fold Cross Validation model evaluation
    for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
        print(f'FOLD {fold}')
        print('--------------------------------')
        # Sample elements randomly from a given list of ids, no replacement.
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        valid_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

        # Define data loaders for training and testing data in this fold
        dataloader_train = DataLoader(dataset, batch_size=train_batch_size, sampler=train_subsampler, num_workers=0)
        dataloader_valid = DataLoader(dataset, batch_size=train_batch_size, sampler=valid_subsampler, num_workers=0)

        # Initialize optimizer and Model
        model = UnetModel(in_channels=in_channels, out_channels=out_channels,
                          init_features=init_features, norm=norm, num_groups=num_groups)
        if use_cuda:
            model = model.cuda()
        #print(model)
        optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

        # Run the training, testing and saving loop for defined number of epochs
        start_time = time.time()

        t_loss, t_acc, t_history, v_loss, v_acc, v_history = train_test_folds(model, loss_function, optimizer,
                                                                               dataloader_train, dataloader_valid,
                                                                               fold, num_epochs, use_cuda,
                                                                               loss_name)

        end_time = time.time()
        print(f"Epoch Time: {end_time - start_time}")

        #Saving loss results 
        fold_train_and_valid_loss[str(fold)] = [t_loss, v_loss]
        fold_train_and_valid_acc[str(fold)] = [t_acc, v_acc]
        fold_train_history[str(fold)] = t_history
        fold_valid_history[str(fold)] = v_history

        # Print accuracy
        print(f'Accuracy for fold {fold}: {v_acc}')
        print(f'Loss for fold {fold}: {v_loss}')
        print('--------------------------------')  

    return fold_train_history, fold_valid_history, fold_train_and_valid_loss, fold_train_and_valid_acc
  


In [23]:
def train_test_folds(model, loss_function, optimizer, dataloader_train, dataloader_valid, fold, num_epochs, use_cuda, loss_name):
    train_history = {'train_loss': [], 'train_acc':[], 'train_dice':[]}
    valid_history = {'valid_loss': [], 'valid_acc':[], 'valid_dice':[]}
    best = math.inf

    edice = EDiceLoss()
    if use_cuda:
        edice = edice.cuda()
    metric = edice.metric
        
    for epoch in range(num_epochs):
        print(f'Starting Train epoch: {epoch+1}')

        train_loss = 0.0
        train_acc, train_dice = 0, 0
        model.train()

        for i, data in enumerate(tqdm.tqdm(dataloader_train)):
            (inputs, targets), ID = data
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            #print(inputs.shape, outputs.shape, targets.squeeze(1).long().shape)

            if loss_name == 'dice':
                class_outputs = outputs.argmax(dim=1)
                loss = loss_function(class_outputs, targets.squeeze(1).long())
                # print(np.unique(class_outputs.detach().numpy()), class_outputs.shape, targets.squeeze(1).shape)
            else:
                loss = loss_function(outputs, targets.squeeze(1).long())

            train_loss += loss.item() * outputs.size(0) #multiplying by batchsize
            
            rtrain_dice1 = calc_dice(outputs.argmax(dim=1), targets.squeeze(1))
            rtrain_acc, rtrain_dice2, _ = calculate_metrics(outputs.argmax(dim=1), targets.squeeze(1), ID)
            print(f'Train Dice 1: {rtrain_dice1}, 2: {rtrain_dice2} \t Acc: {rtrain_acc}')
            train_acc += rtrain_acc
            train_dice += rtrain_dice1
            
            print(f'Train Loss :{loss.item()}')

            loss.backward()
            optimizer.step()
           
        train_history['train_loss'].append(train_loss / len(dataloader_train.sampler))
        train_history['train_acc'].append(train_acc / len(dataloader_train.sampler))
        train_history['train_dice'].append(train_dice / len(dataloader_train.sampler))

        print(f"Train Epoch loss: {train_history['train_loss'][-1]}, \t ACC/DICE :{train_history['train_acc'][-1]}/{train_history['train_dice'][-1]} ")

    valid_loss = 0.0
    valid_acc, valid_dice = 0, 0
           
    model.eval()
    #! maybe change later to validate after some epochs
    with torch.no_grad():
        # Iterate over the test data and generate predictions
        for i, data in enumerate(tqdm.tqdm(dataloader_valid)):
            (inputs, targets), ID = data
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() 
            outputs = model(inputs)

            if params['loss_name'] == 'dice':
                class_outputs = outputs.argmax(dim=1)
                loss = loss_function(class_outputs, targets.squeeze(1).long())
                # print(np.unique(class_outputs.detach().numpy()), class_outputs.shape, targets.squeeze(1).shape)
            else:
                loss = loss_function(outputs, targets.squeeze(1).long())

            # print('Valid Loss:', loss.item())
            valid_loss += loss.item() * inputs.size(0)
            
            rvalid_dice1 = calc_dice(outputs.argmax(dim=1), targets.squeeze(1))

            rvalid_acc, rvalid_dice2, _ = calculate_metrics(outputs.argmax(dim=1), targets.squeeze(1), ID)
            # print(f'Val Dice 1: {rvalid_dice1}, 2: {rvalid_dice2}')
            valid_acc += rvalid_acc
            valid_dice += rvalid_dice1
            
        # Print accuracy
        print(f'Val Dice : {valid_dice}, len {len(dataloader_valid.sampler)}')
        valid_loss /= len(dataloader_valid.sampler) 
        valid_acc = valid_acc / len(dataloader_valid.sampler)
        valid_dice = valid_dice / len(dataloader_valid.sampler)
        #print(f" Fold Accuracy: {valid_acc}")

    valid_history['valid_loss'].append(valid_loss)
    valid_history['valid_acc'].append(valid_acc)
    valid_history['valid_dice'].append(valid_dice)

    print(f"Val Epoch loss: {valid_history['valid_loss'][-1]} \t acc/dice:/ {valid_history['valid_acc'][-1]}/ {valid_history['valid_dice'][-1]}")

    # saving best model for this fold
    if valid_loss < best:
        best = valid_loss
        save_model_folds(model, optimizer, fold, epoch, loss)
    
    
    return train_history['train_loss'][-1], train_history['train_acc'][-1], train_history, valid_history['valid_loss'][-1], valid_history['valid_acc'][-1], valid_history


In [24]:
def train_test_nofolds(model, loss_function, optimizer, dataloader_train, dataloader_valid, params):
    train_history = {'train_loss': [], 'train_acc':[], 'train_dice':[]}
    valid_history = {'valid_loss': [], 'valid_acc':[], 'valid_dice':[]}
    best = math.inf

    edice = EDiceLoss()
    if params['use_cuda']:
        edice = edice.cuda()
    metric = edice.metric

    #Automatic mixed precision addition
    #scaler = torch.cuda.amp.GradScaler()
        
    for epoch in range(params['no_epochs']):
        print(f'Starting Train epoch: {epoch+1}')

        train_loss = 0.0
        train_acc, train_dice = 0, 0
        model.train()

        for i, data in enumerate(tqdm.tqdm(dataloader_train)):
            (inputs, targets), ID = data
            if params['use_cuda']:
                inputs, targets = inputs.cuda(), targets.cuda()

            optimizer.zero_grad()
            #with torch.cuda.amp.autocast(enabled=params['autocast']):
            outputs = model(inputs)
            #print(inputs.shape, outputs.shape, targets.squeeze(1).long().shape)

            if params['loss_name'] == 'dice':
                class_outputs = outputs.argmax(dim=1)
                loss = loss_function(class_outputs, targets.squeeze(1).long())
                # print(np.unique(class_outputs.detach().numpy()), class_outputs.shape, targets.squeeze(1).shape)
            else:
                loss = loss_function(outputs, targets.squeeze(1).long())

            train_loss += loss.item() * outputs.size(0) # multiplying by batchsize
            
            rtrain_dice1 = calc_dice(outputs.argmax(dim=1), targets.squeeze(1))
            rtrain_acc, rtrain_dice2, _ = calculate_metrics(outputs.argmax(dim=1), targets.squeeze(1), ID)

            if i % 10 == 0:
                print(f'Train Dice 1: {rtrain_dice1}, 2: {rtrain_dice2} \t Acc: {rtrain_acc}')

            train_acc += rtrain_acc
            train_dice += rtrain_dice1
            if i % 10 == 0:
                print(f'Train Loss :{loss.item()}')
            
            #scaler.scale(loss).backward()
            #scaler.step(optimizer)
            #scaler.update()
            loss.backward()
            optimizer.step()
           
        train_history['train_loss'].append(train_loss / len(dataloader_train.sampler))
        train_history['train_acc'].append(train_acc / len(dataloader_train.sampler))
        train_history['train_dice'].append(train_dice / len(dataloader_train.sampler))

        print(f"Train Epoch loss: {train_history['train_loss'][-1]}, \t ACC/DICE :{train_history['train_acc'][-1]}/{train_history['train_dice'][-1]} ")
          
        if epoch % (.1 * params['no_epochs']) == 0:
            valid_loss = 0.0
            valid_acc, valid_dice = 0, 0
            model.eval()

            with torch.no_grad():
                # Iterate over the test data and generate predictions
                for i, data in enumerate(tqdm.tqdm(dataloader_valid)):
                    (inputs, targets), ID = data
                    if params['use_cuda']:
                        inputs, targets = inputs.cuda(), targets.cuda() 
                      
                    #with torch.cuda.amp.autocast(enabled=params['autocast']):
                    outputs = model(inputs)

                    if params['loss_name'] == 'dice':
                        class_outputs = outputs.argmax(dim=1)
                        loss = loss_function(class_outputs, targets.squeeze(1).long())
                        # print(np.unique(class_outputs.detach().numpy()), class_outputs.shape, targets.squeeze(1).shape)
                    else:
                        loss = loss_function(outputs, targets.squeeze(1).long())

                    # print('Valid Loss:', loss.item())
                    valid_loss += loss.item() * inputs.size(0)
                    
                    rvalid_dice1 = calc_dice(outputs.argmax(dim=1), targets.squeeze(1))

                    rvalid_acc, rvalid_dice2, _ = calculate_metrics(outputs.argmax(dim=1), targets.squeeze(1), ID)
                    # print(f'Val Dice 1: {rvalid_dice1}, 2: {rvalid_dice2}')
                    valid_acc += rvalid_acc
                    valid_dice += rvalid_dice1
                    
                # Print accuracy
                print(f'Val Dice : {valid_dice}, len {len(dataloader_valid.sampler)}')
                valid_loss /= len(dataloader_valid.sampler) 
                valid_acc = valid_acc / len(dataloader_valid.sampler)
                valid_dice = valid_dice / len(dataloader_valid.sampler)

            valid_history['valid_loss'].append(valid_loss)
            valid_history['valid_acc'].append(valid_acc)
            valid_history['valid_dice'].append(valid_dice)

            print(f"Val Epoch loss: {valid_history['valid_loss'][-1]} \t acc/dice:/ {valid_history['valid_acc'][-1]}/ {valid_history['valid_dice'][-1]}")

            # saving best model for this fold
            if valid_loss < best:
                best = valid_loss
                save_model_nofolds(model, optimizer, epoch, loss, params)
    
    
    return train_history['train_loss'][-1], train_history['train_acc'][-1], train_history, valid_history['valid_loss'][-1], valid_history['valid_acc'][-1], valid_history


In [25]:
def kfolds(params):
    
    if params['loss_name'] == 'ce':
        criterion = CrossEntropyLoss()
    elif params['loss_name'] == 'wce':
        criterion = CrossEntropyLoss(weight=torch.Tensor([1, 355.36116969, 74.37872817, 254.58104099]))
    else:
        criterion = EDiceLoss().cuda()
      
    if params['use_cuda']:
        criterion.cuda()
    
    train_df, valid_df, _ = read_dataframe(params)

    train_valid_df = pd.concat([train_df, valid_df])

    train_valid_dataset = retrieve_dataset(train_valid_df)

    t_history, v_history, tv_loss, tv_acc = kFoldRunAll(criterion, 
                                                        train_valid_dataset,
                                                        params)
    
    plot_result(k_folds, no_epochs, t_history, v_history)

    return t_history, v_history, tv_loss, tv_acc

In [26]:
def transfer(path, params):
    
    if params['loss_name'] == 'ce':
        criterion = CrossEntropyLoss()
    elif params['loss_name'] == 'wce':
        criterion = CrossEntropyLoss(weight=torch.Tensor([1, 355.36116969, 74.37872817, 254.58104099]))
    else:
        criterion = EDiceLoss().cuda()
    
    model, optimizer, _, __ = load_checkpoint(path, params)

    # Transfer by changing(replacing) only last layer and finetuning to outdim=4
    model.decoder.final_conv = ConvBlock(in_channels=params['pretrain_in_final_conv'], 
                                          out_channels=params['out_channels'],
                                          norm=params['norm'],
                                          num_groups=params['num_groups'])

    model.decoder.module_dict.final_conv = ConvBlock(in_channels=params['pretrain_in_final_conv'], 
                                          out_channels=params['out_channels'],
                                          norm=params['norm'],
                                          num_groups=params['num_groups'])

    if params['use_cuda']:
        model.cuda()
        criterion.cuda()
    
    dataloader_train, dataloader_valid, _ = get_data(params)
    
    t_loss, t_acc, t_history, v_loss, v_acc, v_history = train_test_nofolds(
                                                model, criterion, optimizer,
                                                dataloader_train, dataloader_valid,
                                                params
                                                )

    # plot_result(k_folds, params['no_epochs'], t_history, v_history)

    return t_loss, t_acc, t_history, v_loss, v_acc, v_history

In [27]:
def learn_from_scratch(params):
    
    if params['loss_name'] == 'ce':
        criterion = CrossEntropyLoss()
    elif params['loss_name'] == 'wce':
        criterion = CrossEntropyLoss(weight=torch.Tensor([1, 355.36116969, 74.37872817, 254.58104099]))
    else:
        criterion = EDiceLoss()
    


    model = UnetModel(params['in_channels'],
                      params['out_channels'],
                      params['init_features'],
                      params['norm'],
                      params['num_groups'],
                      )

    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=params['learning_rate'],
                                  )

    if params['use_cuda']:
        model.cuda()
        criterion.cuda()
    
    dataloader_train, dataloader_valid, _ = get_data(params)
    
    t_loss, t_acc, t_history, v_loss, v_acc, v_history = train_test_nofolds(
                                                model, criterion, optimizer,
                                                dataloader_train, dataloader_valid,
                                                params
                                                )

    # plot_result(k_folds, params['no_epochs'], t_history, v_history)

    return t_loss, t_acc, t_history, v_loss, v_acc, v_history

In [28]:
def get_params():

    if platform.system() == 'Windows':
        image_dir = r"C:\Users\wisdomik\Documents\project\MICCAI_BraTS2020_TrainingData"
    else:
        image_dir = '../../data/data/mri/MICCAI_BraTS2020_TrainingData/'
    
    use_cuda = torch.cuda.is_available()

    params = {'run_name': 'run_1',
              'in_channels': 4,
              'out_channels': 4,
              'no_epochs': 1,
              'k_folds': 5,
              'learning_rate': 1e-4,
              'loss_name': 'wce',
              'output_shape': 'all',
              'tr_va_te_split': [75, 25, 0],
              'pretrain_in_channels': 4,
              'pretrain_out_channels': 2,
              'pretrain_in_final_conv': 16,
              'init_features': 8,
              'train_batch_size': 1,
              'autocast': False, # not in use
              'test_batch_size': 1,
              'norm': 'g',
              'num_groups': 4,
              'channels': 4,
              'resize_shape': (128, 128, 128),
              'image_dir': image_dir,
              'use_cuda': use_cuda
              }
              
    return params



In [29]:
# set params then run this

#load in trained model for evaluation
# fold_to_check = 9
# PATH = f'drive/MyDrive/Colab Notebooks/model-fold-{fold_to_check}.pth'

params = get_params()
params

{'run_name': 'run_1',
 'in_channels': 4,
 'out_channels': 4,
 'no_epochs': 1,
 'k_folds': 5,
 'learning_rate': 0.0001,
 'loss_name': 'wce',
 'output_shape': 'all',
 'tr_va_te_split': [75, 25, 0],
 'pretrain_in_channels': 4,
 'pretrain_out_channels': 2,
 'pretrain_in_final_conv': 16,
 'init_features': 8,
 'train_batch_size': 1,
 'autocast': False,
 'test_batch_size': 1,
 'norm': 'g',
 'num_groups': 4,
 'channels': 4,
 'resize_shape': (128, 128, 128),
 'image_dir': '../../data/data/mri/MICCAI_BraTS2020_TrainingData/',
 'use_cuda': True}

In [None]:
# Run learn from scratch (UNCOMMENT TO RUN)

t_loss, t_acc, t_history, v_loss, v_acc, v_history = learn_from_scratch(params)

  0%|          | 0/276 [00:00<?, ?it/s]

Data is split into train: 276, validation: 92 and test: 1
Starting Train epoch: 1


  0%|          | 1/276 [00:02<12:28,  2.72s/it]

Train Dice 1: 0.032238684594631195, 2: 0.04294479606703858 	 Acc: 0.22164011001586914
Train Loss :1.4165858030319214


  1%|          | 2/276 [00:06<14:52,  3.26s/it]

Train Dice 1: 0.03821874409914017, 2: 0.05389257950530035 	 Acc: 0.36163806915283203
Train Loss :1.4056919813156128


  1%|          | 3/276 [00:09<13:53,  3.05s/it]

Train Dice 1: 0.0693030059337616, 2: 0.09814075576126262 	 Acc: 0.43248653411865234
Train Loss :1.42372727394104


  1%|▏         | 4/276 [00:11<13:23,  2.95s/it]

Train Dice 1: 0.0425037145614624, 2: 0.054177816691535846 	 Acc: 0.4514155387878418
Train Loss :1.3172378540039062


  2%|▏         | 5/276 [00:15<14:30,  3.21s/it]

Train Dice 1: 0.12048594653606415, 2: 0.14569366318252921 	 Acc: 0.4005126953125
Train Loss :1.4001191854476929


  2%|▏         | 6/276 [00:18<13:44,  3.05s/it]

Train Dice 1: 0.11704041063785553, 2: 0.11272483381568321 	 Acc: 0.42533063888549805
Train Loss :1.3813878297805786


  3%|▎         | 7/276 [00:21<13:06,  2.92s/it]

Train Dice 1: 0.256578266620636, 2: 0.18946459510731634 	 Acc: 0.45481395721435547
Train Loss :1.2816311120986938


  3%|▎         | 8/276 [00:24<13:53,  3.11s/it]

Train Dice 1: 0.2299259752035141, 2: 0.19034140560725354 	 Acc: 0.4565858840942383
Train Loss :1.3590397834777832


  3%|▎         | 9/276 [00:27<13:14,  2.97s/it]

Train Dice 1: 0.09790834784507751, 2: 0.09978119485596987 	 Acc: 0.4606938362121582
Train Loss :1.3536851406097412


  4%|▎         | 10/276 [00:29<12:45,  2.88s/it]

Train Dice 1: 0.10110199451446533, 2: 0.1012647608270453 	 Acc: 0.5606951713562012
Train Loss :1.3341515064239502


  4%|▍         | 11/276 [00:32<12:34,  2.85s/it]

Train Dice 1: 0.05736158788204193, 2: 0.07602512541979352 	 Acc: 0.5749406814575195
Train Loss :1.3185280561447144


  4%|▍         | 12/276 [00:35<13:06,  2.98s/it]

Train Dice 1: 0.09226773679256439, 2: 0.10600313654966873 	 Acc: 0.5960731506347656
Train Loss :1.3078498840332031


  5%|▍         | 13/276 [00:39<13:57,  3.19s/it]

Train Dice 1: 0.08679462224245071, 2: 0.0851507861251092 	 Acc: 0.5785961151123047
Train Loss :1.3104546070098877


  5%|▌         | 14/276 [00:42<13:11,  3.02s/it]

Train Dice 1: 0.20370236039161682, 2: 0.17591236984453112 	 Acc: 0.6167221069335938
Train Loss :1.3095195293426514


  5%|▌         | 15/276 [00:44<12:39,  2.91s/it]

Train Dice 1: 0.1445893943309784, 2: 0.12751289264415502 	 Acc: 0.5832443237304688
Train Loss :1.2993757724761963


  6%|▌         | 16/276 [00:48<13:30,  3.12s/it]

Train Dice 1: 0.06157253682613373, 2: 0.08266330742341452 	 Acc: 0.5680370330810547
Train Loss :1.3187754154205322


  6%|▌         | 17/276 [00:51<12:52,  2.98s/it]

Train Dice 1: 0.02964138425886631, 2: 0.02543218848806078 	 Acc: 0.5549545288085938
Train Loss :1.329577088356018


  7%|▋         | 18/276 [00:53<12:15,  2.85s/it]

Train Dice 1: 0.035810619592666626, 2: 0.0426551515311093 	 Acc: 0.5839042663574219
Train Loss :1.3505849838256836


  7%|▋         | 19/276 [00:58<15:15,  3.56s/it]

Train Dice 1: 0.05391969531774521, 2: 0.06469396063146783 	 Acc: 0.5935540199279785
Train Loss :1.314012885093689


  7%|▋         | 20/276 [01:02<15:22,  3.60s/it]

Train Dice 1: 0.052867453545331955, 2: 0.09823413161265333 	 Acc: 0.6092252731323242
Train Loss :1.2783105373382568


  8%|▊         | 21/276 [01:05<14:08,  3.33s/it]

Train Dice 1: 0.27461397647857666, 2: 0.2786629065432575 	 Acc: 0.6575145721435547
Train Loss :1.291001796722412


  8%|▊         | 22/276 [01:07<13:15,  3.13s/it]

Train Dice 1: 0.07458652555942535, 2: 0.12200838234636094 	 Acc: 0.5789608955383301
Train Loss :1.3166663646697998


  8%|▊         | 23/276 [01:10<12:46,  3.03s/it]

Train Dice 1: 0.11700263619422913, 2: 0.10174197705581876 	 Acc: 0.5699257850646973
Train Loss :1.3277348279953003


  9%|▊         | 24/276 [01:13<12:10,  2.90s/it]

Train Dice 1: 0.10055471956729889, 2: 0.1544808948022764 	 Acc: 0.6192827224731445
Train Loss :1.298248052597046


  9%|▉         | 25/276 [01:16<11:48,  2.82s/it]

Train Dice 1: 0.20079940557479858, 2: 0.18057888487823087 	 Acc: 0.6064109802246094
Train Loss :1.3204351663589478


  9%|▉         | 26/276 [01:21<14:49,  3.56s/it]

Train Dice 1: 0.19245599210262299, 2: 0.2670975634317086 	 Acc: 0.6251907348632812
Train Loss :1.2835166454315186


 10%|▉         | 27/276 [01:23<13:40,  3.30s/it]

Train Dice 1: 0.07573344558477402, 2: 0.07394162644690319 	 Acc: 0.6058483123779297
Train Loss :1.3231651782989502


 10%|█         | 28/276 [01:26<12:50,  3.11s/it]

Train Dice 1: 0.11823894828557968, 2: 0.17350280346727676 	 Acc: 0.6086320877075195
Train Loss :1.286097526550293


 11%|█         | 29/276 [01:29<12:10,  2.96s/it]

Train Dice 1: 0.10998803377151489, 2: 0.19414880258068837 	 Acc: 0.5868968963623047
Train Loss :1.2846287488937378


 11%|█         | 30/276 [01:31<11:41,  2.85s/it]

Train Dice 1: 0.11105170100927353, 2: 0.10338843272322007 	 Acc: 0.6075558662414551
Train Loss :1.315807819366455


 11%|█         | 31/276 [01:34<11:26,  2.80s/it]

Train Dice 1: 0.1501373052597046, 2: 0.2014561322282519 	 Acc: 0.6167488098144531
Train Loss :1.2923977375030518


 12%|█▏        | 32/276 [01:37<11:10,  2.75s/it]

Train Dice 1: 0.03144967183470726, 2: 0.03982238771785949 	 Acc: 0.5662083625793457
Train Loss :1.3207765817642212


 12%|█▏        | 33/276 [01:39<11:00,  2.72s/it]

Train Dice 1: 0.22652171552181244, 2: 0.21692143509146383 	 Acc: 0.6130199432373047
Train Loss :1.2999036312103271


 12%|█▏        | 34/276 [01:42<10:50,  2.69s/it]

Train Dice 1: 0.04498417302966118, 2: 0.09996063660876756 	 Acc: 0.5780606269836426
Train Loss :1.2784093618392944


 13%|█▎        | 35/276 [01:45<10:43,  2.67s/it]

Train Dice 1: 0.1162809282541275, 2: 0.12049289850051963 	 Acc: 0.5762777328491211
Train Loss :1.3222882747650146


 13%|█▎        | 36/276 [01:47<10:34,  2.64s/it]

Train Dice 1: 0.09782511740922928, 2: 0.11458373286677662 	 Acc: 0.5707030296325684
Train Loss :1.3163785934448242


 13%|█▎        | 37/276 [01:50<10:40,  2.68s/it]

Train Dice 1: 0.11465734988451004, 2: 0.11571749346049677 	 Acc: 0.544619083404541
Train Loss :1.2886465787887573


 14%|█▍        | 38/276 [01:53<10:46,  2.72s/it]

Train Dice 1: 0.06674108654260635, 2: 0.1274311659065687 	 Acc: 0.5890893936157227
Train Loss :1.2770777940750122


 14%|█▍        | 39/276 [01:55<10:37,  2.69s/it]

Train Dice 1: 0.1504504531621933, 2: 0.17906951281738762 	 Acc: 0.6210222244262695
Train Loss :1.277421474456787


 14%|█▍        | 40/276 [01:59<11:33,  2.94s/it]

Train Dice 1: 0.05358698219060898, 2: 0.08067014980346456 	 Acc: 0.55010986328125
Train Loss :1.3073629140853882


 15%|█▍        | 41/276 [02:01<10:59,  2.81s/it]

Train Dice 1: 0.2899678349494934, 2: 0.25257572031252956 	 Acc: 0.6232213973999023
Train Loss :1.3109629154205322


 15%|█▌        | 42/276 [02:05<12:14,  3.14s/it]

Train Dice 1: 0.07611244916915894, 2: 0.15166705735975622 	 Acc: 0.5962009429931641
Train Loss :1.2493623495101929


 16%|█▌        | 43/276 [02:08<11:37,  2.99s/it]

Train Dice 1: 0.2613890469074249, 2: 0.18972774784264473 	 Acc: 0.6185312271118164
Train Loss :1.3221999406814575


 16%|█▌        | 44/276 [02:11<11:11,  2.89s/it]

Train Dice 1: 0.03511171042919159, 2: 0.060517566337011074 	 Acc: 0.5769186019897461
Train Loss :1.3033071756362915


 16%|█▋        | 45/276 [02:13<10:57,  2.85s/it]

Train Dice 1: 0.193710595369339, 2: 0.2008632769844357 	 Acc: 0.5934624671936035
Train Loss :1.3329975605010986


 17%|█▋        | 46/276 [02:16<10:40,  2.79s/it]

Train Dice 1: 0.008259423077106476, 2: 0.010837061069327697 	 Acc: 0.5394315719604492
Train Loss :1.3562748432159424


 17%|█▋        | 47/276 [02:19<10:27,  2.74s/it]

Train Dice 1: 0.12600380182266235, 2: 0.12483486397467931 	 Acc: 0.6073555946350098
Train Loss :1.3067632913589478


 17%|█▋        | 48/276 [02:22<11:31,  3.03s/it]

Train Dice 1: 0.028193339705467224, 2: 0.04718515329793849 	 Acc: 0.5820512771606445
Train Loss :1.3136541843414307


 18%|█▊        | 49/276 [02:25<11:07,  2.94s/it]

Train Dice 1: 0.10936244577169418, 2: 0.17951374712076976 	 Acc: 0.5937099456787109
Train Loss :1.2479761838912964


 18%|█▊        | 50/276 [02:28<10:41,  2.84s/it]

Train Dice 1: 0.15680889785289764, 2: 0.29424969019997665 	 Acc: 0.648590087890625
Train Loss :1.2158234119415283


 18%|█▊        | 51/276 [02:31<11:39,  3.11s/it]

Train Dice 1: 0.01727902516722679, 2: 0.04103795260379057 	 Acc: 0.6099205017089844
Train Loss :1.3312307596206665


 19%|█▉        | 52/276 [02:34<10:57,  2.94s/it]

Train Dice 1: 0.09740380197763443, 2: 0.1001908594561541 	 Acc: 0.6063656806945801
Train Loss :1.2870851755142212


 19%|█▉        | 53/276 [02:37<10:37,  2.86s/it]

Train Dice 1: 0.059994425624608994, 2: 0.07453578284873325 	 Acc: 0.5706033706665039
Train Loss :1.3202948570251465


 20%|█▉        | 54/276 [02:41<12:08,  3.28s/it]

Train Dice 1: 0.05364677682518959, 2: 0.0803351872989298 	 Acc: 0.5892939567565918
Train Loss :1.281049132347107


 20%|█▉        | 55/276 [02:45<12:32,  3.41s/it]

Train Dice 1: 0.0500563345849514, 2: 0.08186059948826505 	 Acc: 0.5878071784973145
Train Loss :1.380062222480774


 20%|██        | 56/276 [02:51<15:51,  4.33s/it]

Train Dice 1: 0.02740693837404251, 2: 0.03006043357178951 	 Acc: 0.5917844772338867
Train Loss :1.3233059644699097


 21%|██        | 57/276 [02:54<13:59,  3.83s/it]

Train Dice 1: 0.1109350398182869, 2: 0.11631145961614181 	 Acc: 0.5976133346557617
Train Loss :1.3070387840270996


 21%|██        | 58/276 [02:56<12:29,  3.44s/it]

Train Dice 1: 0.2975292205810547, 2: 0.23980457759011498 	 Acc: 0.6129932403564453
Train Loss :1.2329202890396118


 21%|██▏       | 59/276 [03:01<13:58,  3.86s/it]

Train Dice 1: 0.03843655064702034, 2: 0.04114199900007248 	 Acc: 0.5079965591430664
Train Loss :1.3304567337036133


 22%|██▏       | 60/276 [03:04<12:35,  3.50s/it]

Train Dice 1: 0.24882760643959045, 2: 0.19532495608701528 	 Acc: 0.5882353782653809
Train Loss :1.2107754945755005


 22%|██▏       | 61/276 [03:06<11:38,  3.25s/it]

Train Dice 1: 0.23252491652965546, 2: 0.19891790820270014 	 Acc: 0.5873351097106934
Train Loss :1.1995532512664795


 22%|██▏       | 62/276 [03:09<11:02,  3.09s/it]

Train Dice 1: 0.052909839898347855, 2: 0.06512365417130242 	 Acc: 0.49114227294921875
Train Loss :1.2472050189971924


 23%|██▎       | 63/276 [03:12<11:08,  3.14s/it]

Train Dice 1: 0.10921042412519455, 2: 0.13835988525746792 	 Acc: 0.4066014289855957
Train Loss :1.315247893333435


 23%|██▎       | 64/276 [03:15<10:42,  3.03s/it]

Train Dice 1: 0.09514334052801132, 2: 0.10586666534968368 	 Acc: 0.1367015838623047
Train Loss :1.2979429960250854


 24%|██▎       | 65/276 [03:19<11:52,  3.38s/it]

Train Dice 1: 0.16433188319206238, 2: 0.12305898302060674 	 Acc: 0.4750180244445801
Train Loss :1.22003972530365


 24%|██▍       | 66/276 [03:22<11:02,  3.16s/it]

Train Dice 1: 0.18096229434013367, 2: 0.14073814337225513 	 Acc: 0.37719011306762695
Train Loss :1.281083583831787


 24%|██▍       | 67/276 [03:25<10:26,  3.00s/it]

Train Dice 1: 0.15561531484127045, 2: 0.14692616213181112 	 Acc: 0.4073958396911621
Train Loss :1.2784898281097412


 25%|██▍       | 68/276 [03:27<10:01,  2.89s/it]

Train Dice 1: 0.08338932693004608, 2: 0.06692371228898425 	 Acc: 0.34444284439086914
Train Loss :1.3029754161834717


 25%|██▌       | 69/276 [03:30<09:44,  2.83s/it]

Train Dice 1: 0.04031975567340851, 2: 0.07697602769479586 	 Acc: 0.27556848526000977
Train Loss :1.2715202569961548


 25%|██▌       | 70/276 [03:33<09:36,  2.80s/it]

Train Dice 1: 0.16271689534187317, 2: 0.14750230062108524 	 Acc: 0.38378477096557617
Train Loss :1.3200029134750366


 26%|██▌       | 71/276 [03:35<09:27,  2.77s/it]

Train Dice 1: 0.08441214263439178, 2: 0.10630378594055578 	 Acc: 0.34956884384155273
Train Loss :1.4235540628433228


 26%|██▌       | 72/276 [03:38<09:07,  2.68s/it]

Train Dice 1: 0.03731360659003258, 2: 0.05639017857363475 	 Acc: 0.27161121368408203
Train Loss :1.2836627960205078


 26%|██▋       | 73/276 [03:40<09:01,  2.67s/it]

Train Dice 1: 0.11042815446853638, 2: 0.07770517204637295 	 Acc: 0.33754873275756836
Train Loss :1.2764519453048706


 27%|██▋       | 74/276 [03:45<10:25,  3.10s/it]

Train Dice 1: 0.16456982493400574, 2: 0.20557208167930363 	 Acc: 0.4060640335083008
Train Loss :1.3009974956512451


 27%|██▋       | 75/276 [03:48<10:30,  3.14s/it]

Train Dice 1: 0.0608585886657238, 2: 0.06161079726390727 	 Acc: 0.2651195526123047
Train Loss :1.398391604423523


 28%|██▊       | 76/276 [03:52<11:35,  3.48s/it]

Train Dice 1: 0.12583661079406738, 2: 0.1397519991725246 	 Acc: 0.3892698287963867
Train Loss :1.2996671199798584


 28%|██▊       | 77/276 [03:58<13:52,  4.18s/it]

Train Dice 1: 0.06491302698850632, 2: 0.07230527214212489 	 Acc: 0.35576868057250977
Train Loss :1.286017656326294


 28%|██▊       | 78/276 [04:01<12:15,  3.72s/it]

Train Dice 1: 0.1972687542438507, 2: 0.1534060953250289 	 Acc: 0.3843879699707031
Train Loss :1.3301645517349243


 29%|██▊       | 79/276 [04:04<11:49,  3.60s/it]

Train Dice 1: 0.018614472821354866, 2: 0.017907333352946538 	 Acc: 0.34737062454223633
Train Loss :1.3499524593353271


 29%|██▉       | 80/276 [04:07<10:55,  3.35s/it]

Train Dice 1: 0.06496443599462509, 2: 0.05231454347483829 	 Acc: 0.31687450408935547
Train Loss :1.3599282503128052


 29%|██▉       | 81/276 [04:09<10:10,  3.13s/it]

Train Dice 1: 0.13904736936092377, 2: 0.12731830677334688 	 Acc: 0.357269287109375
Train Loss :1.3697751760482788


 30%|██▉       | 82/276 [04:13<10:50,  3.35s/it]

Train Dice 1: 0.04990481957793236, 2: 0.04268298835695721 	 Acc: 0.3666563034057617
Train Loss :1.3349661827087402


 30%|███       | 83/276 [04:16<10:03,  3.13s/it]

Train Dice 1: 0.0855380967259407, 2: 0.0889984299901306 	 Acc: 0.350341796875
Train Loss :1.2716405391693115


 30%|███       | 84/276 [04:18<09:32,  2.98s/it]

Train Dice 1: 0.11483065038919449, 2: 0.14251485888760937 	 Acc: 0.34810829162597656
Train Loss :1.1979645490646362


 31%|███       | 85/276 [04:21<09:08,  2.87s/it]

Train Dice 1: 0.1284124255180359, 2: 0.10857182519166812 	 Acc: 0.34471940994262695
Train Loss :1.315753698348999


 31%|███       | 86/276 [04:24<08:49,  2.79s/it]

Train Dice 1: 0.15413640439510345, 2: 0.12990960570536483 	 Acc: 0.34756040573120117
Train Loss :1.216029167175293


 32%|███▏      | 87/276 [04:27<09:13,  2.93s/it]

Train Dice 1: 0.08114365488290787, 2: 0.12177044394943227 	 Acc: 0.34295225143432617
Train Loss :1.3360919952392578


 32%|███▏      | 88/276 [04:29<08:53,  2.84s/it]

Train Dice 1: 0.014436135068535805, 2: 0.018767616030923685 	 Acc: 0.30279064178466797
Train Loss :1.3247886896133423


 32%|███▏      | 89/276 [04:32<08:40,  2.78s/it]

Train Dice 1: 0.07960673421621323, 2: 0.09788816819269287 	 Acc: 0.3426494598388672
Train Loss :1.2119464874267578


 33%|███▎      | 90/276 [04:35<08:27,  2.73s/it]

Train Dice 1: 0.015417358838021755, 2: 0.015515766320854253 	 Acc: 0.3695340156555176
Train Loss :1.392266035079956


 33%|███▎      | 91/276 [04:38<09:24,  3.05s/it]

Train Dice 1: 0.03638974204659462, 2: 0.05748169850883785 	 Acc: 0.3385601043701172
Train Loss :1.1834015846252441


 33%|███▎      | 92/276 [04:41<08:56,  2.91s/it]

Train Dice 1: 0.042360514402389526, 2: 0.048452110894530796 	 Acc: 0.3661494255065918
Train Loss :1.490024447441101


 34%|███▎      | 93/276 [04:44<08:38,  2.83s/it]

Train Dice 1: 0.03934670612215996, 2: 0.031214250803824124 	 Acc: 0.30907392501831055
Train Loss :1.3573824167251587


 34%|███▍      | 94/276 [04:46<08:26,  2.79s/it]

Train Dice 1: 0.018869206309318542, 2: 0.035990337501953615 	 Acc: 0.3294086456298828
Train Loss :1.3625404834747314


 34%|███▍      | 95/276 [04:49<08:19,  2.76s/it]

Train Dice 1: 0.03139930218458176, 2: 0.031164831074748395 	 Acc: 0.3148508071899414
Train Loss :1.3650298118591309


 35%|███▍      | 96/276 [04:52<08:09,  2.72s/it]

Train Dice 1: 0.09503592550754547, 2: 0.12006084712160696 	 Acc: 0.32697296142578125
Train Loss :1.2303242683410645


 35%|███▌      | 97/276 [04:54<08:00,  2.69s/it]

Train Dice 1: 0.08369988948106766, 2: 0.06502908227716313 	 Acc: 0.35767650604248047
Train Loss :1.2312713861465454


 36%|███▌      | 98/276 [04:57<07:55,  2.67s/it]

Train Dice 1: 0.18783220648765564, 2: 0.17162646077761828 	 Acc: 0.4204578399658203
Train Loss :1.3036059141159058


 36%|███▌      | 99/276 [05:00<07:54,  2.68s/it]

Train Dice 1: 0.017945902422070503, 2: 0.030691730635031403 	 Acc: 0.36415767669677734
Train Loss :1.3275119066238403


 36%|███▌      | 100/276 [05:04<09:05,  3.10s/it]

Train Dice 1: 0.059521082788705826, 2: 0.09050920459236136 	 Acc: 0.3517613410949707
Train Loss :1.2785351276397705


 37%|███▋      | 101/276 [05:06<08:39,  2.97s/it]

Train Dice 1: 0.03414662182331085, 2: 0.06621655112937169 	 Acc: 0.3783750534057617
Train Loss :1.3632419109344482


 37%|███▋      | 102/276 [05:09<08:19,  2.87s/it]

Train Dice 1: 0.0185591708868742, 2: 0.03003226397283627 	 Acc: 0.34343910217285156
Train Loss :1.4173084497451782


 37%|███▋      | 103/276 [05:12<08:34,  2.98s/it]

Train Dice 1: 0.1704133152961731, 2: 0.17984117524191864 	 Acc: 0.5150647163391113
Train Loss :1.2552530765533447


 38%|███▊      | 104/276 [05:15<08:10,  2.85s/it]

Train Dice 1: 0.07153069972991943, 2: 0.06964885588177697 	 Acc: 0.41101551055908203
Train Loss :1.276153326034546


 38%|███▊      | 105/276 [05:17<07:53,  2.77s/it]

Train Dice 1: 0.074924536049366, 2: 0.08888488894512572 	 Acc: 0.5124397277832031
Train Loss :1.3377608060836792


 38%|███▊      | 106/276 [05:21<08:29,  3.00s/it]

Train Dice 1: 0.11115088313817978, 2: 0.11643427303426687 	 Acc: 0.5215334892272949
Train Loss :1.3448587656021118


 39%|███▉      | 107/276 [05:24<08:40,  3.08s/it]

Train Dice 1: 0.08925552666187286, 2: 0.09703806580471723 	 Acc: 0.4990091323852539
Train Loss :1.3589311838150024


 39%|███▉      | 108/276 [05:28<09:32,  3.41s/it]

Train Dice 1: 0.1294279247522354, 2: 0.18633991049490084 	 Acc: 0.5059237480163574
Train Loss :1.2500264644622803


In [None]:
t_acc

In [None]:
v_acc

In [None]:
torch.cuda.empty_cache()