In [3]:
import imp
from pydoc import cli
from traceback import print_tb
# from grpc import ClientCallDetails
import torch
from torch import nn
from torch.utils.data import Dataset
from torchmtlr import MTLR
from torchmtlr.utils import make_time_bins, encode_survival
import SimpleITK as sitk
import nibabel as nib
from joblib import Parallel, delayed
from sklearn.preprocessing import scale

from typing import Sequence, Tuple, Union, Optional, Callable, Any, List

from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
#from monai.networks.nets.vit import ViT
from monai.utils import ensure_tuple_rep
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.transforms.transform import Transform
from monai.utils.enums import TransformBackends
from monai.utils.module import look_up_option
from monai.transforms.croppad.array import Pad
from monai.utils import (
    InterpolateMode,
    NumpyPadMode,
    PytorchPadMode,
    ensure_tuple_rep
)
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
from einops import repeat, rearrange
import pathlib
from pathlib import Path
import pandas as pd
import numpy as np
import sys
import os
import random
import numpy as np
import torch
import elasticdeform
from skimage.transform import rotate



random.seed(260520)

ModuleNotFoundError: No module named 'grpc'

#TRASNFORMS

In [None]:
class Compose:
    def __init__(self, transforms=None):
        self.transforms = transforms

    def __call__(self, sample):
        for transform in self.transforms:
            sample = transform(sample)

        return sample


class ToTensor:
    def __init__(self, mode='train'):
        if mode not in ['train', 'test']:
            raise ValueError(f"Argument 'mode' must be 'train' or 'test'. Received {mode}")
        self.mode = mode

    def __call__(self, sample):
        if self.mode == 'train':
            img, mask = sample['input'], sample['target_mask']
            img = np.transpose(img, axes=[3, 0, 1, 2])
            mask = np.transpose(mask, axes=[3, 0, 1, 2])
            img = torch.from_numpy(img).float()
            mask = torch.from_numpy(mask).float()
            sample['input'], sample['target_mask'] = img, mask

        else:  # if self.mode == 'test'
            img = sample['input']
            img = np.transpose(img, axes=[3, 0, 1, 2])
            img = torch.from_numpy(img).float()
            sample['input'] = img

        return sample        


class Mirroring:
    def __init__(self, p=0.5):
        self.p = p
        random.seed(260520)

    def __call__(self, sample):
        if random.random() < self.p:
            img, mask = sample['input'], sample['target_mask']

            n_axes = random.randint(0, 3)
            random_axes = random.sample(range(3), n_axes)

            img = np.flip(img, axis=tuple(random_axes))
            mask = np.flip(mask, axis=tuple(random_axes))

            sample['input'], sample['target_mask'] = img.copy(), mask.copy()

        return sample


class NormalizeIntensity:

    def __call__(self, sample):
        img = sample['input']
        img[:, :, :, 0] = self.normalize_ct(img[:, :, :, 0])
        img[:, :, :, 1] = self.normalize_pt(img[:, :, :, 1])

        sample['input'] = img
        return sample

    @staticmethod
    def normalize_ct(img):
        norm_img = np.clip(img, -1024, 1024) / 1024
        return norm_img

    @staticmethod
    def normalize_pt(img):
        mean = np.mean(img)
        std = np.std(img)
        return (img - mean) / (std + 1e-3)


class RandomRotation:

    def __init__(self, p=0.5, angle_range=[5, 15]):
        self.p = p
        self.angle_range = angle_range
        random.seed(260520)

    def __call__(self, sample):
        if random.random() < self.p:
            img, mask = sample['input'], sample['target_mask']

            num_of_seqs = img.shape[-1]
            n_axes = random.randint(1, 3)
            random_axes = random.sample([0, 1, 2], n_axes)

            for axis in random_axes:

                angle = random.randrange(*self.angle_range)
                angle = -angle if random.random() < 0.5 else angle

                for i in range(num_of_seqs):
                    img[:, :, :, i] = RandomRotation.rotate_3d_along_axis(img[:, :, :, i], angle, axis, 1)

                mask[:, :, :, 0] = RandomRotation.rotate_3d_along_axis(mask[:, :, :, 0], angle, axis, 0)

            sample['input'], sample['target_mask'] = img, mask
        return sample

    @staticmethod
    def rotate_3d_along_axis(img, angle, axis, order):

        if axis == 0:
            rot_img = rotate(img, angle, order=order, preserve_range=True)

        if axis == 1:
            rot_img = np.transpose(img, axes=(1, 2, 0))
            rot_img = rotate(rot_img, angle, order=order, preserve_range=True)
            rot_img = np.transpose(rot_img, axes=(2, 0, 1))

        if axis == 2:
            rot_img = np.transpose(img, axes=(2, 0, 1))
            rot_img = rotate(rot_img, angle, order=order, preserve_range=True)
            rot_img = np.transpose(rot_img, axes=(1, 2, 0))

        return rot_img


