In [70]:
import torch
import torchio as tio
from torch.utils.data import Dataset
import numpy as np
import numpy.typing as npt
import os
import torch.nn as nn
import pydicom
import SimpleITK as sk
from pathlib import Path
from rt_utils import RTStructBuilder
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from functools import reduce

In [72]:
class DosePrdictionDataset(Dataset):
    def __init__(self, input_dir:list, struct_types:list):
        self.if_transform = if_transform
        self.input_dir = input_dir
        self.struct_types = struct_types
        
        
    def __len__(self):
        return len(self.input_dir)
    
    
    def __getitem__(self, index:int):
        patient_dir = self.input_dir[index] #get folder path for one specific patient
        X, y = self.get_X_y(patient_dir) #create inputs and output for this patient
        return X,y
    
    
    ## FUNCTION TO CREATE INPUTS AND OUTPUT
    def get_X_y(self, patient_dir:str):
        # ====== Get folder paths of all inputs ======
        input_dir = self.get_subfolder_path(patient_dir)  
        
        # ====== Get subfolder paths for all inputs ======
        X_paths_CT = [input_path for input_path in input_dir if (input_path.endswith("CT") and ~input_path.endswith("STRUCT"))][0]
        X_paths_STRUCT = [input_path for input_path in input_dir if input_path.endswith("STRUCT")][0]
        y_paths = [input_path for input_path in input_dir if input_path.endswith("DOSE")][0]  
        X_paths_CT = self.get_subfolder_path(X_paths_CT)[1]
        X_paths_STRUCT = self.get_subfolder_path(X_paths_STRUCT)[1]
        y_paths = self.get_subfolder_path(y_paths)[1]
        
        # ====== Create input ======
        out_dim = (100,200,300)
        ct = self.extract_images(X_paths_CT) #get CT scan
        struct_contours = self.extract_struct(X_paths_CT, 
                                X_paths_STRUCT, 
                                self.struct_types) #get struct contours
        #add structure contours and create mask
        struct_mask = reduce(lambda a, b: a+b, struct_contours) 
        struct_mask[struct_mask.nonzero()]=1 
        #concatenate ct and contours into combined tensor input X
        X = struct_contours
        X.append(ct) 
        X = torch.from_numpy(np.stack(X)).type(torch.float32)
        #crop inputs to specified dimension using contour mask
        X = self.crop(ct, struct_mask, out_dim) 
        #normalize pixel value to (0,1)
        X = self.rescale_intensity(X)

        # ====== Create ground truth dose map ======
        y = self.pad(self.extract_images(y_paths).astype("int32"), out_dim)
        return X,y
    
    
    ## FUNCTION TO EXTRACT CT AND DOSE IMAGES
    def extract_images(self, folder_path:str):
        path = Path(folder_path)
        file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(path))
        file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(path), file_ids[0])
        series_reader = sk.ImageSeriesReader()
        series_reader.SetFileNames(file_names)
        image_data = series_reader.Execute()
        images_3D = sk.GetArrayFromImage(image_data)
        return images_3D
    
    
    ## FUNCTION TO CTEATE STRUCTURE CONTOURS
    def extract_struct(self, CT_folder_path:str, strcut_folder_path:str, struct_name:list):
        struct_path = []
        for path in os.listdir(strcut_folder_path):
            full_path = os.path.join(strcut_folder_path, path)
            if os.path.isfile(full_path):
                struct_path.append(full_path)
        rtstruct = RTStructBuilder.create_from(dicom_series_path=CT_folder_path, 
                                               rt_struct_path=struct_path[0])
        masks = []
        for struct in struct_name:
            mask = rtstruct.get_roi_mask_by_name(struct).astype("int32")
            mask = np.stack([mask[:,:,i] for i in range(mask.shape[2])])
            masks.append(mask)
        return masks

    ## FUNCTION TO CROP INPUT IMAGES TO TARGET SIZE USING CONTOUR MASK
    def crop(self, X:torch.TensorType, struct_mask:npt.ArrayLike, out_dim:tuple):
        struct_mask = torch.from_numpy(np.stack([struct_mask])).type(torch.float32)
        subject = tio.Subject(
            X=tio.ScalarImage(tensor=X),
            mask=tio.LabelMap(tensor=struct_mask)
        )
        transform = tio.CropOrPad(out_dim,mask_name='mask') 
        transformed = transform(subject)
        return transformed.X.tensor
    
    ##FUNCTION TO PAD INPUT TO THE TARGET SIZE
    def pad(self, y:npt.ArrayLike, out_dim:tuple):
        label = tio.ScalarImage(tensor=y)
        padder = tio.CropOrPad(out_dim, padding_mode=0)
        padded = padder(label)
        return padded.numpy()        
        
    
    ## FUNCTIO TO RESCALE INTENSITY OF PIXEL VALUE
    def rescale_intensity(self, X:torch.TensorType, out_min:float=0.0, out_max:float=1.0):
        in_min = X.min().item()
        in_max = X.max().item()
        tio_image = tio.ScalarImage(tensor=X)
        rescale = tio.RescaleIntensity(out_min_max=(in_min,in_max), 
                                       in_min_max=(out_min, out_max))
        rescaled_image = rescale(tio_image).tensor
        return rescaled_image
    
    
    ## FUNCTION TO EXTRACT SUBFOLDER PATH
    def get_subfolder_path(self, folder_path):
        subfolder_path = []
        for roots, dirs, files in os.walk(folder_path):
            subfolder_path.append(roots)
        return subfolder_path