In [1]:
import torch
import torchio as tio
from torch.utils.data import Dataset
import numpy as np
import numpy.typing as npt
import os
import torch.nn as nn
import pydicom
import SimpleITK as sk
from pathlib import Path
from rt_utils import RTStructBuilder
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from functools import reduce
import random
from dicom_contour.contour import *

In [2]:
class DosePrdictionDataset(Dataset):
    def __init__(self, input_dir:list, struct_types:list, out_dim:tuple):
        self.input_dir = input_dir
        self.struct_types = struct_types
        self.out_dim = out_dim
        self.seed_count = 0
        
    def __len__(self):
        return len(self.input_dir)
        
    def __getitem__(self, index:int):
        patient_dir = self.input_dir[index] #get folder path for one specific patient
        X, y = self.get_X_y(patient_dir) #create inputs and output for this patient
        return X,y
    
    
    ## FUNCTION TO CREATE INPUTS AND OUTPUT
    def get_X_y(self, patient_dir:str):
        # ====== Get folder paths of all inputs ======
        input_dir = self.get_subfolder_path(patient_dir)          
        # ====== Get subfolder paths for all inputs ======
        X_paths = [input_path for input_path in input_dir if (input_path.endswith("CT") and ~input_path.endswith("STRUCT"))][0]
#         X_paths_STRUCT = [input_path for input_path in input_dir if input_path.endswith("STRUCT")][0]
        y_paths = [input_path for input_path in input_dir if input_path.endswith("DOSE")][0]  
        X_paths = self.get_subfolder_path(X_paths)[1]
#         X_paths_STRUCT = self.get_subfolder_path(X_paths_STRUCT)[1]
        y_paths = self.get_subfolder_path(y_paths)[1]        
        # ====== Create input and output ======
        X = self.get_CT_Struct(X_paths, self.struct_types)
#         struct_contours = self.extract_struct(X_paths_CT, 
#                                 X_paths_STRUCT, 
#                                 self.struct_types) #get struct contours
        print("SC",X[0].shape)
        #add up structure contours to create a structure mask
        struct_mask = reduce(lambda a, b: a+b, X[1:]) 
        struct_mask[struct_mask.nonzero()]=1 
        print("SM:", struct_mask.shape)
        #concatenate ct and contours into combined tensor input X & get label y
        X = torch.from_numpy(np.stack(X)).type(torch.float32)
        y = self.extract_images(y_paths).astype("float32")

        #preprocess X and y
        X,y = self.preprocess(X,y,struct_mask)
        #augmentation 
        X = self.transform(X)
        print("X:", X.shape, "y:",y.shape)
        #update seed
        self.seed_count+=1     
        return X,y
    
    
    ## FUNCTION TO EXTRACT CT AND DOSE IMAGES
    def extract_images(self, folder_path:str):
        path = Path(folder_path)
        file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(path))
        file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(path), file_ids[0])
        series_reader = sk.ImageSeriesReader()
        series_reader.SetFileNames(file_names)
        image_data = series_reader.Execute()
        images_3D = sk.GetArrayFromImage(image_data)
        return images_3D   
    