class ZeroPadding:

    def __init__(self, target_shape, mode='train'):
        self.target_shape = np.array(target_shape)  # without channel dimension
        if mode not in ['train', 'test']:
            raise ValueError(f"Argument 'mode' must be 'train' or 'test'. Received {mode}")
        self.mode = mode

    def __call__(self, sample):
        if self.mode == 'train':
            img, mask = sample['input'], sample['target_mask']

            input_shape = np.array(img.shape[:-1])  # last (channel) dimension is ignored
            d_x, d_y, d_z = self.target_shape - input_shape
            d_x, d_y, d_z = int(d_x), int(d_y), int(d_z)

            if not all(i == 0 for i in (d_x, d_y, d_z)):
                positive = [i if i > 0 else 0 for i in (d_x, d_y, d_z)]
                negative = [i if i < 0 else None for i in (d_x, d_y, d_z)]

                # padding for positive values:
                img = np.pad(img, ((0, positive[0]), (0, positive[1]), (0, positive[2]), (0, 0)), 'constant', constant_values=(0, 0))
                mask = np.pad(mask, ((0, positive[0]), (0, positive[1]), (0, positive[2]), (0, 0)), 'constant', constant_values=(0, 0))

                # cropping for negative values:
                img = img[: negative[0], : negative[1], : negative[2], :].copy()
                mask = mask[: negative[0], : negative[1], : negative[2], :].copy()

                assert img.shape[:-1] == mask.shape[:-1], f'Shape mismatch for the image {img.shape[:-1]} and mask {mask.shape[:-1]}'

                sample['input'], sample['target_mask'] = img, mask

            return sample

        else:  # if self.mode == 'test'
            img = sample['input']

            input_shape = np.array(img.shape[:-1])  # last (channel) dimension is ignored
            d_x, d_y, d_z = self.target_shape - input_shape
            d_x, d_y, d_z = int(d_x), int(d_y), int(d_z)

            if not all(i == 0 for i in (d_x, d_y, d_z)):
                positive = [i if i > 0 else 0 for i in (d_x, d_y, d_z)]
                negative = [i if i < 0 else None for i in (d_x, d_y, d_z)]

                # padding for positive values:
                img = np.pad(img, ((0, positive[0]), (0, positive[1]), (0, positive[2]), (0, 0)), 'constant', constant_values=(0, 0))

                # cropping for negative values:
                img = img[: negative[0], : negative[1], : negative[2], :].copy()

                sample['input'] = img

            return sample


class ExtractPatch:
    """Extracts a patch of a given size from an image (4D numpy array)."""

    def __init__(self, patch_size, p_tumor=0.5):
        self.patch_size = patch_size  # without channel dimension!
        self.p_tumor = p_tumor  # probs to extract a patch with a tumor

    def __call__(self, sample):
        img = sample['input']
        mask = sample['target_mask']

        assert all(x <= y for x, y in zip(self.patch_size, img.shape[:-1])), \
            f"Cannot extract the patch with the shape {self.patch_size} from  " \
                f"the image with the shape {img.shape}."

        # patch_size components:
        ps_x, ps_y, ps_z = self.patch_size

        if random.random() < self.p_tumor:
            # coordinates of the tumor's center:
            xs, ys, zs, _ = np.where(mask != 0)
            tumor_center_x = np.min(xs) + (np.max(xs) - np.min(xs)) // 2
            tumor_center_y = np.min(ys) + (np.max(ys) - np.min(ys)) // 2
            tumor_center_z = np.min(zs) + (np.max(zs) - np.min(zs)) // 2

            # compute the origin of the patch:
            patch_org_x = random.randint(tumor_center_x - ps_x, tumor_center_x)
            patch_org_x = np.clip(patch_org_x, 0, img.shape[0] - ps_x)

            patch_org_y = random.randint(tumor_center_y - ps_y, tumor_center_y)
            patch_org_y = np.clip(patch_org_y, 0, img.shape[1] - ps_y)

            patch_org_z = random.randint(tumor_center_z - ps_z, tumor_center_z)
            patch_org_z = np.clip(patch_org_z, 0, img.shape[2] - ps_z)
        else:
            patch_org_x = random.randint(0, img.shape[0] - ps_x)
            patch_org_y = random.randint(0, img.shape[1] - ps_y)
            patch_org_z = random.randint(0, img.shape[2] - ps_z)

        # extract the patch:
        patch_img = img[patch_org_x: patch_org_x + ps_x,
                    patch_org_y: patch_org_y + ps_y,
                    patch_org_z: patch_org_z + ps_z,
                    :].copy()

        patch_mask = mask[patch_org_x: patch_org_x + ps_x,
                     patch_org_y: patch_org_y + ps_y,
                     patch_org_z: patch_org_z + ps_z,
                     :].copy()

        assert patch_img.shape[:-1] == self.patch_size, \
            f"Shape mismatch for the patch with the shape {patch_img.shape[:-1]}, " \
                f"whereas the required shape is {self.patch_size}."

        sample['input'] = patch_img
        sample['target_mask'] = patch_mask

        return sample


