## Libraries

In [None]:
%pip install pynrrd numpy torch torchvision monai

In [2]:
import monai
import os
import csv
import numpy as np
import nrrd
import torch
import PIL
import IPython.display

In [3]:
def pad(array, target_shape, value=-3024):
    shape = array.shape
    if len(shape) != len(target_shape):
        raise ValueError("Target shape does not have same amount of dimensions as input array.")
    for dim in range(len(target_shape)):
        if target_shape[dim] < shape[dim]:
            raise ValueError("Atleast one target dimension is smaller than the current dimension.")

    result = np.ones(target_shape)*value
    indices = list()
    for dim in range(len(target_shape)):
        pad_range = (target_shape[dim]-shape[dim])//2
        indices.append([pad_range, pad_range+shape[dim]])
    selection = tuple([slice(ind[0], ind[1]) for ind in indices])
    result[selection]=array
    return result, indices

def unpad(array, indices):
    if len(indices) != len(array.shape):
        raise ValueError("Number of dimensions differs between array and indices.")

    selection = tuple([slice(ind[0], ind[1]) for ind in indices])
    return array[selection]

In [4]:
def build_dict_ASOCA(data_path, mode="train"):
    # test if mode is correct
    if mode not in ["train", "test"]:
        raise ValueError(f"Please choose a mode in ['train', 'test']. Current mode is {mode}.")

    # create empty dictionary
    dicts = list()

    for clazz in ["Diseased", "Normal"]:
        if mode == "train":
            for index in range(1,21):
                image_path = os.path.join(data_path, clazz, "CTCA", f"{clazz}_{index}.nrrd")
                mask_path = os.path.join(data_path, clazz, "Annotations", f"{clazz}_{index}.nrrd")
                dicts.append({"img": image_path, "mask": mask_path})
        if mode == "test":
            if clazz == "Diseased":
                for index in range(10,20):
                    image_path = os.path.join(data_path, clazz, f"Testset_Disease", f"{index}.nrrd")
                    dicts.append({"img": image_path})
            else:
                for index in range(10):
                    image_path = os.path.join(data_path, clazz, f"Testset_{clazz}", f"{index}.nrrd")
                    dicts.append({"img": image_path})
    return dicts

class LoadASOCAData(monai.transforms.Transform):
    
    def __init__(self, keys=None):
        pass

    def __call__(self, sample):
        dicts = dict()
        image = nrrd.read(sample["img"])[0]
        image, indices = pad(image, (512,512,352))
        dicts["img"] = image.astype(np.int16) # save memory space
        dicts["indices"] = indices
        dicts["sample"] = sample # For verifying e.g. padding
        if "mask" in sample.keys():
            mask = nrrd.read(sample["mask"])[0]
            mask, indices = pad(mask, (512,512,352), value=0)
            dicts["mask"] = mask.astype(bool) # save memory space
        return dicts

In [None]:
# adjust cache_rate based on how much memory you have. test_dataset is only used for visualization / at end, so dont need to cache it
if "train_dataset" not in globals():
    train_dataset = monai.data.SmartCacheDataset(build_dict_ASOCA("ASOCA", mode="train"), transform=LoadASOCAData(), cache_rate=0.5, replace_rate=0.2)
    test_dataset = monai.data.CacheDataset(build_dict_ASOCA("ASOCA", mode="test"), transform=LoadASOCAData(), cache_rate=0)

In [None]:
# verifying padding and unpadding
image = nrrd.read("ASOCA/Normal/CTCA/Normal_3.nrrd")[0]
print(train_dataset[0]["sample"]) # should be same as above
image2 = unpad(train_dataset[0]["img"], train_dataset[0]["indices"])
(image == image2).all()

In [7]:
def RGB_mask(mask):
    result = np.zeros((*mask.shape,3))
    result[..., 1] = 128*mask
    return result

def RGB_image(image):
    result = image-image.min() # [a,b] -> [0, b-a]
    result = result/result.max()*255 # [0, b-a] -> [0,1] -> [0,255]
    return np.repeat(np.reshape(result[:,:,:], [image.shape[0],image.shape[1],image.shape[2],1]),3,axis=3)

image = train_dataset[1]["img"]
mask = train_dataset[1]["mask"]
indices = train_dataset[1]["indices"]
image = unpad(image, indices)
mask = unpad(mask, indices)
rgb_image = RGB_image(image)
rgb_mask = RGB_mask(mask)
rgb_image = 0.5*rgb_image + 0.5*rgb_mask
rgb_image = rgb_image.astype(np.uint8)
images = [PIL.Image.fromarray(rgb_image[:,:,index,:]) for index in range(image.shape[2])]
images[0].save("array.gif", save_all=True, append_images=images[1:])

In [None]:
display(IPython.display.Image(data=open("array.gif",'rb').read(), format='png'))