In [None]:
import torch
import torchio as tio
from torch.utils.data import Dataset
import numpy as np
import numpy.typing as npt
import os
import torch.nn as nn
import pydicom
import SimpleITK as sk
import itk
from pathlib import Path
from rt_utils import RTStructBuilder
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from functools import reduce
import random
from numpy import load
from dicom_contour.contour import *

In [None]:
class DosePrdictionDataset(Dataset):
    def __init__(self, npz_pathlist_patients:list):
        self.patient_folder_paths = npz_pathlist_patients
        self.seed_count = 0
        
    def __len__(self):
        return len(self.patient_folder_paths)
        
    def __getitem__(self, index:int):
        patient_dir = self.patient_folder_paths[index] #get patient folder path
        files = []
        for path in os.listdir(patient_dir):
            files.append(os.path.join(patient_dir, path))
        X = torch.from_numpy(load(files[1],allow_pickle=False)['arr_0'])
        y = load(files[0],allow_pickle=False)['arr_0']
        X = self.rescale_intensity(X)
        X = self.transform(X)
        self.seed_count+=1 
        # print(f"Input data loaded for Patient ID:{patient_dir.split('/')[-1]}")
        return X,y
    
    
    # ## FUNCTION TO CREATE INPUTS AND OUTPUT
    # def get_X_y(self, patient_dir:str):
    #     # ====== Get folder paths of all inputs ======
    #     input_dir = self.get_subfolder_path(patient_dir)          
    #     # ====== Get subfolder paths for all inputs ======
    #     X_paths = [input_path for input_path in input_dir if (input_path.endswith("CT") and ~input_path.endswith("STRUCT"))][0]
    #     y_paths = [input_path for input_path in input_dir if input_path.endswith("DOSE")][0]  
    #     X_path = self.get_subfolder_path(X_paths)[1]
    #     y_path = self.get_subfolder_path(y_paths)[1]        
    #     # # ====== Create input and output ======
    #     # X = self.get_CT_Struct(X_paths, self.struct_types)
    #     # #add up structure contours to create a structure mask
    #     # struct_mask = reduce(lambda a, b: a+b, X[1:]) 
    #     # struct_mask[struct_mask.nonzero()]=1 
    #     #concatenate ct and contours into combined tensor input X & get label y
    #     # X = torch.from_numpy(np.stack(X)).type(torch.float32)
    #     # y = self.extract_images(y_paths).astype("float32")
    #     #preprocess X and y
    #     X,y = self.preprocess(X_path,y_path)
    #     #augmentation 
    #     X = self.transform(X)
    #     #update seed
    #     self.seed_count+=1     
    #     return X,y
       
    # ## FUNCTION TO EXTRACT CT AND DOSE IMAGES
    # def extract_images(self, folder_path:str):
    #     path = Path(folder_path)
    #     file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(path))
    #     file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(path), file_ids[0])
    #     series_reader = sk.ImageSeriesReader()
    #     series_reader.SetFileNames(file_names)
    #     image_data = series_reader.Execute()
    #     images_3D = sk.GetArrayFromImage(image_data)
    #     return images_3D   
    
    
    # ## FUNCTION TO GET CT AND CONTOUR MASKS 
    # def get_CT_Struct(self, path:str, ROI_names:list):
    #     combined_images = []
    #     #store dicom file
    #     contour_file = get_contour_file(path)
    #     contour_data = dicom.read_file(path + '/' + contour_file)
    #     ROI_list = get_roi_names(contour_data)
    #     target_ROI_index = []
    #     for name in ROI_names:
    #         for t in ROI_list:
    #             if name.lower()==t.lower():         
    #                 target_ROI_index.append(ROI_list.index(t))
    #     images, contours = get_data(path, index=target_ROI_index[0])
    #     #get CT images
    #     CT = np.stack([images[i] for i in range(images.shape[0])])
    #     combined_images.append(CT)
    #     for index in target_ROI_index:
    #         images, contours = get_data(path, index=index)
    #         #get contour maps
    #         contour_slices = [contours[i] for i in range(contours.shape[0])]
    #         contour_3d = [fill_contour(c) if c.max()==1 else c for c in contour_slices]
    #         contour_3d = np.stack(contour_3d)
    #         combined_images.append(contour_3d)
    #     return combined_images
        

    ## FUNCTION TO CROP INPUT IMAGES TO TARGET SIZE USING CONTOUR MASK
    # def crop(self, input_tensor:torch.TensorType, struct_mask:npt.ArrayLike=None):
    #     if struct_mask is not None:
    #         struct_mask = torch.from_numpy(np.stack([struct_mask])).type(torch.float32)
    #         subject = tio.Subject(
    #             input_tensor=tio.ScalarImage(tensor=input_tensor),
    #             mask=tio.LabelMap(tensor=struct_mask)
    #         )
    #         cropper = tio.CropOrPad(target_shape=self.out_dim, 
    #                                 mask_name='mask',
    #                                 padding_mode=0) 
    #         cropped = cropper(subject)
    #     else:
    #         subject = tio.Subject(
    #             input_tensor=tio.ScalarImage(tensor=input_tensor)
    #         )
    #         cropper = tio.CropOrPad(target_shape=self.out_dim, 
    #                                 padding_mode=0) 
    #         cropped = cropper(subject)
    #     return cropped.input_tensor.tensor
    
    
    # ## FUNCTION TO CROP INPUT IMAGES
    # def crop(self, images:object, if_dose:bool=False):
    #     if if_dose:
    #         padding = 0
    #     else:
    #         padding = -1024
    #     cropper = tio.CropOrPad(target_shape=self.out_dim, padding_mode=padding)
    #     cropped = cropper(images)
    #     return cropped

    
    # ## FUNCTION TO ALIGN INPUT AND LABEL
    # def get_alignment(self, CT_path:str, y_path:str):
    #     CT = itk.imread(CT_path)
    #     CT_origin = np.array(list(CT.GetOrigin()))
    #     spacing = np.array(CT.GetSpacing())
    #     file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(y_path))
    #     file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(y_path), file_ids[0])
    #     series_reader = sk.ImageSeriesReader()
    #     series_reader.SetFileNames(file_names)
    #     image_data = series_reader.Execute()
    #     dose_origin = np.array(image_data.GetOrigin()[:3])
    #     return ((dose_origin-CT_origin)/spacing)[::-1].astype(int)
              
    
    ## FUNCTIO TO RESCALE INTENSITY OF PIXEL VALUE
    def rescale_intensity(self, X:torch.TensorType, out_min:float=0.0, out_max:float=1.0):
        in_min = X.min().item()
        in_max = X.max().item()
        tio_image = tio.ScalarImage(tensor=X)
        rescale = tio.RescaleIntensity(out_min_max=(in_min,in_max), 
                                       in_min_max=(out_min, out_max))
        rescaled_image = rescale(tio_image).tensor
        return rescaled_image
    
    
    # ## FUNCTION TO PREPROCESS INPUT AND LABEL   
    # def preprocess(self, X:tensor.TensorType, y:npt.ArrayLike):
    #     alignment = self.get_alignment(X_path, y_path)
    #     X = self.get_CT_Struct(X_path, self.struct_types)
    #     y = self.extract_images(y_path).astype("float32")
    #     # print("Original input size: ", X.shape)
    #     # print("Original label size: ",y.shape)
    #     X = [layer[alignment[0]:,alignment[1]:,alignment[2]:] for layer in X]
    #     X = torch.from_numpy(np.stack(X)).type(torch.float32)
    #     X = self.crop(X)
    #     y = self.crop(y, True)
    #     print("Cropped input size: ", X.shape)
    #     print("Cropped label size: ",y.shape)
    #     X = self.rescale_intensity(X)
    #     return X, y
              
    
    ## FUNCTION TO AUGMENT INPUT
    def transform(self, X:torch.TensorType):
        random.seed(self.seed_count)
        transform = tio.Compose([
            tio.RandomAffine(scales=(random.uniform(0.0,0.5),
                                     random.uniform(0.0,0.5),
                                     random.uniform(0.0,0.5)),
                             degrees=random.randint(0,30),
                             translation=(random.randint(0,10),
                                          random.randint(0,10),
                                          random.randint(0,10))),
            tio.RandomBlur(std=(random.randint(0,10),
                                random.randint(0,10), 
                                random.randint(0,10))),
            tio.RandomNoise(mean=random.uniform(0.0,1.0),
                            std=(0, random.uniform(0,0.5))),
            tio.RandomSwap(patch_size=random.randint(0,5))
        ])
        return transform(X)
       
    ## FUNCTION TO EXTRACT SUBFOLDER PATH
    def get_subfolder_path(self, folder_path):
        subfolder_path = []
        for roots, dirs, files in os.walk(folder_path):
            subfolder_path.append(roots)
        return subfolder_path