<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Classwise_DICE_3D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Class-wise Dice from 3D Patch
Challenge-1: Patch to full volume conversion <br>
Challenge-2: Saving full volume as same header as MRI/GT<br>
Challenge-3: Calculating Class-wise Dice Score

Downloading SVLS and Surface Dice codes

In [1]:
!rm -rf SVLS
!git clone https://github.com/mobarakol/SVLS.git
%cd SVLS
!git clone https://github.com/deepmind/surface-distance.git
!mv surface-distance surface_distance

Cloning into 'SVLS'...
remote: Enumerating objects: 400, done.[K
remote: Counting objects: 100% (400/400), done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 400 (delta 351), reused 372 (delta 336), pack-reused 0 (from 0)[K
Receiving objects: 100% (400/400), 120.63 KiB | 17.23 MiB/s, done.
Resolving deltas: 100% (351/351), done.
/content/SVLS
Cloning into 'surface-distance'...
remote: Enumerating objects: 50, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 50 (delta 9), reused 9 (delta 8), pack-reused 35 (from 1)[K
Receiving objects: 100% (50/50), 38.20 KiB | 19.10 MiB/s, done.
Resolving deltas: 100% (22/22), done.


Install require packages:

In [2]:
!pip install -U -q SimpleITK

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.3/52.3 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25h

### Download Dataset and Trained Models

In [3]:
!gdown 1sI9SzGmKw1tLdRNhNROH7MKHdljKLSFi
!gdown 1oZ9z-l9lBjKGNZCCTK819Z1ufwe90mg3

!unzip -q ckpt_brats19.zip
!unzip -q train_valid.zip

Downloading...
From (original): https://drive.google.com/uc?id=1sI9SzGmKw1tLdRNhNROH7MKHdljKLSFi
From (redirected): https://drive.google.com/uc?id=1sI9SzGmKw1tLdRNhNROH7MKHdljKLSFi&confirm=t&uuid=8ae5f78c-b890-4132-be0f-56c2b23aea2e
To: /content/SVLS/ckpt_brats19.zip
100% 64.0M/64.0M [00:02<00:00, 24.5MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1oZ9z-l9lBjKGNZCCTK819Z1ufwe90mg3
From (redirected): https://drive.google.com/uc?id=1oZ9z-l9lBjKGNZCCTK819Z1ufwe90mg3&confirm=t&uuid=8fd84ee2-0892-4377-9aa1-5b57aac45f4d
To: /content/SVLS/train_valid.zip
100% 535M/535M [00:04<00:00, 124MB/s] 


Demo

Saving the prediction: including padding(default volume 155 244 244)

In [4]:
import argparse
import os
import numpy as np
import pathlib
import torch
from torch import nn
from torch.nn import functional as F
from model import UNet3D
import nibabel as nib
import SimpleITK as sitk
import random
from glob import glob
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.utils.data._utils.collate import default_collate
from datasets import custom_collate, determinist_collate, pad_batch_to_max_shape, \
pad_batch1_to_compatible_size, irm_min_max_preprocess, pad_or_crop_image

import warnings
warnings.filterwarnings("ignore")


class EDiceLoss(nn.Module):
    """Dice loss tailored.
    """

    def __init__(self, do_sigmoid=True):
        super(EDiceLoss, self).__init__()
        self.do_sigmoid = do_sigmoid
        self.device = "cpu"

    def binary_dice(self, inputs, targets, metric_mode=False):
        smooth = 1.
        if metric_mode:
            if targets.sum() == 0:
                if inputs.sum() == 0:
                    return torch.tensor(1., device="cuda")
                else:
                    return torch.tensor(0., device="cuda")
        # Threshold the pred
        intersection = EDiceLoss.compute_intersection(inputs, targets)
        if metric_mode:
            dice = (2 * intersection) / ((inputs.sum() + targets.sum()) * 1.0)
        else:
            dice = (2 * intersection + smooth) / (inputs.pow(2).sum() + targets.pow(2).sum() + smooth)
        if metric_mode:
            return dice
        return 1 - dice

    @staticmethod
    def compute_intersection(inputs, targets):
        intersection = torch.sum(inputs * targets)
        return intersection

    def metric_classwise(self, inputs, target):
        dices = []
        for j in range(target.size(0)):
            dice = []
            dice.append(self.binary_dice(inputs[j]==1, target[j]==1, True))
            dice.append(self.binary_dice(inputs[j]==2, target[j]==2, True))
            dice.append(self.binary_dice(inputs[j]==3, target[j]==3, True))
            dices.append(dice)

        return dices

def get_datasets_brats(data_root=None, normalisation="zscore"):

    data_root = pathlib.Path(data_root)
    base_folder_train = pathlib.Path('data/BraTS19/train_train/').resolve()
    base_folder_valid = pathlib.Path('data/BraTS19/train_valid/').resolve()
    patients_dir_train = sorted([data_root/x.name for x in base_folder_train.iterdir() if (data_root/x.name).is_dir()])
    patients_dir_valid = sorted([data_root/x.name for x in base_folder_valid.iterdir() if (data_root/x.name).is_dir()])
    train_dataset = brats19(patients_dir_train, training=True, normalisation=normalisation)
    val_dataset = brats19(patients_dir_valid, training=False, normalisation=normalisation)
    return train_dataset, val_dataset


class brats19(Dataset):
    def __init__(self, patients_dir, training=True, no_seg=False, normalisation="minmax"):
        super(brats19, self).__init__()
        self.normalisation = normalisation
        self.training = training
        self.datas = []
        self.validation = no_seg
        self.patterns = [ "_flair", "_t1", "_t1ce", "_t2"]
        self.mean = dict(flair=0.0860377, t1=0.1216296, t1ce=0.07420689, t2=0.09033176)
        if not no_seg:
            self.patterns += ["_seg"]
        for patient_dir in patients_dir:
            patient_id = patient_dir.name
            paths = [patient_dir / f"{patient_id}{value}.nii.gz" for value in self.patterns]
            patient = dict(
                id=patient_id, flair=paths[0], t1=paths[1], t1ce=paths[2],
                t2=paths[3], seg=paths[4] if not no_seg else None
            )
            self.datas.append(patient)

    def __getitem__(self, idx):
        _patient = self.datas[idx]
        patient_image = {key: self.load_nii(_patient[key]) for key in _patient if key not in ["id", "seg"]}
        if _patient["seg"] is not None:
            patient_label = self.load_nii(_patient["seg"])

        patient_image = {key: (irm_min_max_preprocess(patient_image[key]) - self.mean[key]) for key in patient_image}
        patient_image = np.stack([patient_image[key] for key in patient_image])
        patient_label[patient_label==4] = 3
        patient_label = patient_label[None,:,:,:]

        # Remove maximum extent of the zero-background to make future crop more useful
        z_indexes, y_indexes, x_indexes = np.nonzero(np.sum(patient_image, axis=0) != 0)

        # Add 1 pixel in each side
        zmin, ymin, xmin = [max(0, int(np.min(arr) - 1)) for arr in (z_indexes, y_indexes, x_indexes)]
        zmax, ymax, xmax = [int(np.max(arr) + 1) for arr in (z_indexes, y_indexes, x_indexes)]
        patient_image, patient_label = patient_image.astype("float16"), patient_label.astype("long")
        patient_image, patient_label = [torch.from_numpy(x) for x in [patient_image, patient_label]]
        return dict(patient_id=_patient["id"],
                    image=patient_image, label=patient_label,
                    seg_path=str(_patient["seg"]),
                    crop_indexes=((zmin, zmax), (ymin, ymax), (xmin, xmax)),
                    )

    @staticmethod
    def load_nii(path_folder):
        return sitk.GetArrayFromImage(sitk.ReadImage(str(path_folder)))

    def __len__(self):
        return len(self.datas)


def save_prediction_to_mri(args, predictions, patient_ids):
    os.makedirs('predicted_segs', exist_ok=True)
    for idx, patient_id in enumerate(patient_ids):
        path = os.path.join(args.data_root, patient_id, patient_id+'_t1.nii.gz')
        img_original = nib.load(path)
        img_nifti = nib.Nifti1Image(predictions[idx], img_original.affine, header=img_original.header)
        nib.save(img_nifti,'predicted_segs/'+patient_id+'_pred.nii.gz') #not as expected


parser = argparse.ArgumentParser(description='SVLS Brats Training')
parser.add_argument('--batch_size', default=2, type=int,help='mini-batch size')
parser.add_argument('--num_classes', default=4, type=int, help="num of classes")
parser.add_argument('--in_channels', default=4, type=int, help="num of input channels")
parser.add_argument('--train_option', default='SVLS', help="options:[SVLS, LS, OH]")
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')
parser.add_argument('--data_root', default='train_valid', help='data directory')
parser.add_argument('--ckpt_dir', default='ckpt_brats19', help='ckpt directory')
args = parser.parse_args(args=[])


_, val_dataset = get_datasets_brats(data_root=args.data_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
    pin_memory=False, num_workers=2)

print('valid sample:',len(val_dataset), 'valid minibatch:',len(val_loader))
model = UNet3D(inplanes=args.in_channels, num_classes=args.num_classes).cuda()
model = torch.nn.DataParallel(model)
criterion_dice = EDiceLoss().cuda()
model.load_state_dict(torch.load(os.path.join(args.ckpt_dir, 'best_oh.pth.tar')))
model.eval()
with torch.no_grad():
    metrics = []
    for i, batch in enumerate(val_loader):
        targets = batch["label"].squeeze(1).cuda(non_blocking=True)
        inputs = batch["image"].float().cuda()
        preds = model(inputs)
        preds = preds.data.max(1)[1].squeeze_(1)
        if len(targets.shape) < 4:#if batch size=1
            targets = targets.unsqueeze(0)
        metric_ = criterion_dice.metric_classwise(preds, targets)
        metrics.extend(metric_)
        preds = preds.permute(0,3,2,1).detach().cpu().numpy()
        save_prediction_to_mri(args, preds, batch['patient_id'])

    metrics = list(zip(*metrics))
    metrics = [torch.tensor(dice, device="cpu").numpy() for dice in metrics]
    avg_dices = np.mean(metrics,1)

print('dice[Class-1:%.3f, Class-2:%.3f, Class-3:%.3f]'%(avg_dices[0], avg_dices[1], avg_dices[2]))


valid sample: 66 valid minibatch: 33
dice[Class-1:0.589, Class-2:0.774, Class-3:0.776]


#Obtaining Class-wise DICE metrics from saved 3D prediction and 3D GT mask

In [30]:
from glob import glob
from torch import nn
import numpy as np

class EDiceLoss(nn.Module):
    """Dice loss tailored.
    """

    def __init__(self, do_sigmoid=True):
        super(EDiceLoss, self).__init__()
        self.do_sigmoid = do_sigmoid
        self.device = "cpu"

    def binary_dice(self, inputs, targets, metric_mode=False):
        smooth = 1.
        if metric_mode:
            if targets.sum() == 0:
                if inputs.sum() == 0:
                    return torch.tensor(1., device="cuda")
                else:
                    return torch.tensor(0., device="cuda")
        # Threshold the pred
        intersection = EDiceLoss.compute_intersection(inputs, targets)
        if metric_mode:
            dice = (2 * intersection) / ((inputs.sum() + targets.sum()) * 1.0)
        else:
            dice = (2 * intersection + smooth) / (inputs.pow(2).sum() + targets.pow(2).sum() + smooth)
        if metric_mode:
            return dice
        return 1 - dice

    @staticmethod
    def compute_intersection(inputs, targets):
        intersection = torch.sum(inputs * targets)
        return intersection

    def metric_classwise(self, inputs, target):
        dices = []
        for j in range(target.size(0)):
            dice = []
            dice.append(self.binary_dice(inputs[j]==1, target[j]==1, True))
            dice.append(self.binary_dice(inputs[j]==2, target[j]==2, True))
            dice.append(self.binary_dice(inputs[j]==3, target[j]==3, True))
            dices.append(dice)

        return dices



gt_dir_all = glob('/content/SVLS/train_valid/*/*_seg.nii.gz')
metrics = []
for path_gt in gt_dir_all:
    patient_label = sitk.GetArrayFromImage(sitk.ReadImage(str(path_gt)))
    patient_label[patient_label==4] = 3 # only need for BraTS19 dataset class 3 annoted as 4
    pred_mask_dir = (os.path.dirname(path_gt)+'_pred.nii.gz').replace('train_valid', 'predicted_segs')
    pred_mask = sitk.GetArrayFromImage(sitk.ReadImage(str(pred_mask_dir)))
    metric_ = criterion_dice.metric_classwise(torch.tensor(pred_mask).unsqueeze(0), torch.tensor(patient_label).unsqueeze(0))
    metrics.extend(metric_)


metrics = list(zip(*metrics))
metrics = [torch.tensor(dice, device="cpu").numpy() for dice in metrics]
avg_dices = np.mean(metrics,1)

print('dice[Class-1:%.3f, Class-2:%.3f, Class-3:%.3f]'%(avg_dices[0], avg_dices[1], avg_dices[2]))

dice[Class-1:0.589, Class-2:0.774, Class-3:0.776]