class InverseToTensor:
    def __call__(self, sample):
        output = sample['output']

        output = torch.squeeze(output)  # squeeze the batch and channel dimensions
        output = output.numpy()

        sample['output'] = output
        return sample


class CheckOutputShape:
    def __init__(self, shape=(144, 144, 144)):
        self.shape = shape

    def __call__(self, sample):
        output = sample['output']
        assert output.shape == self.shape, \
            f'Received wrong output shape. Must be {self.shape}, but received {output.shape}.'
        return sample


class ProbsToLabels:
    def __call__(self, sample):
        output = sample['output']
        output = (output > 0.5).astype(int)  # get binary label
        sample['output'] = output
        return sample

class AdjustContrast(Transform):
    """
    Changes image intensity by gamma. Each pixel/voxel intensity is updated as::
        x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
    Args:
        gamma: gamma value to adjust the contrast as function.
    """

    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

    def __init__(self, gamma: float, random=True) -> None:
        if not isinstance(gamma, (int, float)):
            raise ValueError("gamma must be a float or int number.")
        self.gamma = gamma
        self.random = random

    def __call__(self, sample):
        """
        Apply the transform to `img`.
        """
        if self.random:
            self.gamma = np.random.uniform(0.5, 2.0)

        images, mask = sample['input'], sample['target_mask']
        ct_img = images[:,:,:,0]
        pet_img = images[:,:,:,1]
        
        
        epsilon = 1e-7
        img_min = pet_img.min()
        img_range = pet_img.max() - img_min
        
        ret: NdarrayOrTensor = ((pet_img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min
        img = np.stack([ct_img, ret], axis=-1)

        sample['input'] = img

        return sample
class AdjustContrastCT(Transform):
    """
    Changes image intensity by gamma. Each pixel/voxel intensity is updated as::
        x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
    Args:
        gamma: gamma value to adjust the contrast as function.
    """

    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

    def __init__(self, gamma: float, p=0.5, random=True) -> None:
        if not isinstance(gamma, (int, float)):
            raise ValueError("gamma must be a float or int number.")
        self.p = p
        self.gamma = gamma
        self.random = random

    def __call__(self, sample):
        """
        Apply the transform to `img`.
        """
        if random.random() <self.p:
            if self.random:
                self.gamma = np.random.uniform(0.5, 2.0)

            images, mask = sample['input'], sample['target_mask']
            ct_img = images[:,:,:,0]
            pet_img = images[:,:,:,1]
            
            
            epsilon = 1e-7
            img_min = ct_img.min()
            img_range = ct_img.max() - img_min
            
            ret: NdarrayOrTensor = ((ct_img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min
            img = np.stack([ret, pet_img], axis=-1)

            sample['input'] = img

        return sample

class Zoom(Transform):
    def __init__(self, factor):
        self.factor = factor
    def __call__(self, sample):
        images, mask = sample['input'], sample['target_mask']
        ct_img = images[:,:,:,0]
        pet = images[:,:,:,1]
        mask = mask.squeeze(-1)
        
        zoomed_ct = zoom(ct_img, self.factor)
        zoomed_pet = zoom(pet, self.factor)
        
        img = np.stack([zoomed_ct, zoomed_pet], axis=-1)
        sample['input'] = img


        zoomed_mask = zoom(mask, self.factor)
        zoomed_mask[zoomed_mask<0.5] = 0
        zoomed_mask[zoomed_mask>0] = 1  
        sample['target_mask'] = np.expand_dims(zoomed_mask, axis=-1)

        return sample


def zoom(
    img,
    factor,
    padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None,
    align_corners: Optional[bool] = True,
    keep_size = True,

) -> NdarrayOrTensor:
    """
    Args:
        img: channel first array, must have shape: (num_channels, H[, W, ..., ]).
        mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
            The interpolation mode. Defaults to ``self.mode``.
            See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
        padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            The mode to pad data after zooming.
            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        align_corners: This only has an effect when mode is
            'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.
            See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
    """
    
    img_t: torch.Tensor
    img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32)  # type: ignore
    mode = InterpolateMode('bilinear')
    _zoom = ensure_tuple_rep(factor, img.ndim - 1)  # match the spatial image dim
    zoomed: NdarrayOrTensor = torch.nn.functional.interpolate(  # type: ignore
        recompute_scale_factor=True,
        input=img_t.unsqueeze(0),
        scale_factor=list(_zoom),
        mode=look_up_option(mode if mode is None else mode, InterpolateMode).value,
        align_corners=align_corners if align_corners is None else align_corners,
    )
    zoomed = zoomed.squeeze(0)

    if keep_size and not np.allclose(img_t.shape, zoomed.shape):

        pad_vec = [(0, 0)] * len(img_t.shape)
        slice_vec = [slice(None)] * len(img_t.shape)
        for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)):
            diff = od - zd
            half = abs(diff) // 2
            if diff > 0:  # need padding
                pad_vec[idx] = (half, diff - half)
            elif diff < 0:  # need slicing
                slice_vec[idx] = slice(half, half + od)

        padder = Pad(pad_vec, padding_mode or padding_mode)
        zoomed = padder(zoomed)
        zoomed = zoomed[tuple(slice_vec)]

    out, *_ = convert_to_dst_type(zoomed, dst=img)
    return out



class ElasticDeformation():
    def __init__(self, p = 0.5):
        self.p = p
    def __call__(self, sample):
        images, mask = sample['input'], sample['target_mask']
        ct_img = images[:,:,:,0]
        pet_img = images[:,:,:,1]
        if random.random()<self.p:
            new_ct, new_pet, new_mask = elasticdeform.deform_random_grid([ct_img,pet_img,mask],sigma = random.randint(5, 10), points =  random.randint(1,3),axis=(0, 1, 2))
            new_mask = (new_mask - np.min(new_mask))/(np.max(new_mask) - np.min(new_mask))
            new_mask[new_mask<0.5] = 0
            new_mask[new_mask>0] = 1
            img = np.stack([new_ct, new_pet], axis=-1)
            sample['input'], sample['target_mask'] =img,new_mask
        return sample
class ElasticDeformation():
    def __init__(self, p = 0.5):
        self.p = p
    def __call__(self, sample):
        images, mask = sample['input'], sample['target_mask']
        ct_img = images[:,:,:,0]
        pet_img = images[:,:,:,1]
        if random.random()<self.p:
            new_ct, new_pet, new_mask = elasticdeform.deform_random_grid([ct_img,pet_img,mask],sigma = random.randint(5, 10), points =  random.randint(1,3),axis=(0, 1, 2))
            new_mask = (new_mask - np.min(new_mask))/(np.max(new_mask) - np.min(new_mask))
            new_mask[new_mask<0.5] = 0
            new_mask[new_mask>0] = 1
            img = np.stack([new_ct, new_pet], axis=-1)
            sample['input'], sample['target_mask'] =img,new_mask
        return sample

#DATASET

In [None]:
def find_centroid(mask: sitk.Image):

    stats = sitk.LabelShapeStatisticsImageFilter()
    stats.Execute(mask)
    centroid_coords = stats.GetCentroid(1)
    centroid_idx = mask.TransformPhysicalPointToIndex(centroid_coords)

    return np.asarray(centroid_idx, dtype=np.float64)

def get_paths_to_patient_files(path_to_imgs, PatientID, append_mask=True):
        """
    Get paths to all data samples, i.e., CT & PET images (and a mask) for each patient.

    Parameters
    ----------
    path_to_imgs : str
        A path to a directory with patients' data. Each folder in the directory must corresponds to a single patient.
    append_mask : bool
        Used to append a path to a ground truth mask.

    Returns
    -------
    list of tuple
        A list wherein each element is a tuple with two (three) `pathlib.Path` objects for a single patient.
        The first one is the path to the CT image, the second one - to the PET image. If `append_mask` is True,
        the path to the ground truth mask is added.
    """
        path_to_imgs = pathlib.Path(path_to_imgs)

        patients = [p for p in PatientID if os.path.isdir(path_to_imgs / p)]
        paths = []
        for p in patients:
            path_to_ct = path_to_imgs / p / (p + '_ct.nii.gz')
            path_to_pt = path_to_imgs / p / (p + '_pt.nii.gz')

            if append_mask:
                path_to_mask = path_to_imgs / p / (p + '_ct_gtvt.nii.gz')
                paths.append((path_to_ct, path_to_pt, path_to_mask))
            else:
                paths.append((path_to_ct, path_to_pt))
        return paths

class HecktorDataset(Dataset):

    def __init__(self,
                 root_directory:str, 
                 clinical_data_path:str, 
                 patch_size:int =50,
                 time_bins:int = 14,
                 cache_dir:str = "data_cropped/data_cache/",
                 transform: Optional[Callable] = None,
                 num_workers: int = 1
    ):

        self.num_of_seqs = 2 #CT PT
        
        self.root_directory = root_directory
        self.patch_size = patch_size

        self.transforms = transform
        self.num_workers = num_workers

        self.clinical_data = self.make_data(clinical_data_path)
        self.time_bins = make_time_bins(times=self.clinical_data["time"], num_bins=time_bins, event = self.clinical_data["event"])
        self.y = encode_survival(self.clinical_data["time"].values, self.clinical_data["event"].values, self.time_bins) # single event

        self.cache_path = get_paths_to_patient_files(cache_dir, self.clinical_data['PatientID'])


    def make_data(self, path):

        try:
            X = pd.read_csv(path + '/hecktor2021_patient_info_training.csv')
            y = pd.read_csv(path + '/hecktor2021_patient_endpoint_training.csv')
            df = pd.merge(X, y, on="PatientID")
        except:
            df = path

        clinical_data = df
        clinical_data = clinical_data.rename(columns={"Progression": "event", "Progression free survival": "time", "TNM group":"Stage_group", "Gender (1=M,0=F)":"Gender"})

        clinical_data["Age"] = scale(clinical_data["Age"])

        # binarize T stage as T1/2 = 0, T3/4 = 1
        clinical_data["T-stage"] = clinical_data["T-stage"].map(
            lambda x: "T1/2" if x in ["T1", "T2"] else("Tx" if x == "Tx" else "T3/4"), na_action="ignore")

        # use more fine-grained grouping for N stage
        clinical_data["N-stage"] = clinical_data["N-stage"].str.slice(0, 2)

        clinical_data["Stage_group"] = clinical_data["Stage_group"].map(
            lambda x: "I/II" if x in ["I", "II"] else "III/IV", na_action="ignore")

        clinical_data = pd.get_dummies(clinical_data,
                                    columns=["Gender",
                                                "N-stage",
                                                "M-stage",],
                                    drop_first=True)

        cols_to_drop = [
            #"PatientID",
            "Tobacco",
            "Alcohol",
            "Performance status",
            "HPV status (0=-, 1=+)",
            "Estimated weight (kg) for SUV",
            "CenterID",

        ]

        clinical_data = clinical_data.drop(cols_to_drop, axis=1)


        clinical_data = pd.get_dummies(clinical_data,
                                    columns=["T-stage",
                                                "Stage_group",])
        
        return clinical_data


    

    def _prepare_data(self):
        """Preprocess and cache the dataset"""

        Parallel(n_jobs=self.num_workers)(
            delayed(self._preprocess_subject)(subject_id)
            for subject_id in self.clinical_data["PatientID"]
        )

    def _preprocess_subject(self, subject_id: str):

        print(self.root_directory)
        print(subject_id)
        
        path = os.path.join(self.root_directory, "data/hecktor_nii/"
                            "{}",f"{subject_id}"+"{}"+".nii")

        image = sitk.ReadImage(path.format("images", "_ct"))
        mask = sitk.ReadImage(path.format("masks", "_gtvt"))

        #crop the image to (patch_size)^3 patch around the tumor center
        tumour_center = find_centroid(mask)
        size = np.ceil(self.patch_size / np.asarray(image.GetSpacing())).astype(np.int) + 1
        min_coords = np.floor(tumour_center - size / 2).astype(np.int64)
        max_coords = np.floor(tumour_center + size / 2).astype(np.int64)
        min_x, min_y, min_z = min_coords
        max_x, max_y, max_z = max_coords
        image = image[min_x:max_x, min_y:max_y, min_z:max_z]

        # resample to isotropic 1 mm spacing
        reference_image = sitk.Image([self.patch_size]*3, sitk.sitkFloat32)
        reference_image.SetOrigin(image.GetOrigin())
        image = sitk.Resample(image, reference_image)

        # window image intensities to [-500, 1000] HU range
        image = sitk.Clamp(image, sitk.sitkFloat32, -500, 1000)

        sitk.WriteImage(image, os.path.join(self.cache_path, f"{subject_id}.nii"), True)


    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """Get an input-target pair from the dataset.

        The images are assumed to be preprocessed and cached.

        Parameters
        ----------
        idx
            The index to retrieve (note: this is not the subject ID).

        Returns
        -------
        tuple of torch.Tensor and int
            The input-target pair.
        """
        
        try:      # training data
            # clin_var_data = self.clinical_data.drop(["target_binary", 'time', 'event', 'Study ID'], axis=1) # single event
            clin_var_data = self.clinical_data.drop(['PatientID','time', 'event'], axis=1)
        except:   # test data
            clin_var_data = self.clinical_data.drop(['PatientID'], axis=1)


        clin_var = clin_var_data.iloc[idx].to_numpy(dtype='float32')
        
        target = self.y[idx]
        
        labels = self.clinical_data.iloc[idx].to_dict()
 
        
        subject_id = self.clinical_data.iloc[idx]["PatientID"]
        # path = self.cache_path, f"{subject_id}_ct.nii.gz")
#         print('hi:', path)
        
        # image = sitk.ReadImage(path)
        # if self.transform is not None:
        #     image = self.transform(image)
        
        
        sample = dict()
        
        id_ = self.cache_path[idx][0].parent.stem

        sample['id'] = id_
        img = [self.read_data(self.cache_path[idx][i]) for i in range(self.num_of_seqs)]
        img = np.stack(img, axis=-1)
        #img = rearrange(img,'h w d c -> c h w d')
        sample['input'] = img #np.expand_dims(img, axis=0)
        
        mask = self.read_data(self.cache_path[idx][-1])
        mask = np.expand_dims(mask, axis=3)
        #mask = rearrange(mask,'h w d c->c h w d')
        sample['target_mask'] = mask
        
        if self.transforms:
            sample = self.transforms(sample)
    
        return (sample, clin_var), target, labels
    
    

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.clinical_data)
    
    @staticmethod
    def read_data(path_to_nifti, return_numpy=True):
        """Read a NIfTI image. Return a numpy array (default) or `nibabel.nifti1.Nifti1Image` object"""
        if return_numpy:
            return nib.load(str(path_to_nifti)).get_fdata()
        return nib.load(str(path_to_nifti))