#     ## FUNCTION TO CTEATE STRUCTURE CONTOURS
#     def extract_struct(self, CT_folder_path:str, strcut_folder_path:str, struct_name:list):
#         struct_path = []
#         for path in os.listdir(strcut_folder_path):
#             full_path = os.path.join(strcut_folder_path, path)
#             if os.path.isfile(full_path):
#                 struct_path.append(full_path)
#         rtstruct = RTStructBuilder.create_from(dicom_series_path=CT_folder_path, 
#                                                rt_struct_path=struct_path[0])
#         masks = []
#         for struct in struct_name:
#             mask = rtstruct.get_roi_mask_by_name(struct).astype("int32")
#             mask = np.stack([np.fliplr(mask[:,:,i]) for i in range(mask.shape[2])])
#             masks.append(mask)
#         return masks
    
    
    ## FUNCTION TO GET CT AND CONTOUR MASKS 
    def get_CT_Struct(self, path:str, ROI_names:list):
        combined_images = []
        #store dicom file
        contour_file = get_contour_file(path)
        contour_data = dicom.read_file(path + '/' + contour_file)
        ROI_list = get_roi_names(contour_data)
        target_ROI_index = [ROI_list.index(r) for r in ROI_names]
        images, contours = get_data(path, index=target_ROI_index[0])
        #get CT images
        CT = np.stack([images[i] for i in range(images.shape[0])])
        combined_images.append(CT)
        for index in target_ROI_index:
            images, contours = get_data(path, index=10)
            #get contour maps
            contour_slices = [contours[i] for i in range(contours.shape[0])]
            contour_3d = [fill_contour(c) if c.max()==1 else c for c in contour_slices]
            contour_3d = np.stack(contour_3d)
            combined_images.append(contour_3d)
        return combined_images
        

    ## FUNCTION TO CROP INPUT IMAGES TO TARGET SIZE USING CONTOUR MASK
    def crop(self, input_tensor:torch.TensorType, struct_mask:npt.ArrayLike=None):
        if struct_mask is not None:
            struct_mask = torch.from_numpy(np.stack([struct_mask])).type(torch.float32)
            subject = tio.Subject(
                input_tensor=tio.ScalarImage(tensor=input_tensor),
                mask=tio.LabelMap(tensor=struct_mask)
            )
            cropper = tio.CropOrPad(target_shape=self.out_dim, 
                                    mask_name='mask',
                                    padding_mode=0) 
            cropped = cropper(subject)
        else:
            subject = tio.Subject(
                input_tensor=tio.ScalarImage(tensor=input_tensor)
            )
            cropper = tio.CropOrPad(target_shape=self.out_dim, 
                                    padding_mode=0) 
            cropped = cropper(subject)
        return cropped.input_tensor.tensor
        
    
    ## FUNCTIO TO RESCALE INTENSITY OF PIXEL VALUE
    def rescale_intensity(self, X:torch.TensorType, out_min:float=0.0, out_max:float=1.0):
        in_min = X.min().item()
        in_max = X.max().item()
        tio_image = tio.ScalarImage(tensor=X)
        rescale = tio.RescaleIntensity(out_min_max=(in_min,in_max), 
                                       in_min_max=(out_min, out_max))
        rescaled_image = rescale(tio_image).tensor
        return rescaled_image
    
    ## FUNCTION TO PREPROCESS INPUT AND LABEL   
    def preprocess(self, X:torch.TensorType, y:npt.ArrayLike, struct_mask:npt.ArrayLike):
        X = self.crop(X, struct_mask)
#         resamplerX = tio.Resample((X.shape[1]/self.out_dim[0], 
#                                    X.shape[2]/self.out_dim[1], 
#                                    X.shape[3]/self.out_dim[2]),
#                                   image_interpolation="bspline")
#         resamplerY = tio.Resample((y.shape[1]/self.out_dim[0], 
#                                    y.shape[2]/self.out_dim[1], 
#                                    y.shape[3]/self.out_dim[2]),
#                                   image_interpolation="bspline")
#         X = resamplerX(X)
        X = self.rescale_intensity(X)
        y = self.crop(y)
        return X, y
              
    
    ## FUNCTION TO AUGMENT INPUT
    def transform(self, X:torch.TensorType):
        random.seed(self.seed_count)
        transform = tio.Compose([
            tio.RandomAffine(scales=(random.uniform(0.0,0.5),
                                     random.uniform(0.0,0.5),
                                     random.uniform(0.0,0.5)),
                             degrees=random.randint(0,30),
                             translation=(random.randint(0,10),
                                          random.randint(0,10),
                                          random.randint(0,10))),
            tio.RandomBlur(std=(random.randint(0,10),
                                random.randint(0,10), 
                                random.randint(0,10))),
            tio.RandomNoise(mean=random.uniform(0.0,1.0),
                            std=(0, random.uniform(0,0.5))),
            tio.RandomSwap(patch_size=random.randint(0,5))
        ])
        return transform(X)
       
    ## FUNCTION TO EXTRACT SUBFOLDER PATH
    def get_subfolder_path(self, folder_path):
        subfolder_path = []
        for roots, dirs, files in os.walk(folder_path):
            subfolder_path.append(roots)
        return subfolder_path