In [2]:
import torch
from torch.utils.data import Dataset
import numpy as np
import os
import torch.nn as nn
import pydicom
import SimpleITK as sk
from pathlib import Path

In [5]:
class DosePrdictionDataset(Dataset):
    def __init__(self,  
                 input_dir:list,
                 input_types:list,
                 transform=None, 
                ):
        self.transform = transform
        self.input_dir = input_dir
        self.input_types = input_types
        
    def __len__(self):
        return len(self.input_dir)
    
    def __getitem__(self, index:int):
        subfolder_dir = self.input_dir[index]
        X, y = self.get_X_y(subfolder_dir)
        return X,y
    
    ## SHOULD BE CHANGED ACCORDINGLY
    def get_X_y(self, patient_dir):
        
        input_dir = []
        for root, dirs, files in os.walk(patient_dir):
            input_dir.append(root)    
        X_paths = [input_path for input_path in input_dir if any(input_path.endswith('/CT') for input_type in input_types)]
        y_paths = [input_path for input_path in input_dir if input_path.endswith("DOSE")]
        
        
        if self.transform != None:
            # X_images = [self.transform(pydicom.dcmread(image_path).pixel_array) for image_path in X_paths]
            # y_image = self.transform(pydicom.dcmread(y_paths[0]).pixel_array)
            ## ------ new version ------
            X = []
            for input_type in self.input_types:
                X.append(self.transform(self.extract_images([path for path in X_paths if input_type in path][0])))
            # 2.get dose map
            y = [path for path in X_paths if "DOSE" in path][0]
            y = np.array([self.transform(self.extract_images(dose_path))])
        else:
            # X_images = [pydicom.dcmread(image_path).pixel_array for image_path in X_paths]
            # y_image = np.array([pydicom.dcmread(y_paths[0]).pixel_array])   
            # X_images = [self.extract_CT(image_path) for image_path in X_paths]
            # y_image = np.array([self.extract_Dose(y_paths[0])])
            ## ------ new version ------
            # 1. get all input images
            X = []
            for input_type in self.input_types:
                X.append(self.extract_images([path for path in X_paths if input_type in path][0]))
            # 2.get dose map
            y_path = [path for path in X_paths if "DOSE" in path][0]
            y = np.array([self.extract_images(y_path)])  
        X = torch.from_numpy(np.stack(X_images).astype("int32")).type(torch.float32)
        y = torch.from_numpy(y_image.astype("int32")).type(torch.float32)
        
        return X,y
    
    ## Function to extract CT images
    def extract_images(self, folder_path):
        subfolder_dirs = []
        for roots, dirs, files in os.walk(folder_path):
            subfolder_dirs.append(roots)
        subfolder_path = subfolder_dirs[1]
        subfolder_path = Path(subfolder_path)
        file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(subfolder_path))
        file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(subfolder_path), file_ids[0])
        series_reader = sk.ImageSeriesReader()
        series_reader.SetFileNames(file_names)
        image_data = series_reader.Execute()
        images_3D = sk.GetArrayFromImage(image_data)
        if "DOSE" in folder_path:
            return images_3D[0]
        return images_3D

In [8]:
folder_path = "/Users/wangyangwu/Documents/Maastro/NeuralNets/PROTON/P0439C0006I1473766/CT"
subfolder_dirs = []
for roots, dirs, files in os.walk(folder_path):
    subfolder_dirs.append(roots)
subfolder_path = subfolder_dirs[1]
subfolder_path = Path(subfolder_path)
file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(subfolder_path))
file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(subfolder_path), file_ids[0])
series_reader = sk.ImageSeriesReader()
series_reader.SetFileNames(file_names)
image_data = series_reader.Execute()
images_3D = sk.GetArrayFromImage(image_data)
if "DOSE" in folder_path:
    images_3D = images_3D[0]

In [9]:
images_3D.shape

(189, 512, 512)