#MODEL

In [None]:
n_clin_var = 12


def flatten_layers(arr):
    return [i for sub in arr for i in sub]


class UNETR(nn.Module):
    """
    UNETR based on: "Hatamizadeh et al.,
    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
    """

    def __init__(
                    self,
                    hparams : dict,
                    in_channels: int,
                    out_channels: int,
                    img_size: Union[Sequence[int], int],
                    feature_size: int = 16,
                    hidden_size: int = 768,
                    mlp_dim: int = 3072,
                    num_heads: int = 12,
                    pos_embed: str = "conv",
                    norm_name: Union[Tuple, str] = "instance",
                    conv_block: bool = True,
                    res_block: bool = True,
                    dropout_rate: float = 0.0,
                    spatial_dims: int = 3,
                ) -> None:


        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size should be divisible by num_heads.")

        self.num_layers = 12
        img_size = ensure_tuple_rep(img_size, spatial_dims)
        self.patch_size = ensure_tuple_rep(16, spatial_dims)
        self.feat_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, self.patch_size))
        self.hidden_size = hidden_size
        self.classification = False
        self.vit = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=self.patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=self.num_layers,
            num_heads=num_heads,
            pos_embed=pos_embed,
            classification=self.classification,
            dropout_rate=dropout_rate,
            spatial_dims=spatial_dims,
        )
        
        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
        
        config = {}
        config['num_of_attention_heads'] = 12
        config['hidden_size'] = 768
        # self.msa = BertSelfAttention(config)

        if hparams['n_dense'] <=0:
            self.mtlr = MTLR(hparams['hidden_size'] , hparams['time_bins'])

        else:
            fc_layers = [[nn.Linear(hparams['hidden_size'] , 256 * hparams['dense_factor']), 
                          nn.BatchNorm1d(256 * hparams['dense_factor']),
                          nn.ReLU(inplace=True), 
                          nn.Dropout(hparams['dropout'])]]   
            
            if hparams['n_dense'] > 1:    
                fc_layers.extend([[nn.Linear(256 * hparams['dense_factor'], 64 * hparams['dense_factor']),
                                   nn.BatchNorm1d(64 * hparams['dense_factor']),
                                   nn.ReLU(inplace=True),
                                   nn.Dropout(hparams['dropout'])] for _ in range(hparams['n_dense'] - 1)])
            
            fc_layers = flatten_layers(fc_layers)
            self.mtlr = nn.Sequential(*fc_layers,
                                      MTLR(64 * hparams['dense_factor'], hparams['time_bins']),)
           
            # self.dimred = nn.Linear(768, 256)
            # self.ehrproj = nn.Linear(12, 256)
                            
            # self.mtlr = nn.Sequential(
            #                           nn.Linear(768, 512),
            #                           nn.BatchNorm1d(512),
            #                           nn.ReLU(inplace=True), 
            #                           nn.Dropout(hparams['dropout']),
            #                           nn.Linear(512, 128),
            #                           nn.BatchNorm1d(128),
            #                           nn.ReLU(inplace=True), 
            #                           nn.Dropout(hparams['dropout']),
            #                           MTLR(128, hparams['time_bins']))

    def proj_feat(self, x, hidden_size, feat_size):
        new_view = (x.size(0), *feat_size, hidden_size)
        x = x.view(new_view)
        new_axes = (0, len(x.shape) - 1) + tuple(d + 1 for d in range(len(feat_size)))
        x = x.permute(new_axes).contiguous()
        return x

    def forward(self, sample):

        sample_img, clin_var = sample
        x_in = (sample_img['input'], clin_var)
        
        x, hidden_states_out = self.vit(x_in)

        
        enc1 = self.encoder1(sample_img['input'])
        x2 = hidden_states_out[3][:,1:,:]
        enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size))
        x3 = hidden_states_out[6][:,1:,:]
        enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size))
        x4 = hidden_states_out[9][:,1:,:]
        enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size))
        dec4 = self.proj_feat(x[:,1:,:], self.hidden_size, self.feat_size)
        dec3 = self.decoder5(dec4, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        out = self.decoder2(dec1, enc1)
        
        x = torch.mean(x, dim=1)
        # x = x[:,:,:n_clin_var]
        # x = self.dimred(x)
        # clin_var = self.ehrproj(clin_var)
        # f_con = x + clin_var
        # f_con = torch.cat((clin_var, x), dim=1)        

        # x_mtlr = torch.mean(f_con, dim=1)
        # msa = self.msa(f_con)

        risk_out = self.mtlr(x)
        
        return self.out(out), risk_out 





# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Sequence, Union

import torch
import torch.nn as nn

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock

# __all__ = ["ViT"]


class ViT(nn.Module):
    """
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
    """

    def __init__(
        self,
        in_channels: int,
        img_size: Union[Sequence[int], int],
        patch_size: Union[Sequence[int], int],
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_layers: int = 12,
        num_heads: int = 12,
        pos_embed: str = "conv",
        classification: bool = False,
        num_classes: int = 2,
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer.
            num_layers: number of transformer blocks.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            classification: bool argument to determine if classification is used.
            num_classes: number of classes if classification is used.
            dropout_rate: faction of the input units to drop.
            spatial_dims: number of spatial dimensions.

        Examples::

            # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
            >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')

            # for 3-channel with image size of (128,128,128), 24 layers and classification backbone
            >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)

            # for 3-channel with image size of (224,224), 12 layers and classification backbone
            >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)

        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size should be divisible by num_heads.")

        self.classification = classification
        self.patch_embedding = PatchEmbeddingBlock(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            num_heads=num_heads,
            pos_embed=pos_embed,
            dropout_rate=dropout_rate,
            spatial_dims=spatial_dims,
        )
        self.blocks = nn.ModuleList(
            [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
        )
        self.norm = nn.LayerNorm(hidden_size)
        if self.classification:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
            self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
            
        ## Projection of EHR
        self.EHR_proj = nn.Linear(hidden_size, hidden_size)

        ## Position embedding for EHR
        #self.EHR_pos = nn.Parameter(torch.zeros(1, 1, hidden_size))

    def forward(self, x):
        
        x = self.patch_embedding(x) #img, clin_var = x

        #clin_var = self.EHR_proj(clin_var)
        #clin_var = clin_var.repeat(1,64)
        #clin_var = self.EHR_proj(clin_var)
        #clin_var = clin_var.unsqueeze(1)
        #clin_var = clin_var + self.EHR_pos
        
        #x = torch.cat((clin_var, x), dim=1)
        
        if self.classification:
            cls_token = self.cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_token, x), dim=1)
        hidden_states_out = []
        for blk in self.blocks:
            x = blk(x)
            hidden_states_out.append(x)
        x = self.norm(x)
        if self.classification:
            x = self.classification_head(x[:, 0])
            
        return x, hidden_states_out
    
    
    
    # Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from typing import Sequence, Union

