<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/BraTS_MRI_Prediction_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Spatially Varying Label Smoothing: Capturing Uncertainty from Expert Annotations
[Preprint](https://arxiv.org/pdf/2104.05788.pdf)
[Code](https://github.com/mobarakol/SVLS)

Downloading SVLS and Surface Dice codes

In [1]:
!rm -rf SVLS
!git clone https://github.com/mobarakol/SVLS.git
%cd SVLS
!git clone https://github.com/deepmind/surface-distance.git
!mv surface-distance surface_distance

Cloning into 'SVLS'...
remote: Enumerating objects: 400, done.[K
remote: Counting objects: 100% (400/400), done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 400 (delta 351), reused 372 (delta 336), pack-reused 0[K
Receiving objects: 100% (400/400), 120.63 KiB | 5.48 MiB/s, done.
Resolving deltas: 100% (351/351), done.
/content/SVLS
Cloning into 'surface-distance'...
remote: Enumerating objects: 50, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 50 (delta 11), reused 8 (delta 7), pack-reused 33[K
Receiving objects: 100% (50/50), 36.67 KiB | 3.05 MiB/s, done.
Resolving deltas: 100% (24/24), done.


Install require packages:

In [2]:
!pip install -U -q SimpleITK

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25h

### Download Dataset and Trained Models

In [3]:
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

Trained Models: https://drive.google.com/file/d/1evE2VqBGdY-0VPB8OeArHMPdRXhWuxFm/view?usp=sharing <br>
Validation Data: https://drive.google.com/file/d/1oZ9z-l9lBjKGNZCCTK819Z1ufwe90mg3/view?usp=sharing

In [4]:
ids = ['1evE2VqBGdY-0VPB8OeArHMPdRXhWuxFm', '1oZ9z-l9lBjKGNZCCTK819Z1ufwe90mg3']
zip_files = ['ckpt_brats19.zip','train_valid.zip']
for id, zip_file in zip(ids, zip_files):
    downloaded = drive.CreateFile({'id':id})
    downloaded.GetContentFile(zip_file)
    !unzip -q $zip_file

Download BraTS validation set

In [4]:
!mkdir brats_valid
ids = ['1dZDw0wFDTIZlAbFEqSAfMaKOAlEG_Fat', '1evE2VqBGdY-0VPB8OeArHMPdRXhWuxFm']
zip_files = ['ASNR-MICCAI-BraTS2023-GLI-Challenge-ValidationData.zip', 'ckpt_brats19.zip']
for id, zip_file in zip(ids, zip_files):
    downloaded = drive.CreateFile({'id':id})
    downloaded.GetContentFile(zip_file)
    !unzip -q $zip_file -d brats_valid

Prediction to MRI:

In [None]:
import argparse
import os
import numpy as np
import pathlib
import torch
from torch.nn import functional as F
from model import UNet3D
from datasets import get_datasets_brats
from utils import seed_everything, EDiceLoss
import nibabel as nib
from google.colab import files
import numpy as np
import SimpleITK as sitk
import random
from glob import glob
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.utils.data._utils.collate import default_collate
from datasets import custom_collate, determinist_collate, pad_batch_to_max_shape, \
pad_batch1_to_compatible_size, irm_min_max_preprocess, pad_or_crop_image

import warnings
warnings.filterwarnings("ignore")


parser = argparse.ArgumentParser(description='SVLS Brats Training')
parser.add_argument('--batch_size', default=2, type=int,help='mini-batch size')
parser.add_argument('--num_classes', default=4, type=int, help="num of classes")
parser.add_argument('--in_channels', default=4, type=int, help="num of input channels")
parser.add_argument('--train_option', default='SVLS', help="options:[SVLS, LS, OH]")
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')
parser.add_argument('--data_root', default='train_valid', help='data directory')
parser.add_argument('--ckpt_dir', default='ckpt_brats19', help='ckpt directory')
args = parser.parse_args(args=[])

def get_datasets_brats(data_root=None, normalisation="zscore", no_seg=False):
    data_rootn = data_root
    data_root = pathlib.Path(data_root)
    base_folder_valid = pathlib.Path(data_root).resolve()
    subdirectories = [x for x in base_folder_valid.iterdir() if x.is_dir()]
    patients_dir_valid = sorted([data_root/x.name for x in base_folder_valid.iterdir() if (data_root/x.name).is_dir()])
    val_dataset = brats19(patients_dir_valid, training=False, normalisation=normalisation, no_seg=no_seg)
    return val_dataset


class brats19(Dataset):
    def __init__(self, patients_dir, training=True, no_seg=False, normalisation="minmax"):
        super(brats19, self).__init__()
        self.normalisation = normalisation
        self.training = training
        self.datas = []
        self.validation = no_seg
        self.patterns = [ "-t2f", "-t1n", "-t1c", "-t2w"]
        self.mean = dict(flair=0.0860377, t1=0.1216296, t1ce=0.07420689, t2=0.09033176)
        if not no_seg:
            self.patterns += ["_seg"]
        for patient_dir in patients_dir:
            patient_id = patient_dir.name
            paths = [patient_dir / f"{patient_id}{value}.nii.gz" for value in self.patterns]
            patient = dict(
                id=patient_id, flair=paths[0], t1=paths[1], t1ce=paths[2],
                t2=paths[3], seg=paths[4] if not no_seg else None
            )
            self.datas.append(patient)

    def __getitem__(self, idx):
        _patient = self.datas[idx]
        patient_image = {key: self.load_nii(_patient[key]) for key in _patient if key not in ["id", "seg"]}

        patient_image = {key: (irm_min_max_preprocess(patient_image[key]) - self.mean[key]) for key in patient_image}
        patient_image = np.stack([patient_image[key] for key in patient_image])
        patient_label = torch.zeros(patient_image[0].shape)

        # Remove maximum extent of the zero-background to make future crop more useful
        z_indexes, y_indexes, x_indexes = np.nonzero(np.sum(patient_image, axis=0) != 0)

        # Add 1 pixel in each side
        zmin, ymin, xmin = [max(0, int(np.min(arr) - 1)) for arr in (z_indexes, y_indexes, x_indexes)]
        zmax, ymax, xmax = [int(np.max(arr) + 1) for arr in (z_indexes, y_indexes, x_indexes)]
        patient_image = patient_image[:, 13:141, 24:216, 24:216]
        if _patient["seg"] is not None:
            patient_label = self.load_nii(_patient["seg"])
            patient_label = patient_label[None,:,:,:]
            patient_label = patient_label[:, 13:141, 24:216, 24:216]
            patient_label = patient_label.astype("long")
            patient_label = torch.from_numpy(patient_label)

        patient_image = patient_image.astype("float16")
        patient_image = torch.from_numpy(np.array(patient_image))
        # print(patient_image.shape)
        return dict(patient_id=_patient["id"],
                    image=patient_image, label=patient_label,
                    seg_path=str(_patient["seg"]),
                    crop_indexes=((zmin, zmax), (ymin, ymax), (xmin, xmax)),
                    )

    @staticmethod
    def load_nii(path_folder):
        return sitk.GetArrayFromImage(sitk.ReadImage(str(path_folder)))

    def __len__(self):
        return len(self.datas)


def save_prediction_to_mri(args, predictions, patient_ids):
    os.makedirs('predicted_segs', exist_ok=True)
    for idx, patient_id in enumerate(patient_ids):
        path = os.path.join(args.data_root, patient_id, patient_id+'-t1n.nii.gz')
        img_original = nib.load(path)
        img_nifti = nib.Nifti1Image(predictions[idx], img_original.affine, header=img_original.header)
        nib.save(img_nifti,'predicted_segs/'+patient_id+'.nii.gz') #not as expected

args.data_root = '/content/SVLS/brats_valid/ASNR-MICCAI-BraTS2023-GLI-Challenge-ValidationData/'
val_dataset = get_datasets_brats(data_root=args.data_root, no_seg=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
    pin_memory=False, num_workers=2)

print('valid sample:',len(val_dataset), 'valid minibatch:',len(val_loader))
model = UNet3D(inplanes=args.in_channels, num_classes=args.num_classes).cuda()
model = torch.nn.DataParallel(model)
criterion_dice = EDiceLoss().cuda()
model.load_state_dict(torch.load(os.path.join(args.ckpt_dir, 'best_oh.pth.tar')))
model.eval()
with torch.no_grad():
    metrics = []
    for i, batch in enumerate(val_loader):
        targets = batch["label"].squeeze(1).cuda(non_blocking=True)
        inputs = batch["image"].float().cuda()
        preds = model(inputs)
        pred_full = torch.zeros(targets.shape)
        preds = preds.data.max(1)[1].squeeze_(1)
        pred_full[:,13:141, 24:216, 24:216] = preds.cpu()
        pred_full = pred_full.permute(0,3,2,1).detach().cpu().numpy()
        save_prediction_to_mri(args, pred_full, batch['patient_id'])


!zip -r predicted_segs.zip predicted_segs
files.download('predicted_segs.zip')

valid sample: 219 valid minibatch: 110


Demo

Method-1: Saving the prediction: same size as training

In [32]:
import argparse
import os
import numpy as np
import pathlib
import torch
from torch.nn import functional as F
from model import UNet3D
from datasets import get_datasets_brats
from utils import seed_everything, EDiceLoss
import nibabel as nib
import numpy as np
import SimpleITK as sitk
import random
from glob import glob
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.utils.data._utils.collate import default_collate
from datasets import custom_collate, determinist_collate, pad_batch_to_max_shape, \
pad_batch1_to_compatible_size, irm_min_max_preprocess, pad_or_crop_image

import warnings
warnings.filterwarnings("ignore")


parser = argparse.ArgumentParser(description='SVLS Brats Training')
parser.add_argument('--batch_size', default=2, type=int,help='mini-batch size')
parser.add_argument('--num_classes', default=4, type=int, help="num of classes")
parser.add_argument('--in_channels', default=4, type=int, help="num of input channels")
parser.add_argument('--train_option', default='SVLS', help="options:[SVLS, LS, OH]")
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')
parser.add_argument('--data_root', default='train_valid', help='data directory')
parser.add_argument('--ckpt_dir', default='ckpt_brats19', help='ckpt directory')
args = parser.parse_args(args=[])

def get_datasets_brats(data_root=None, normalisation="zscore"):

    data_root = pathlib.Path(data_root)
    base_folder_train = pathlib.Path('data/BraTS19/train_train/').resolve()
    base_folder_valid = pathlib.Path('data/BraTS19/train_valid/').resolve()
    patients_dir_train = sorted([data_root/x.name for x in base_folder_train.iterdir() if (data_root/x.name).is_dir()])
    patients_dir_valid = sorted([data_root/x.name for x in base_folder_valid.iterdir() if (data_root/x.name).is_dir()])
    train_dataset = brats19(patients_dir_train, training=True, normalisation=normalisation)
    val_dataset = brats19(patients_dir_valid, training=False, normalisation=normalisation)
    return train_dataset, val_dataset


class brats19(Dataset):
    def __init__(self, patients_dir, training=True, no_seg=False, normalisation="minmax"):
        super(brats19, self).__init__()
        self.normalisation = normalisation
        self.training = training
        self.datas = []
        self.validation = no_seg
        self.patterns = [ "_flair", "_t1", "_t1ce", "_t2"]
        self.mean = dict(flair=0.0860377, t1=0.1216296, t1ce=0.07420689, t2=0.09033176)
        if not no_seg:
            self.patterns += ["_seg"]
        for patient_dir in patients_dir:
            patient_id = patient_dir.name
            paths = [patient_dir / f"{patient_id}{value}.nii.gz" for value in self.patterns]
            patient = dict(
                id=patient_id, flair=paths[0], t1=paths[1], t1ce=paths[2],
                t2=paths[3], seg=paths[4] if not no_seg else None
            )
            self.datas.append(patient)

    def __getitem__(self, idx):
        _patient = self.datas[idx]
        patient_image = {key: self.load_nii(_patient[key]) for key in _patient if key not in ["id", "seg"]}
        if _patient["seg"] is not None:
            patient_label = self.load_nii(_patient["seg"])

        patient_image = {key: (irm_min_max_preprocess(patient_image[key]) - self.mean[key]) for key in patient_image}
        patient_image = np.stack([patient_image[key] for key in patient_image])
        patient_label[patient_label==4] = 3
        patient_label = patient_label[None,:,:,:]

        # Remove maximum extent of the zero-background to make future crop more useful
        z_indexes, y_indexes, x_indexes = np.nonzero(np.sum(patient_image, axis=0) != 0)

        # Add 1 pixel in each side
        zmin, ymin, xmin = [max(0, int(np.min(arr) - 1)) for arr in (z_indexes, y_indexes, x_indexes)]
        zmax, ymax, xmax = [int(np.max(arr) + 1) for arr in (z_indexes, y_indexes, x_indexes)]
        patient_image = patient_image[:, 13:141, 24:216, 24:216]
        patient_image, patient_label = patient_image.astype("float16"), patient_label.astype("long")
        patient_image, patient_label = [torch.from_numpy(x) for x in [patient_image, patient_label]]
        return dict(patient_id=_patient["id"],
                    image=patient_image, label=patient_label,
                    seg_path=str(_patient["seg"]),
                    crop_indexes=((zmin, zmax), (ymin, ymax), (xmin, xmax)),
                    )

    @staticmethod
    def load_nii(path_folder):
        return sitk.GetArrayFromImage(sitk.ReadImage(str(path_folder)))

    def __len__(self):
        return len(self.datas)


def save_prediction_to_mri(args, predictions, patient_ids):
    os.makedirs('predicted_segs', exist_ok=True)
    for idx, patient_id in enumerate(patient_ids):
        path = os.path.join(args.data_root, patient_id, patient_id+'_t1.nii.gz')
        img_original = nib.load(path)
        img_nifti = nib.Nifti1Image(predictions[idx], img_original.affine, header=img_original.header)
        nib.save(img_nifti,'predicted_segs/'+patient_id+'_pred.nii.gz') #not as expected

_, val_dataset = get_datasets_brats(data_root=args.data_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
    pin_memory=False, num_workers=2)

print('valid sample:',len(val_dataset), 'valid minibatch:',len(val_loader))
model = UNet3D(inplanes=args.in_channels, num_classes=args.num_classes).cuda()
model = torch.nn.DataParallel(model)
criterion_dice = EDiceLoss().cuda()
model.load_state_dict(torch.load(os.path.join(args.ckpt_dir, 'best_oh.pth.tar')))
model.eval()
with torch.no_grad():
    metrics = []
    for i, batch in enumerate(val_loader):
        targets = batch["label"].squeeze(1).cuda(non_blocking=True)
        inputs = batch["image"].float().cuda()
        preds = model(inputs)
        pred_full = torch.zeros(targets.shape)
        print(targets.shape)
        preds = preds.data.max(1)[1].squeeze_(1)
        pred_full[:,13:141, 24:216, 24:216] = preds.cpu()
        if len(targets.shape) < 4:#if batch size=1
            targets = targets.unsqueeze(0)
        metric_ = criterion_dice.metric_brats(pred_full, targets.cpu())
        metrics.extend(metric_)
        pred_full = pred_full.permute(0,3,2,1).detach().cpu().numpy()
        save_prediction_to_mri(args, pred_full, batch['patient_id'])
        break
    metrics = list(zip(*metrics))
    metrics = [torch.tensor(dice, device="cpu").numpy() for dice in metrics]
    avg_dices = np.mean(metrics,1)

print('dice[ET:%.3f, TC:%.3f, WT:%.3f]'%(avg_dices[0], avg_dices[1], avg_dices[2]))


valid sample: 66 valid minibatch: 33
torch.Size([2, 155, 240, 240])


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d834a00f250>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


dice[ET:0.402, TC:0.687, WT:0.730]


Method-2: Saving the prediction: including padding(default volume 155 244 244)

In [None]:
import argparse
import os
import numpy as np
import pathlib
import torch
from torch.nn import functional as F
from model import UNet3D
from datasets import get_datasets_brats
from utils import seed_everything, EDiceLoss
import nibabel as nib
import numpy as np
import SimpleITK as sitk
import random
from glob import glob
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.utils.data._utils.collate import default_collate
from datasets import custom_collate, determinist_collate, pad_batch_to_max_shape, \
pad_batch1_to_compatible_size, irm_min_max_preprocess, pad_or_crop_image

import warnings
warnings.filterwarnings("ignore")


parser = argparse.ArgumentParser(description='SVLS Brats Training')
parser.add_argument('--batch_size', default=2, type=int,help='mini-batch size')
parser.add_argument('--num_classes', default=4, type=int, help="num of classes")
parser.add_argument('--in_channels', default=4, type=int, help="num of input channels")
parser.add_argument('--train_option', default='SVLS', help="options:[SVLS, LS, OH]")
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')
parser.add_argument('--data_root', default='train_valid', help='data directory')
parser.add_argument('--ckpt_dir', default='ckpt_brats19', help='ckpt directory')
args = parser.parse_args(args=[])

def get_datasets_brats(data_root=None, normalisation="zscore"):

    data_root = pathlib.Path(data_root)
    base_folder_train = pathlib.Path('data/BraTS19/train_train/').resolve()
    base_folder_valid = pathlib.Path('data/BraTS19/train_valid/').resolve()
    patients_dir_train = sorted([data_root/x.name for x in base_folder_train.iterdir() if (data_root/x.name).is_dir()])
    patients_dir_valid = sorted([data_root/x.name for x in base_folder_valid.iterdir() if (data_root/x.name).is_dir()])
    train_dataset = brats19(patients_dir_train, training=True, normalisation=normalisation)
    val_dataset = brats19(patients_dir_valid, training=False, normalisation=normalisation)
    return train_dataset, val_dataset


class brats19(Dataset):
    def __init__(self, patients_dir, training=True, no_seg=False, normalisation="minmax"):
        super(brats19, self).__init__()
        self.normalisation = normalisation
        self.training = training
        self.datas = []
        self.validation = no_seg
        self.patterns = [ "_flair", "_t1", "_t1ce", "_t2"]
        self.mean = dict(flair=0.0860377, t1=0.1216296, t1ce=0.07420689, t2=0.09033176)
        if not no_seg:
            self.patterns += ["_seg"]
        for patient_dir in patients_dir:
            patient_id = patient_dir.name
            paths = [patient_dir / f"{patient_id}{value}.nii.gz" for value in self.patterns]
            patient = dict(
                id=patient_id, flair=paths[0], t1=paths[1], t1ce=paths[2],
                t2=paths[3], seg=paths[4] if not no_seg else None
            )
            self.datas.append(patient)

    def __getitem__(self, idx):
        _patient = self.datas[idx]
        patient_image = {key: self.load_nii(_patient[key]) for key in _patient if key not in ["id", "seg"]}
        if _patient["seg"] is not None:
            patient_label = self.load_nii(_patient["seg"])

        patient_image = {key: (irm_min_max_preprocess(patient_image[key]) - self.mean[key]) for key in patient_image}
        patient_image = np.stack([patient_image[key] for key in patient_image])
        patient_label[patient_label==4] = 3
        patient_label = patient_label[None,:,:,:]

        # Remove maximum extent of the zero-background to make future crop more useful
        z_indexes, y_indexes, x_indexes = np.nonzero(np.sum(patient_image, axis=0) != 0)

        # Add 1 pixel in each side
        zmin, ymin, xmin = [max(0, int(np.min(arr) - 1)) for arr in (z_indexes, y_indexes, x_indexes)]
        zmax, ymax, xmax = [int(np.max(arr) + 1) for arr in (z_indexes, y_indexes, x_indexes)]
        patient_image, patient_label = patient_image.astype("float16"), patient_label.astype("long")
        patient_image, patient_label = [torch.from_numpy(x) for x in [patient_image, patient_label]]
        return dict(patient_id=_patient["id"],
                    image=patient_image, label=patient_label,
                    seg_path=str(_patient["seg"]),
                    crop_indexes=((zmin, zmax), (ymin, ymax), (xmin, xmax)),
                    )

    @staticmethod
    def load_nii(path_folder):
        return sitk.GetArrayFromImage(sitk.ReadImage(str(path_folder)))

    def __len__(self):
        return len(self.datas)


def save_prediction_to_mri(args, predictions, patient_ids):
    os.makedirs('predicted_segs', exist_ok=True)
    for idx, patient_id in enumerate(patient_ids):
        path = os.path.join(args.data_root, patient_id, patient_id+'_t1.nii.gz')
        img_original = nib.load(path)
        img_nifti = nib.Nifti1Image(predictions[idx], img_original.affine, header=img_original.header)
        nib.save(img_nifti,'predicted_segs/'+patient_id+'_pred.nii.gz') #not as expected

_, val_dataset = get_datasets_brats(data_root=args.data_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
    pin_memory=False, num_workers=2)

print('valid sample:',len(val_dataset), 'valid minibatch:',len(val_loader))
model = UNet3D(inplanes=args.in_channels, num_classes=args.num_classes).cuda()
model = torch.nn.DataParallel(model)
criterion_dice = EDiceLoss().cuda()
model.load_state_dict(torch.load(os.path.join(args.ckpt_dir, 'best_oh.pth.tar')))
model.eval()
with torch.no_grad():
    metrics = []
    for i, batch in enumerate(val_loader):
        targets = batch["label"].squeeze(1).cuda(non_blocking=True)
        inputs = batch["image"].float().cuda()
        preds = model(inputs)
        preds = preds.data.max(1)[1].squeeze_(1)
        if len(targets.shape) < 4:#if batch size=1
            targets = targets.unsqueeze(0)
        metric_ = criterion_dice.metric_brats(preds, targets)
        metrics.extend(metric_)
        preds = preds.permute(0,3,2,1).detach().cpu().numpy()
        save_prediction_to_mri(args, preds, batch['patient_id'])

    metrics = list(zip(*metrics))
    metrics = [torch.tensor(dice, device="cpu").numpy() for dice in metrics]
    avg_dices = np.mean(metrics,1)

print('dice[ET:%.3f, TC:%.3f, WT:%.3f]'%(avg_dices[0], avg_dices[1], avg_dices[2]))


valid sample: 66 valid minibatch: 33
dice[ET:0.776, TC:0.826, WT:0.872]


Method-3: Saving the prediction: excluding padding

In [None]:
import argparse
import os
import numpy as np
import pathlib
import torch
from torch.nn import functional as F
from model import UNet3D
from datasets import get_datasets_brats
from utils import seed_everything, EDiceLoss
import nibabel as nib
import numpy as np
import SimpleITK as sitk
import random
from glob import glob
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.utils.data._utils.collate import default_collate
from datasets import custom_collate, determinist_collate, pad_batch_to_max_shape, \
pad_batch1_to_compatible_size, irm_min_max_preprocess, pad_or_crop_image

import warnings
warnings.filterwarnings("ignore")


parser = argparse.ArgumentParser(description='SVLS Brats Training')
parser.add_argument('--batch_size', default=1, type=int,help='mini-batch size')
parser.add_argument('--num_classes', default=4, type=int, help="num of classes")
parser.add_argument('--in_channels', default=4, type=int, help="num of input channels")
parser.add_argument('--train_option', default='SVLS', help="options:[SVLS, LS, OH]")
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')
parser.add_argument('--data_root', default='train_valid', help='data directory')
parser.add_argument('--ckpt_dir', default='ckpt_brats19', help='ckpt directory')
args = parser.parse_args(args=[])

def get_datasets_brats(data_root=None, normalisation="zscore"):

    data_root = pathlib.Path(data_root)
    base_folder_train = pathlib.Path('data/BraTS19/train_train/').resolve()
    base_folder_valid = pathlib.Path('data/BraTS19/train_valid/').resolve()
    patients_dir_train = sorted([data_root/x.name for x in base_folder_train.iterdir() if (data_root/x.name).is_dir()])
    patients_dir_valid = sorted([data_root/x.name for x in base_folder_valid.iterdir() if (data_root/x.name).is_dir()])
    train_dataset = brats19(patients_dir_train, training=True, normalisation=normalisation)
    val_dataset = brats19(patients_dir_valid, training=False, normalisation=normalisation)
    return train_dataset, val_dataset


class brats19(Dataset):
    def __init__(self, patients_dir, training=True, no_seg=False, normalisation="minmax"):
        super(brats19, self).__init__()
        self.normalisation = normalisation
        self.training = training
        self.datas = []
        self.validation = no_seg
        self.patterns = [ "_flair", "_t1", "_t1ce", "_t2"]
        # self.mean = dict(flair=0.0860377, t1=0.1216296, t1ce=0.07420689, t2=0.09033176)
        self.mean = [0.0860377, 0.1216296, 0.07420689, 0.09033176]
        if not no_seg:
            self.patterns += ["_seg"]
        for patient_dir in patients_dir:
            patient_id = patient_dir.name
            paths = [patient_dir / f"{patient_id}{value}.nii.gz" for value in self.patterns]
            patient = dict(
                id=patient_id, flair=paths[0], t1=paths[1], t1ce=paths[2],
                t2=paths[3], seg=paths[4] if not no_seg else None
            )
            self.datas.append(patient)

    def __getitem__(self, idx):
        _patient = self.datas[idx]
        patient_image = {key: self.load_nii(_patient[key]) for key in _patient if key not in ["id", "seg"]}
        if _patient["seg"] is not None:
            patient_label = self.load_nii(_patient["seg"])

        # Remove maximum extent of the zero-background to make future crop more useful
        patient_image = np.stack([patient_image[key] for key in patient_image])
        z_indexes, y_indexes, x_indexes = np.nonzero(np.sum(patient_image, axis=0) != 0)

        # Add 1 pixel in each side
        zmin, ymin, xmin = [max(0, int(np.min(arr) - 1)) for arr in (z_indexes, y_indexes, x_indexes)]
        zmax, ymax, xmax = [int(np.max(arr) + 1) for arr in (z_indexes, y_indexes, x_indexes)]
        patient_image = [(irm_min_max_preprocess(patient_image[mod_idx]) - self.mean[mod_idx]) for mod_idx, key in enumerate(patient_image)]
        patient_image = np.stack(patient_image)
        patient_label[patient_label==4] = 3
        patient_label = patient_label[None,:,:,:]
        patient_image = patient_image[:, zmin:zmax, ymin:ymax, xmin:xmax]
        patient_image, patient_label = patient_image.astype("float16"), patient_label.astype("long")
        patient_image, patient_label = [torch.from_numpy(x) for x in [patient_image, patient_label]]
        return dict(patient_id=_patient["id"],
                    image=patient_image, label=patient_label,
                    seg_path=str(_patient["seg"]),
                    crop_indexes=((zmin, zmax), (ymin, ymax), (xmin, xmax)),
                    )

    @staticmethod
    def load_nii(path_folder):
        return sitk.GetArrayFromImage(sitk.ReadImage(str(path_folder)))

    def __len__(self):
        return len(self.datas)


def save_prediction_to_mri(args, predictions, patient_ids):
    os.makedirs('predicted_segs', exist_ok=True)
    for idx, patient_id in enumerate(patient_ids):
        path = os.path.join(args.data_root, patient_id, patient_id+'_t1.nii.gz')
        img_original = nib.load(path)
        img_nifti = nib.Nifti1Image(predictions[idx], img_original.affine, header=img_original.header)
        nib.save(img_nifti,'predicted_segs/'+patient_id+'_pred.nii.gz') #not as expected

_, val_dataset = get_datasets_brats(data_root=args.data_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
    pin_memory=False, num_workers=2)

print('valid sample:',len(val_dataset), 'valid minibatch:',len(val_loader))
model = UNet3D(inplanes=args.in_channels, num_classes=args.num_classes).cuda()
model = torch.nn.DataParallel(model)
criterion_dice = EDiceLoss().cuda()
model.load_state_dict(torch.load(os.path.join(args.ckpt_dir, 'best_oh.pth.tar')))
model.eval()
with torch.no_grad():
    metrics = []
    for i, batch in enumerate(val_loader):
        targets = batch["label"].squeeze(1).cuda(non_blocking=True)
        inputs = batch["image"].float().cuda()
        ((zmin, zmax), (ymin, ymax), (xmin, xmax)) = batch["crop_indexes"]
        ((zmin, zmax), (ymin, ymax), (xmin, xmax)) = ((zmin.item(), zmax.item()), (ymin.item(), ymax.item()), (xmin.item(), xmax.item()))
        pred_full = torch.zeros(targets.shape)
        preds = model(inputs)
        preds = preds.data.max(1)[1].squeeze_(1)
        if len(targets.shape) < 4:#if batch size=1
            targets = targets.unsqueeze(0)

        pred_full[:,zmin:zmax, ymin:ymax, xmin:xmax] = preds.cpu()
        metric_ = criterion_dice.metric_brats(pred_full, targets.cpu())
        metrics.extend(metric_)
        pred_full = pred_full.permute(0,3,2,1).detach().cpu().numpy()
        save_prediction_to_mri(args, pred_full, batch['patient_id'])

    metrics = list(zip(*metrics))
    metrics = [torch.tensor(dice, device="cpu").numpy() for dice in metrics]
    avg_dices = np.mean(metrics,1)

print('dice[ET:%.3f, TC:%.3f, WT:%.3f]'%(avg_dices[0], avg_dices[1], avg_dices[2]))


valid sample: 66 valid minibatch: 66
dice[ET:0.788, TC:0.817, WT:0.860]