import numpy as np
import torch
import torch.nn as nn

from monai.networks.layers import Conv
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"}


class PatchEmbeddingBlock(nn.Module):
    """
    A patch embedding block, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    Example::

        >>> from monai.networks.blocks import PatchEmbeddingBlock
        >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv")

    """

    def __init__(
        self,
        in_channels: int,
        img_size: Union[Sequence[int], int],
        patch_size: Union[Sequence[int], int],
        hidden_size: int,
        num_heads: int,
        pos_embed: str,
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            hidden_size: dimension of hidden layer.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            dropout_rate: faction of the input units to drop.
            spatial_dims: number of spatial dimensions.


        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise ValueError("hidden size should be divisible by num_heads.")

        self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)

        img_size = ensure_tuple_rep(img_size, spatial_dims)
        patch_size = ensure_tuple_rep(patch_size, spatial_dims)
        for m, p in zip(img_size, patch_size):
            if m < p:
                raise ValueError("patch_size should be smaller than img_size.")
            if self.pos_embed == "perceptron" and m % p != 0:
                raise ValueError("patch_size should be divisible by img_size for perceptron.")
        self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) + 1  # +1 for EHR
        self.patch_dim = in_channels * np.prod(patch_size)

        self.patch_embeddings: nn.Module
        if self.pos_embed == "conv":
            self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
                in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
            )
        elif self.pos_embed == "perceptron":
            # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
            chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
            from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
            to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
            axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
            self.patch_embeddings = nn.Sequential(
                Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size)
            )
            
        self.EHR_proj = nn.Sequential(nn.Linear(n_clin_var, hidden_size),
                                    #   nn.ReLU(inplace=True)
                                      )
        # self.n_patches = 76

        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
        self.dropout = nn.Dropout(dropout_rate)
        self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def trunc_normal_(self, tensor, mean, std, a, b):
        # From PyTorch official master until it's in a few official releases - RW
        # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
        def norm_cdf(x):
            return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

        with torch.no_grad():
            l = norm_cdf((a - mean) / std)
            u = norm_cdf((b - mean) / std)
            tensor.uniform_(2 * l - 1, 2 * u - 1)
            tensor.erfinv_()
            tensor.mul_(std * math.sqrt(2.0))
            tensor.add_(mean)
            tensor.clamp_(min=a, max=b)
            return tensor

    def forward(self, x):
        img, clin_var = x
        x = self.patch_embeddings(img)
        
        clin_var = self.EHR_proj(clin_var)
        clin_var = clin_var.unsqueeze(dim=1)
        # clin_var = repeat(clin_var, 'b c d->b (repeat c) d', repeat=75)

        if self.pos_embed == "conv":
            x = x.flatten(2).transpose(-1, -2)
            
        x = torch.cat([clin_var, x], dim=1)



        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings


In [None]:
test_set =  