In [None]:
# %matplotlib widget

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import json, os
from pydicom import dcmread
import seaborn as sns   
import pandas as pd
from utilities import natural_keys

The cell below creates a log file for the current notebook. This can help with debugging and record keeping. You can find the file in the `logs` folder.

In [None]:
import logging, os

logger  = logging.getLogger(__name__)
log_file = os.path.join('logs','preprocessing.log')
logging.basicConfig(filename=log_file, filemode='w',
                    level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

The Json file containing the user preferences is central to all of the classes used in this demo. The directory of this file is used as an input when the classes are instantialized. 

In [None]:
user_inputs_dir = "configuration_files/user_config.json"

---
## The DicomToolbox package

This is a class designed for parsing DICOM files. Unlike previous iterations, this version does not rely on HDF5 files to minimize the RAM load. The HDF5 dependence served well when we were using this class to simulate perturbed dose distributions, which used as many as 23 processors and was memory demanding. We intended for this version of the DicomToolbox to be used in serial applications.

In [None]:
from dicom_toolbox import DicomToolbox

dt = DicomToolbox(user_inputs_dir)
# dt.expected_data = ['ct', 'rtdose', 'rtstruct'] # allows you to process patients with only these data types

**Identifying patients folders with the necessary files**

The cell below will identify all the patients folders that contain the necessary files for dose prediction with this tool. The method searches for folders inside the folder you specify in the `raw_patient_data` field of the `user_config` JSON file.


Note: If you want to use the `DicomToolbox` class to read data with missing plan files, you can remove `rtplan` from the value of `self.expected_data`. This should be rare, so you can ignore this note if it is not clear.

In [None]:
patient_ids_found = dt.identify_patient_files()

print(f'No. of patients found: {len(patient_ids_found)}')
print(f'Patients found: {", ".join(sorted(patient_ids_found))}')

### NEW: You can display the content of the header of the DICOM files with the `get_header_info` method as shown below.

In [None]:
dt.get_header_info(1, 'ct', save_to_file=True, echo=True)

### Parsing the DICOM files of a single patient

In this example, the resolution of the masks or structures will be set to match the resolution of the CT volume. This is done by setting the `mask_resolution` input variable to `"ct"` when calling the `parse_dicom_files()` method.

In [None]:
patient_ID = 1

# the next line is optional and I include it just for demonstration purposes
# dt.desired_contour_set ='clinical' # in some cases, you may have auto-contours, which can be read by setting this to 'auto' 

dt.parse_dicom_files(patient_ID, mask_resolution='ct')

When you parse data for a patient, the structures/contours/masks are stored in a dictionary called `self.contours`. You can view the content of the dictionary by running the cell below.

In [None]:
for c in dt.contours.keys():
    print(c)

In [None]:
dt.cumulative_dose.max()

In [None]:
sn = 80
plt.figure(figsize=(7,7))
plt.imshow(dt.ct.data[sn,:,:], cmap='gray', vmin=-100, vmax = 300)
for c in dt.contours.keys():
    try:
        plt.contour(dt.contours[c].data[sn,:,:], colors=[np.random.rand(3,)])
    except:
        pass
plt.axis('off')

Getting the cummulative dose is easy with the DicomToolbox class. You can use the method `cumulative_dose`, which is a class property, to get the cummulative dose. Since it is a class property, you do not need to use parenthesis to call it. 

In [None]:
cum_dose = dt.cumulative_dose

print(f"Mean dose: {np.mean(cum_dose):.2f} Gy")
print(f"Max dose: {np.max(cum_dose):.2f} Gy")

In [None]:
sn = 73
fig, ax = plt.subplots(figsize=(7,7)) # create subplot
im = ax.imshow(cum_dose[sn,:,:], cmap='jet')

divider = make_axes_locatable(ax) # TIP: this is used to make the colorbar the same height as the image
cax = divider.append_axes("right", size="5%", pad=0.05)

cbar = plt.colorbar(im, cax=cax) # add colorbar to subplot
cbar.ax.set_ylabel('Cumulative Dose (Gy)', labelpad=10)
ax.axis('off');

In [None]:
dt.radiation_type

Note for Josiane: Change the number below to show only the dose above the value you specify.

Another option when you parse the data, which is helpful when you are inspecting a dataset and do not want to build all of the masks, is to set `mask_name_only` to `True`. In this case, the contours will not be returned as a dictionary but rather as a list of the names of the contours. This is demonstrated in the cell below.

Note that this execution will take only a few seconds at the expense of less data. You still have the CT and dose volumes as outputs.

In [None]:
dt.parse_dicom_files(patient_ID, mask_names_only=True)

In the previous example, we parsed the contours at the resolution of the CT. Now we can check the option for parsing at the resolution of the dose. This option is useful for tasks like computing a DVH, which we can perform in the cells below.  

In [None]:
dt.parse_dicom_files(patient_ID, mask_resolution='dose')

In [None]:
from pydicom import dcmread
from pydicom.sequence import Sequence
from pydicom.dataset import Dataset

from scipy.interpolate import interp1d
from skimage.draw import polygon
from dataclasses import dataclass, field
from utilities import interpolate_volume, natural_keys
from tqdm import tqdm
from skimage.measure import find_contours
import multiprocessing as mp
from multiprocessing import Pool
import numpy as np
import logging

import os, json, copy, pandas, h5py, logging, re, pickle, pydicom, traceback

import numpy as np
import pandas as pd

@dataclass
class Coordinates(): 
    x: np.ndarray
    y: np.ndarray
    z: np.ndarray
    dx: float
    dy: float
    dz: float
    image_position: np.ndarray

@dataclass
class CT():
    shape: tuple  
    resolution: np.ndarray
    max_value: float
    min_value: float  
    units: str 
    rescale_slope: float
    rescale_intercept: float
    patient_position: str
    data: np.ndarray
    slice_thickness: float
    coordinates: Coordinates
    
@dataclass
class CommonDoseTags():
    shape: tuple
    max_value: float
    min_value: float
    resolution: np.ndarray
    dose_grid_scaling: float
    dose_units: str
    data: np.ndarray
    coordinates: Coordinates
    beam_number: int = 1
    beam_type: str = 'not_specified'
    gantry_angle: float = 0.0
    patient_support_angle: float = 0.0
    table_top_pitch_angle: float = 0.0
    table_top_roll_angle: float = 0.0
    isocenter: np.ndarray = field(default_factory=lambda: np.array([0, 0, 0]))
    radiation_type: str = 'not_specified'
    treatment_delivery_type: str = 'not_specified'
    beam_name: str = 'not_specified'
    treatment_machine: str = 'not_specified'
    beam_description: str = 'not_specified'
    number_of_control_points: int = 0
    final_cumulative_meterset_weight: float = 0.0
    beam_dose: float = 0.0
    scan_mode: str = 'not_specified'
    primary_dosimetric_units: str = 'not_specified'
     
@dataclass
class ProtonDose(CommonDoseTags):
    vsad: np.ndarray = field(default_factory=lambda: np.array([0, 0, 0]))
    
@dataclass
class PhotonDose(CommonDoseTags):
    sad: np.ndarray = field(default_factory=lambda: np.array([0, 0, 0]))
    gantry_rotation_direction: list = field(default_factory=lambda: ['not_specified'])
    
@dataclass
class Mask():
    data: np.ndarray
    resolution: str
    coordinates:Coordinates
    number: int = 0
    structure_set: str = 'clinical'
    generation_method: str = 'unknown'
    
@dataclass
class Plan():
    number_of_beams: int = 1
    geometry: str = 'not_specified'
    patient_position: str = 'not_specified'
    patient_sex: str = 'not_specified'
    plan_label: str = 'not_specified'
    number_of_fractions_planned: int = 0
    dose_per_fraction: float = 0.0
    dose_reference_type: str = 'not_specified'
    dose_reference_description: str = 'not_specified'
    dose_reference_dose: float = 0.0
    radiation_type: str = 'not_specified'
    beam: dict = field(default_factory=lambda: {})
    
@dataclass
class Beam():
    gantry_angle: float = 0.0
    patient_support_angle: float = 0.0
    isocenter: np.ndarray = field(default_factory=lambda: np.array([0, 0, 0]))
    treatment_delivery_type: str = 'not_specified'
    treatment_machine: str = 'not_specified'
    type: str = 'not_specified'
    sad: float = 0.0
    vsad: np.ndarray = field(default_factory=lambda: np.array([0, 0, 0]))
    
class DicomToolbox():
    """"Class for parsing a set of DICOM-RT files for a patient.

        Created on the Fall of 2021 by Ivan Vazquez in collaboration with Ming Yang. 

        Last updated: January 2024

        Copyright 2021-2024 Ivan Vazquez
    """
    original_ct_coordinates =Coordinates(0, 0, 0, 0, 0, 0, [0, 0, 0])
    original_dose_coordinates = Coordinates(0, 0, 0, 0, 0, 0, [0, 0, 0])
    
    def __init__(self, user_inputs_dir=None, patient_data_directory=None, lut_directory=None) -> None:
        
        self.__logger = logging.getLogger(__name__)

        # Prepare RSP LUT directory
        if lut_directory is not None:
            lut_directory = os.path.join('utilities','LUT')
            if not os.path.isdir(lut_directory): 
                self.__logger.warning(f'The directory {lut_directory} for the look-up tables does not exist.')
                self.lut_directory = None
            else:
                self.lut_directory = lut_directory

        # Load user input information if a directory is provided
        if user_inputs_dir is not None:
            with open(user_inputs_dir, "r") as f:
                self.user_inputs = json.load(f)  
                self.patient_data_directory = self.user_inputs['DIRECTORIES']['raw_patient_data']
        else:
            self.__set_default_user_inputs()
            
        # set the directory for the patient data to the specified value if one is given
        if patient_data_directory is not None: self.patient_data_directory = patient_data_directory
        assert self.patient_data_directory is not None

        # Check if the necessary directories exists
        for directory in ['logs', 'temp', os.path.join('temp','data')]:
            if not os.path.isdir(directory): os.makedirs(directory, exist_ok=True)
            
        # Initialize variables
        self.reset()
        self.parallelize = None
        self.n_threads = self.user_inputs["PARALLELIZATION"]["number_of_processors"]
        self.min_coordinate_precision = 3
    
    def __set_default_user_inputs(self):
            
            self.user_inputs = {
                "TYPE_OF_TARGET_VOLUME": "ctv",
                "PARALLELIZATION": {
                    "number_of_processors": 1,
                },
                "DATA_PREPROCESSING": {
                    "contour_interpolation_method":"nearest",
                }
            }
            self.patient_data_directory = None
            self.__logger.warning("No user inputs were provided. Setting default values.")
            
    def reset(self):
        
        self.patient_id = None
        self.original_ct_coordinates = None
        self.original_dose_coordinates = None
        self.mask_interpolation_technique = self.user_inputs["DATA_PREPROCESSING"]["contour_interpolation_method"]
        self.mask_generation_method = 'interpolate'
        self.write_new_hdf5_file = True
        self.relevant_masks = None
        self.compression = 'lzf'
        self.expected_data = ['ct', 'rtdose', 'rtplan', 'rtstruct']
        self.radiation_type = None
        self.echo_progress = True
        self.echo_level = 0
        self.coordinate_precision = 3
        self.equalize_dose_grid_dimensions = True
          
    def identify_patient_files(self, patient_data_directory = None, echo=False):
        """Function to identify the number of patient folders with all DICOM-RT files 
           needed for proper functioning of the code.

            Parameters
            ----------

            `patient_data_directory` : str
                The location of the patient data folders containing the required DICOM-RT files.

            `echo` : bool
                Flag to prompt funtion to write the number of patient folders found

            Returns
            -------

            `list`
                Names of folders for the patients found.
        """

        # check if a directory was specified
        if patient_data_directory is not None: self.patient_folders_directory = patient_data_directory

        # get a list of all of the folders in the directory
        folders = os.listdir(self.patient_data_directory)

        # check folder content to avoid future errors
        patient_files_info = {f:{'modalities':[], 'folder_directory':''} for f in folders}
        
        self.__logger.info(f'Checking the content of {self.patient_data_directory} to identify patient folders with the required DICOM-RT files.')
        
        for folder in folders:

            patient_folder_directory = os.path.join(self.patient_data_directory, folder)
            
            for root, _, files in os.walk(patient_folder_directory):
                
                for file in files:
        
                    # grab modality for DICOM file
                    file_directory = os.path.join(root, file)
                    
                    try:
                        ds = dcmread(file_directory)

                        modality = ds.data_element('Modality').value
                                                
                        patient_files_info[folder]['modalities'].append(modality.lower())
                                                
                    except:
                        self.__logger.warning(f"The content for the folder '{folder}' could not be read")
                        break
            
            # remove repeated modality values and sort the resulting list                                        
            patient_files_info[folder]['modalities'] = sorted(list(set(patient_files_info[folder]['modalities'])))
            patient_files_info[folder]['folder_directory'] = patient_folder_directory
         
        # Record the data folders with the required DICOM-RT files
        patients = []
        for p in patient_files_info.keys():
            if not all(m in patient_files_info[p]['modalities'] for m in self.expected_data):
                self.__logger.warning(f"The folder '{p}' is missing one or more of the required DICOM-RT files. "
                                      f"Current modalities: {', '.join(patient_files_info[p]['modalities'])}")                           
            else:
                patients.append(p)
        
        patients.sort(key=natural_keys)
        
        if echo: self.__logger.info(f'Found {len(patients)} patient folders in {patient_data_directory} with the required DICOM-RT files.')

        return patients
    
    def get_header_info(self, patient_id, file_type, save_to_file=False, echo=False):
          
        patient_files = self.run_initial_check(patient_id)
        
        if file_type == 'ct':
            files = patient_files['ct']
        elif file_type == 'dose':
            files = patient_files['dose']
        elif file_type == 'plan':
            files = patient_files['plan']
        elif file_type == 'structures':
            files = patient_files['structures']
        else:
            self.__logger.error(f'Invalid file type {file_type}.')
            return
        
        for n, f in enumerate(files):
            ds = dcmread(f)    
            
            if echo: print(ds) # print the header information
            
            # save pretty json to file in log directory
            if save_to_file:
                header_output_dir = os.path.join('logs', f'{file_type}_header_info_{file_type}_{n}.json')
                with open(header_output_dir, 'w') as outfile:
                    print(ds, file=outfile)
            
            # if CT, exit
            if file_type == 'ct': break
                        
    def identify_radiation_type(self, patient_files, patient_id):
        """Function to identify the main type of radiation therapy used for a patient. The 
        function determines the most common radiation type used for the beams in the plan file.
        
        """
        
        ds = dcmread(patient_files['plan'][0])
                
        if self.radiation_type is not None: return
                        
        try:            
            radiation_types = [b.RadiationType.lower() for b in ds.BeamSequence]
            unique_radiation_types = list(set(radiation_types))
            self.radiation_type = max(unique_radiation_types, key=radiation_types.count) 
        except:
            try:
                radiation_types = [b.RadiationType.lower() for b in ds.IonBeamSequence]
                unique_radiation_types = list(set(radiation_types))
                self.radiation_type = max(unique_radiation_types, key=radiation_types.count)                
            except:
                self.__logger.error(f'Failed to identify the radiation type for pat-{patient_id}.')
                self.__logger.info('If you know the radiation type, please specify it with the class attribute "radiation_type".')   
                                                        
    def run_initial_check(self, patient_id=None):
        
        assert self.patient_data_directory is not None
        if patient_id is not None:
            if type(patient_id) != type(""): patient_id = str(patient_id)
            self.patient_id = patient_id
        
        # prepare patient data directory
        patient_directory = os.path.join(self.patient_data_directory, self.patient_id)
        
        # Detect all files in the directory
        try:
            files = os.listdir(patient_directory)
        
            # Find directory of all type of DICOM files 
            patient_files = {'ct':[], 'plan':[], 'structures':[], 'dose':[]}
            
            # Discover all of the files for the patients
            for root, _, files in os.walk(patient_directory):
                for f in files:
                    # Get modality for the file
                    ds = dcmread(os.path.join(root,f))
                    modality = ds.data_element('Modality').value.lower()
                    
                    # Add file to the corresponding list
                    if modality == 'rtdose':
                        patient_files['dose'].append(os.path.join(root,f))
        
                    elif modality == 'rtplan':
                        patient_files['plan'].append(os.path.join(root,f))
                        
                    elif modality == 'rtstruct':
                        patient_files['structures'].append(os.path.join(root,f)) 
                        patient_files['structures'] = sorted(patient_files['structures'])
                        
                    elif modality == 'ct':
                        patient_files['ct'].append(os.path.join(root,f))     
                      
        except Exception as e:
            self.__logger.error(f"An error occured while trying to read the files for patient {self.patient_id}.")
            self.__logger.error(traceback.format_exc())
            return None
        
        print(patient_files['ct'])  
        
        # identify the radiation type
        if self.radiation_type is None: self.identify_radiation_type(patient_files, self.patient_id)
    
        # Detect incomplete data
        file_types_dict = {'ct':'ct', 'rtdose':'dose', 'rtplan':'plan', 'rtstruct':'structures'}
        if any([patient_files[k]==[] for k in [file_types_dict[x] for x in self.expected_data]]):
            self.__logger.error(f'The full DICOM-RT set ({", ".join(self.expected_data)}) for patient {self.patient_id} could not be read.')
            self.__logger.info("Please change the expected data types by specifying the class attribute 'expected_data' or check the patient folder.")
            return None
        
        # Check the dose files to check for beam-specific or cumulative dose        
        patient_files['dose'] = self.__check_dose_files(patient_files['dose'])
            
        return patient_files
    
    def __check_dose_files(self, dose_files):
                
        dose_file_info = {n:{} for n in dose_files}
        
        for f in dose_files:
            with dcmread(f) as ds:
                dose_file_info[f]['dose_summation_type'] = ds.DoseSummationType.lower()
                dose_file_info[f]['data'] = ds.pixel_array * ds.DoseGridScaling
                try:
                    dose_file_info[f]['beam_number'] = int(ds.ReferencedRTPlanSequence[0][('300c','0020')][0][('300c','0004')][0][('300c','0006')].value)
                except:
                    dose_file_info[f]['beam_number'] = 'not_specified'
        
        # check the dose summation type
        if len(set([dose_file_info[f]['dose_summation_type'] for f in dose_files])) > 1:
            self.__logger.info(f'Multiple dose summation types were identified for patient {self.patient_id}: {", ".join([dose_file_info[f]["dose_summation_type"] for f in dose_files])}')

            # check total dose 
            cum_dose = np.sum([dose_file_info[f]['data'] for f in dose_files if dose_file_info[f]['dose_summation_type'] == 'beam'], axis=0)
            # grab the dose data for file with plan as dose summation type
            plan_dose = [dose_file_info[f]['data'] for f in dose_files if dose_file_info[f]['dose_summation_type'] == 'plan'][0]
            
            if np.round(np.abs(np.subtract(cum_dose,plan_dose).max())) > 0.0:
                self.__logger.error(f'The sum of the beam dose for patient {self.patient_id} does not match the dose for the plan file.')
                raise ValueError()
            else:
                return [f for f in dose_files if dose_file_info[f]['dose_summation_type'] == 'beam']
            
        else:
            # return the dose files
            return dose_files
                        
    def parse_dicom_files(self, patient_id=None, parse_structures=True, mask_resolution = None, patient_data_directory =None):
        
        # import tools for timing execution
        from time import time
        
        # check time for setting things up
        start = time()
        if patient_data_directory is not None: 
            self.user_inputs['DIRECTORIES']['raw_patient_data'] = patient_data_directory
        else:
            patient_data_directory  = self.user_inputs['DIRECTORIES']['raw_patient_data']
        
        if patient_data_directory is None:
            self.__logger.error('No patient data directory was specified.')
            raise Exception('Check log file')
        
        # initial checks
        if patient_id is not None: 
            if type(patient_id) != type(''): patient_id = str(patient_id)
            self.patient_id = patient_id
        if patient_id is None and self.patient_id is None:
            self.__logger.error('Calling DICOM parsing function without specifying a `patient_id`.')
            raise Exception('Check log file')
        self.mask_resolution = mask_resolution if mask_resolution is not None else 'dose'
        
        if patient_data_directory  is not None and patient_data_directory .split('.')[-1] in ['h5', 'hdf5']:
            self.read_data_from_hdf5()
            return
            
        # Identify the file types in the patient folder and type of radiation therapy
        self.dicom_files = self.run_initial_check(self.patient_id)
        
        # report time 
        print(f'Time to set up: {time()-start:.2f} seconds')
                        
        
        # time to parse the files
        start = time()
                        
        # Parse the CT volume
        self.ct = self.parse_ct_study_files(self.dicom_files['ct'])
        
        print(f'Time to parse CT: {time()-start:.2f} seconds')

        start = time()

        # Parse the dose volume and (optionally) the plan 
        if 'plan' in self.dicom_files.keys() and self.dicom_files['plan'] != []:
            self.dose, self.plan = self.parse_rt_dose_files(self.dicom_files['dose'], self.dicom_files['plan'])
        else:
            self.__logger.warning(f'No plan file was found for patient {self.patient_id}. Using default values for the plan.')
            self.dose = self.parse_rt_dose_files(self.dicom_files['dose'])
            self.plan = Plan()
            
        print(f'Time to parse dose: {time()-start:.2f} seconds')
        
        start = time()
                                                     
        # Parse the contours
        self.contours = self.parse_structure_files(sorted(self.dicom_files['structures']), names_only=not parse_structures, resolution=self.mask_resolution)
        
        print(f'Time to parse structures: {time()-start:.2f} seconds')

    def parse_ct_study_files(self, files=None, patient_id = None, units='hu'):
      
        if patient_id is not None: 
            self.patient_id = str(patient_id)
            files = self.run_initial_check(self.patient_id)['ct']
            
        # Prepare CT volume
        ct_slices = {dcmread(f).ImagePositionPatient[-1]:dcmread(f).pixel_array for f in files}
        ## Construct the z coordinate array 
        z = sorted(list(ct_slices.keys()))
        ## Build 3D CT dataset
        data = np.array([ct_slices[i].astype(float) for i in z])
        ## Determine the number of slice spacings used for the CT data         
        z_spacing = list(set(list(np.round(np.array(z[1:]) - np.array(z[0:-1]), self.coordinate_precision))))
        z = np.round(z, self.coordinate_precision)
        
        with dcmread(files[0]) as ds:
            
            # Grab the position of the patient
            patient_position = ds.PatientPosition.lower()
                   
            ## Grab image position (patient) attribute
            image_position = [np.round(p, self.coordinate_precision) for p in ds.ImagePositionPatient]
                       
            ## Update z-value of image position
            image_position[-1] = z[0]

            ## Store CT Study information
            xy_resolution = [np.round(float(i),self.coordinate_precision) for i in ds.PixelSpacing]
            rescale_slope = ds.RescaleSlope
            rescale_intercept = ds.RescaleIntercept     
            slice_thickness = np.round(ds.SliceThickness, self.coordinate_precision)
            
            # if np.round(ds.SliceThickness,self.coordinate_precision) not in z_spacing:
            #     self.__logger.warning(f'Mistmatch between the slice thickness {ds.SliceThickness}-mm and '
            #                           f'coordinate spacings ({", ".join([str(i) for i in z_spacing])})-mm for patient {self.patient_id}.')
                 
        ## Prepare the x and y coordinates
        x = np.round(np.arange(data.shape[2]) * xy_resolution[0] + image_position[0], self.coordinate_precision)
        y = np.round(np.arange(data.shape[1]) * xy_resolution[1] + image_position[1], self.coordinate_precision)
                 
        ## Interpolate volume if multiple slice thicknesses were used
        if len(z_spacing) > 1:
            min_dz = np.round(min(z_spacing), self.coordinate_precision)
            min_z, max_z = np.round(np.min(z), self.coordinate_precision), np.round(np.max(z), self.coordinate_precision)
            self.__logger.warning(f"Multiple slice thicknesses were identified for the CT data of patient {self.patient_id}: {', '.join([str(i) for i in z_spacing])} mm")  
            self.__logger.info(f'Interpolating CT data to achieve a uniform slice thickness of {min_dz}-mm')
            ### define new z-coordinates
            z_new = np.round(np.arange(min_z, max_z, min_dz), self.coordinate_precision)
            original_coordinates = (z, y, x)
            interpolation_coordinates = (z_new, y, x) 
            data = interpolate_volume(data, original_coordinates, interpolation_coordinates, 
                                      intMethod='linear', boundError=False, fillValue=0)
            z = z_new
        
        ## Create coordinate object for CT data
        dx, dy, dz = np.array(xy_resolution + [np.round(min(z_spacing), self.coordinate_precision)])
        coordinates = Coordinates(x,y,z,dx,dy,dz,image_position)

        ## Convert units to HU if specified
        if units == 'hu': data = data * rescale_slope + rescale_intercept
        if units != 'hu': units = 'original'    

        ## Save a copy of the original CT information to help create the masks
        self.original_ct_coordinates = Coordinates(x,y,z,dx,dy,dz,image_position)
        self.original_ct_shape = data.shape

        return CT(data.shape, (dx,dy,dz), np.max(data), np.min(data), units, rescale_slope,
                  rescale_intercept, patient_position, data, slice_thickness, coordinates)
    
    def get_plan_info(self, dsP):
                    
        plan = Plan()

        plan.geometry = dsP.RTPlanGeometry.lower() if hasattr(dsP, "RTPlanGeometry") else 'not_specified'
        plan.patient_sex = dsP.PatientSex.lower() if dsP.PatientSex != '' else 'not_specified'
        plan.radiation_type = self.radiation_type.lower()
        
        for fgs in dsP.FractionGroupSequence:
            plan.number_of_fractions_planned = int(fgs.NumberOfFractionsPlanned) if hasattr(fgs, "NumberOfFractionsPlanned") else 0
            plan.number_of_beams = int(fgs.NumberOfBeams)
            dose_per_beam = []
            for rbs in fgs.ReferencedBeamSequence:
                dose_per_beam.append(float(rbs.BeamDose) if hasattr(rbs, "BeamDose") else 0.0)
                
            plan.dose_per_fraction = np.sum(dose_per_beam) 

            if hasattr(fgs, "DoseReferenceSequence"):
                for drs in fgs.DoseReferenceSequence:
                    if hasattr(drs, "DoseReferenceStructureType") and drs.DoseReferenceStructureType.lower() == 'site':
                        plan.dose_reference_type = drs.DoseReferenceType.lower() if hasattr(drs, "DoseReferenceType") else 'not_specified'
                        plan.dose_reference_description = drs.DoseReferenceDescription.lower() if hasattr(drs, "DoseReferenceDescription") else 'not_specified'
                        plan.dose_reference_dose = drs.TargetPrescriptionDose.lower() if hasattr(drs, "TargetPrescriptionDose") else 0.0
            
            plan.plan_label = dsP.RTPlanLabel.lower() if hasattr(dsP, "RTPlanLabel") else 'not_specified'

            patient_position = list(set([x.PatientPosition for x in dsP.PatientSetupSequence]))
            if len(patient_position) > 1:
                self.__logger.warning(f'Multiple patient positions were identified for patient {self.patient_id}: {", ".join(patient_position)}')
            else:
                plan.patient_position = patient_position[0].lower()
        
        information_sequence = dsP.IonBeamSequence if self.radiation_type == 'proton' else dsP.BeamSequence
        
        for b in information_sequence:
            cps = b.IonControlPointSequence[0] if self.radiation_type == 'proton' else b.ControlPointSequence[0]
            if b.TreatmentDeliveryType.lower() == 'setup': continue # skip setup beams
            
            plan.beam[int(b.BeamNumber)] = Beam()
            plan.beam[int(b.BeamNumber)].type = b.BeamType.lower()
            if hasattr(b, "VirtualSourceAxisDistances"):
                plan.beam[int(b.BeamNumber)].sad = b.VirtualSourceAxisDistances
            else: 
                plan.beam[int(b.BeamNumber)].sad = float(b.SourceAxisDistance)
                
            plan.beam[int(b.BeamNumber)].gantry_angle = cps.GantryAngle
            plan.beam[int(b.BeamNumber)].patient_support_angle = cps.PatientSupportAngle
            plan.beam[int(b.BeamNumber)].isocenter = cps.IsocenterPosition
            plan.beam[int(b.BeamNumber)].treatment_delivery_type = b.TreatmentDeliveryType.lower()
            plan.beam[int(b.BeamNumber)].treatment_machine = b.TreatmentMachineName.lower()
                    
        return plan
    
    def get_additional_details_from_plan_file(self, dsP, dose, bn):
        
        # TODO: VMAT plans have more than one angle. This needs to be handled.
                
        # grab the beam dose      
        for fgs in dsP.FractionGroupSequence:
            for rbs in fgs.ReferencedBeamSequence:
                if int(rbs.ReferencedBeamNumber) == bn: 
                    dose.beam_dose = rbs.BeamDose if hasattr(rbs, "BeamDose") else 0
    
        information_sequence = dsP.IonBeamSequence if self.radiation_type == 'proton' else dsP.BeamSequence
 
        for b in information_sequence:
            if int(b.BeamNumber) == bn:
                
                cps = b.IonControlPointSequence[0] if self.radiation_type == 'proton' else b.ControlPointSequence[0]
                
                dose.beam_type = b.BeamType.lower()
                dose.radiation_type = b.RadiationType.lower()
                dose.beam_name = b.BeamName.lower()
                dose.beam_number = int(b.BeamNumber)
                dose.beam_description = b.BeamDescription.lower() if hasattr(b, "BeamDescription") else 'not_specified'
                dose.treatment_machine = b.TreatmentMachineName.lower()
                dose.final_cumulative_meterset_weight = b.FinalCumulativeMetersetWeight
                dose.scan_mode = b.ScanMode.lower() if hasattr(b, "ScanMode") else 'not_specified'
                dose.treatment_delivery_type = b.TreatmentDeliveryType.lower()
                dose.primary_dosimetric_units = b.PrimaryDosimeterUnit.lower()
                dose.number_of_control_points = int(b.NumberOfControlPoints)
                dose.gantry_angle = cps.GantryAngle
                dose.patient_support_angle = cps.PatientSupportAngle if hasattr(cps, "PatientSupportAngle") else 0.0
                dose.table_top_pitch_angle = cps.TableTopPitchAngle if hasattr(cps, "TableTopPitchAngle") else 0.0
                dose.table_top_roll_angle = cps.TableTopRollAngle if hasattr(cps, "TableTopRollAngle") else 0.0
                dose.isocenter = cps.IsocenterPosition
                
                if hasattr(b, "VirtualSourceAxisDistances"):
                    dose.vsad = b.VirtualSourceAxisDistances
                else:
                    dose.sad = b.SourceAxisDistance
                    dose.gantry_rotation_direction = b.GantryRotationDirection.lower() if hasattr(b, "GantryRotationDirection") else 'not_specified'
            
        return dose

    def parse_rt_dose_files(self, dose_files=None, plan_file=None, patient_id = None):
        
        if patient_id is not None: 
            self.patient_id = str(patient_id)
            dose_files = self.run_initial_check(self.patient_id)['dose']
            plan_file = self.run_initial_check(self.patient_id)['plan']

        assert dose_files is not None or patient_id is not None 
                        
        dose, args = {}, []
        for f in dose_files:
            with dcmread(f) as ds:
                                
                if len(dose_files) == 1: # handles the case for just one dose file
                    try: # check if the dose file is a beam-specific dose file
                        bn = int(ds.ReferencedRTPlanSequence[0][('300c','0020')][0][('300c','0004')][0][('300c','0006')].value)
                    except: # if not, assume it is a cumulative dose file
                        bn = 1
                    self.__logger.info(f'Only one dose file was found for patient {self.patient_id}.')
                    self.__logger.info('Assuming that the dose file contains the cummulative dose for the plan.')
                elif ds.DoseSummationType.lower() != 'plan': # handles the case for multiple dose files (beam-specific)
                    bn = int(ds.ReferencedRTPlanSequence[0][('300c','0020')][0][('300c','0004')][0][('300c','0006')].value)
                       
                # Grab data
                data = ds.pixel_array * ds.DoseGridScaling
                # Grab some data properties
                units = ds.DoseUnits
                xy_resolution = [np.round(float(x), self.coordinate_precision) for x in ds.PixelSpacing]
                dose_grid_scaling = float(ds.DoseGridScaling)
                image_position = [np.round(float(i), self.coordinate_precision) for i in ds.ImagePositionPatient]
                grid_offset_vector = np.round(np.array(ds.GridFrameOffsetVector), self.coordinate_precision)
                
                # Prepare coordinates
                x = np.round(np.arange(ds.Columns)*xy_resolution[0] + image_position[0], self.coordinate_precision)
                y = np.round(np.arange(ds.Rows)*xy_resolution[1] + image_position[1], self.coordinate_precision)
                z = np.round(grid_offset_vector+ image_position[2], self.coordinate_precision)
                ## Determine the number of slice spacings used for the dose data         
                z_spacing = list(set(list(np.round(np.array(z[1:]) - np.array(z[0:-1]), self.coordinate_precision))))
                ## Interpolate volume if multiple slice thicknesses were used
                if len(z_spacing) > 1:
                    min_dz = min(z_spacing)
                    min_z, max_z = np.round(np.min(z), self.coordinate_precision), np.round(np.max(z), self.coordinate_precision)
                    self.__logger.warning(f'Two or more slice thicknesses identified for the dose data of patient {self.patient_id}.')    
                    self.__logger.info(f'Interpolating dose data to achieve a uniform slice thickness of {min_dz}-mm')
                    ### define new z-coordinates
                    z_new = np.round(np.arange(min_z, max_z, min_dz), self.coordinate_precision)
                    original_coordinates = (z, y, x)
                    interpolation_coordinates = (z_new, y, x) 
                    data = interpolate_volume(data, original_coordinates, interpolation_coordinates, 
                                              intMethod='linear', boundError=False, fillValue=0)
                    z = z_new
                
                # Grab coordinate information
                dx, dy, dz = xy_resolution + [min(z_spacing)]
                coordinates = Coordinates(x,y,z,dx,dy,dz,image_position)
                
                # Prepare dose object
                if self.radiation_type == 'photon':
                    dose[bn] = PhotonDose(*[data.shape, data.max(), data.min(), (dx,dy,dz), 
                                            dose_grid_scaling, units, data, coordinates])
                else:
                    dose[bn] = ProtonDose(*[data.shape, data.max(), data.min(), (dx,dy,dz), 
                                            dose_grid_scaling, units, data, coordinates])
                
                # Grab additional information from plan file if available
                if plan_file is not None and plan_file != []:
                    with dcmread(plan_file[0]) as dsP:                                         
                        dose[bn] = self.get_additional_details_from_plan_file(dsP, dose[bn], bn)
                                            
        if len(set([dose[bn].data.shape for bn in dose.keys()])) > 1:      
            self.__logger.warning(f'Not all of the dose volumes have the same shape for patient {self.patient_id}.')
            if self.equalize_dose_grid_dimensions:
                self.__logger.info('Equalizing dose grid dimensions')
                dose = self.__equalize_dose_grid_dimensions(dose)

        # Save a copy of the original dose information to help create the masks
        assert len(set([dose[bn].data.shape for bn in dose.keys()])) == 1    
        self.original_dose_coordinates = copy.deepcopy(dose[bn].coordinates)
        self.original_dose_shape = data.shape

        # Grab additional information from plan file if available
        if plan_file is not None and plan_file != []: 
            with dcmread(plan_file[0]) as dsP:
            
                plan = self.get_plan_info(dsP)

            return dose, plan
        else:
            return dose
        
    def __equalize_dose_grid_dimensions(self, dose):
            
        max_shape = np.max([dose[bn].data.shape for bn in dose.keys()], axis=0)

        bn_to_correct = []

        for k in dose.keys():
            num_max_dims = []
            for n,s in enumerate(dose[k].data.shape):
                if s == max_shape[n]: 
                    num_max_dims.append(k)
            if len(num_max_dims) == 3: 
                max_dim_bn = k
                break

        if 'max_dim_bn' not in locals(): raise ValueError('Unable to find a beam with the maximum shape along all dimensions.')
        
        # determine patient(s) needing correction
        for k in dose.keys():
            for n,s in enumerate(dose[k].data.shape):
                if s != max_shape[n]: 
                    bn_to_correct.append(k)
                    
        for b in bn_to_correct:
            original_coordinates = (dose[b].coordinates.z, dose[b].coordinates.y, dose[b].coordinates.x) 
            interpolation_coordinates = (dose[max_dim_bn].coordinates.z, dose[max_dim_bn].coordinates.y, dose[max_dim_bn].coordinates.x)   
            data = interpolate_volume(dose[b].data, original_coordinates, interpolation_coordinates,
                                    intMethod='linear', boundError=0, fillValue=0)
            
            dose[b].data = data
            dose[b].coordinates = copy.deepcopy(dose[max_dim_bn].coordinates)
            dose[b].shape = data.shape
                        
        return dose
                       
    def parse_structure_files(self, structure_files=None, patient_id=None, structure_names=None, resolution='dose', names_only=False, parallelize=True, n_threads=None):
        
        assert patient_id is not None or self.patient_id is not None
        self.patient_id = patient_id or self.patient_id
        structure_files = structure_files or self.run_initial_check(self.patient_id)['structures']

        self.all_stucture_names = []
        struc_ds = [dcmread(f) for f in structure_files]
        for ds in struc_ds:
            self.all_stucture_names += [c['name'] for c in self.read_structure(ds)]
            
        if names_only: return self.all_stucture_names

        if self.original_ct_coordinates is None:
            self.parse_ct_study_files(patient_id=self.patient_id)

        if self.original_dose_coordinates is None:
            self.parse_rt_dose_files(patient_id=self.patient_id)

        coordinates = copy.deepcopy(self.original_ct_coordinates if resolution.lower() == 'ct' else self.original_dose_coordinates)
        
        if self.relevant_masks is not None:
            assert all(m in self.all_stucture_names for m in self.relevant_masks)
            self.all_stucture_names = self.relevant_masks

        n_threads = n_threads or mp.cpu_count()//2 if parallelize else 1
        structure_names = structure_names or self.all_stucture_names
        structure_subsets = np.array_split(structure_names, n_threads) 
        
        with Pool(processes=n_threads) as pool:
            args = [(struc_ds, subset.tolist(), resolution, coordinates) for subset in structure_subsets]
            results = pool.starmap(self.process_structures_subset, args)
        
        # Merging dictionaries from all processes
        contours = {k: v for d in results for k, v in d.items()}
        # else:
        #     contours = self.process_structures_subset(structure_files, mask_names, resolution, coordinates)

        # Process contours further as required
        return contours
    
    def process_structures_subset(self, 
                                  struc_ds, 
                                  subset, 
                                  resolution, 
                                  coordinates):

        print(os.getpid(), subset)
                
        contours = {}
        
        for ds in struc_ds:
            structures = self.read_structure(ds)
            info = {s['name']: {'number': s['number'], 'generation_algorithm': s['generation_algorithm']} for s in structures}
            for s in subset:
                if s not in self.all_stucture_names: 
                    self.__logger.warning(f'A mask "{s}" was not found for patient-{self.patient_id}.')
                    contours[s] = Mask(self.return_empty_mask(coordinates), resolution, coordinates)
                    continue
                
                if s not in info.keys(): continue
                
                assigned_name = self.get_unique_structure_name(s, contours.keys())
                data = self.get_mask(structures, s, resolution)
                contours[assigned_name] = Mask(data, resolution, coordinates, info[s]['number'], info[s]['generation_algorithm'])
                    
        return contours
                        
    @staticmethod
    def get_unique_structure_name(structure_name, current_structure_names):   
        if structure_name not in current_structure_names: return structure_name 
        structure_number = 1
        final_name = f'{mask_name}_{mask_number}'
        while final_name in current_structure_names:
            structure_number += 1
            final_name = f'{structure_name}_{structure_number}'
        return final_name
    
    @staticmethod
    def return_empty_mask(coordinates):
        dim0, dim1, dim2 = coordinates.z.shape[0], coordinates.y.shape[0], coordinates.x.shape[0]
        return np.zeros((dim0, dim1, dim2), dtype=np.uint8)
    
    @staticmethod
    def read_structure(ds):
        """Auxiliary function for reading the content of the structure file.

        Parameters
        ----------
        ds : pydycom object
            Handle for the file opened using the pydicom module 

        Returns
        -------
        dict
            Dictionary contraining the controu data, color, number, and name of 
            the contrours.
        """
        contours = []
         
        for i in range(len(ds.ROIContourSequence)):
            contour = {}
            
            try:
                contour['contour_data'] = [s.ContourData for s in ds.ROIContourSequence[i].ContourSequence]
                contour['color'] = ds.ROIContourSequence[i].ROIDisplayColor
                contour['number'] = ds.ROIContourSequence[i].ReferencedROINumber
                contour['name'] = str(ds.StructureSetROISequence[i].ROIName).lower()
                contour['generation_algorithm'] = str(ds.StructureSetROISequence[i].ROIGenerationAlgorithm).lower()
                contours.append(contour)  
            except:
                pass
            
        return contours

    def get_mask(self, structures, name, resolution = 'ct', method = 'interpolate'):
          
        print(f'Process ID: {os.getpid()} - {name}')
        
        method = self.mask_generation_method if method is None else method
           
        z = self.original_ct_coordinates.z[:] if method == 'interpolate' else self.original_dose_coordinates.z[:]
        y_0 = self.original_ct_coordinates.image_position[1] if method == 'interpolate' else self.original_dose_coordinates.image_position[1]
        dy = self.original_ct_coordinates.dy if method == 'interpolate' else self.original_dose_coordinates.dy
        x_0 = self.original_ct_coordinates.image_position[0] if method == 'interpolate' else self.original_dose_coordinates.image_position[0]
        dx = self.original_ct_coordinates.dx if method == 'interpolate' else self.original_dose_coordinates.dx
        z_min, z_max = np.round(z.min(),self.coordinate_precision), np.round(z.max(),self.coordinate_precision)
            
        # allocate volume for mask using the shape of the original CT volume
        shape = self.original_dose_shape if resolution == 'dose' and method != 'interpolate' else self.original_ct_shape
        mask = np.zeros(shape, dtype=np.uint8) 
        
        # round the z coordinates to avoid floating point errors
        z = np.round(z,self.coordinate_precision)
                  
        ## Grab the structure data that matches the name of the desired mask
        structure_data = [np.array(i).reshape(-1,3) for i in [s['contour_data'] for s in structures if s['name'].lower() == name.lower()][0]]
                    
        for nodes in structure_data: 
            
            z_node = np.round(nodes[0, 2],self.coordinate_precision)  
                                 
            if np.logical_and(z_node >= z_min, z_node <= z_max): # ignore slices outside of the CT volume
            
                z_index = np.where(z == z_node)[0][0]
         
                r = (nodes[:, 1] - y_0) / dy
                c = (nodes[:, 0] - x_0) / dx 
                
                # make values larger than max index equal to max index
                r[np.where(r > mask.shape[1]-1)] = mask.shape[1]-1
                c[np.where(c > mask.shape[2]-1)] = mask.shape[2]-1
                
                rr, cc = polygon(r, c)
            
                mask[z_index, rr, cc] += 1
          
        mask[np.where(mask>1)] = 0 # account for holes (mask ==2) in a structure
                      
        if resolution == 'dose' and method == 'interpolate':
            oc = (self.original_ct_coordinates.z, self.original_ct_coordinates.y, self.original_ct_coordinates.x)
            ic = (self.original_dose_coordinates.z, self.original_dose_coordinates.y, self.original_dose_coordinates.x)
            return interpolate_volume(mask, oc, ic, intMethod=self.mask_interpolation_technique, boundError=0, fillValue=0)
        
        return mask
        
    def convert_hu_to_rsp(self, ct=None, in_place=True, interpolation_kind = 'linear', lut_directory=None):
        
        # Read LUT for HU to RSP conversion
        if lut_directory is not None: self.rsp_lut_directory = lut_directory
        if self.rsp_lut_directory is None: 
            self.__logger.error('No LUT directory was provided for the HU to RSP conversion.')
            raise ValueError('No LUT directory was found or provided for the HU to RSP conversion.')
        
        rsp_lut_dir = os.path.join(self.rsp_lut_directory, 'mda_relative_stopping_power.csv') 
        
        try:
            HU_2_RSP_LUT = pandas.read_csv(rsp_lut_dir)
        except:
            self.__logger.error(f'Error reading the HU to RSP LUT from "{rsp_lut_dir}". This should be a CSV file.')
            raise ValueError(f'Error reading the HU to RSP LUT from "{rsp_lut_dir}".')
        HU = HU_2_RSP_LUT['HU'].values
        rsp = HU_2_RSP_LUT['rsp'].values
        
        # create interpolating function
        rsp2Hu = interp1d(HU, rsp, kind = interpolation_kind)

        if ct is None: ct = self.ct
        data = ct.data

        # Ensure that the min and max values of CT's HU values are inside acceptable range
        badLowInds = np.where(data < HU.min())
        data[badLowInds] = HU.min()

        badHighInds = np.where(data > HU.max())
        data[badHighInds] = HU.max()

        # Use interpolating function to convert the CT array
        rspVol = rsp2Hu(data.flatten()).reshape(ct.shape)

        if in_place: 
            ct.data = rspVol
        else:
            return rspVol

    def __get_patient_list(self, patient_data_directory):
        if os.path.isdir(patient_data_directory):
            return self.identify_patient_files(patient_data_directory)
        elif os.path.isfile(patient_data_directory) and patient_data_directory.split('.')[-1] in ['h5', 'hdf5']:
            with h5py.File(patient_data_directory, 'r') as hf:
                return list(hf.keys())
        else:
            raise Exception (f"Unable to work with the patient data directory: {patient_data_directory}")

    def get_dicom_data_report(self, patient_list=None, save = True):
        
        # grab list of patient files to explore
        if patient_list is None: 
            patient_list = self.__get_patient_list(self.patient_data_directory)
        else:
            if type(patient_list) != type([]): patient_list = [patient_list]
            patient_list = sorted([str(n) for n in patient_list])
        
        if self.parallelize is None: self.identify_parallel_capabilitie  # check if parallel processing should be used
        
        def get_report(my_patients, infer_rx_dose, process_id=None, save_file = True):
            
            data_report = {p:{} for p in my_patients}
             
            # check if parallel processing is possible
            if not self.parallelize:
                if self.echo_progress: my_patients = tqdm(my_patients, desc="Generating basic report", leave=True)
            
            for p in my_patients: 
                
                self.reset()
                                            
                try:        
                    self.parse_dicom_files(p, mask_names_only=True) # parse the data and grab contour names only
                    beams = list(self.plan.beam.keys())
                    data_report[p]['number_of_beams'] = len(beams) if beams != [] else 1
                    data_report[p]['radiation_type'] = self.radiation_type
                    data_report[p]['gantry_angles'] = [self.plan.beam[b].gantry_angle for b in beams]
                    data_report[p]['couch_angles'] = [self.plan.beam[b].patient_support_angle for b in beams]
                    data_report[p]['ct_array_dimensions'] = self.ct.data.shape
                    data_report[p]['dose_array_dimensions'] = self.dose[beams[0]].data.shape
                    data_report[p]['dose_array_resolution'] = {'dx':self.dose[beams[0]].coordinates.dx,
                                                               'dy':self.dose[beams[0]].coordinates.dy,
                                                               'dz':self.dose[beams[0]].coordinates.dz}
                    if self.radiation_type == 'proton':
                        data_report[p]['vsad'] = [self.plan.beam[b].vsad for b in beams]
                    else:
                        data_report[p]['sad'] = [self.plan.beam[b].sad for b in beams]
                        data_report[p]['gantry_rotation_direction'] = self.dose[beams[0]].gantry_rotation_direction
                    data_report[p]['isocenter'] = [self.plan.beam[b].isocenter for b in beams]
                    data_report[p]['contours'] = ','.join(sorted(self.contours))
                    data_report[p]['dose_reference_dose']= self.plan.dose_reference_dose
                    data_report[p]['dose_reference_type']= self.plan.dose_reference_type
                    data_report[p]['dose_reference_description']= self.plan.dose_reference_description
                    data_report[p]['beam_type'] = list(set([self.plan.beam[b].type for b in beams]))
                    data_report[p]['patient_position'] = self.ct.patient_position
                                       
                except Exception as e:
                    self.__logger.error(f"The data for pat-{p} could not be analyzed")
                    self.__logger.error(traceback.format_exc())
                    
            if save_file:
                tag = f'_{process_id}' if process_id is not None else ''
                with open(os.path.join('temp','data',f'basic_data_report{tag}.pickle'), 'wb') as handle:
                    pickle.dump(data_report, handle, protocol=pickle.HIGHEST_PROTOCOL)
                        
            if not self.parallelize: return data_report
                
        if self.parallelize:
            
            self.__logger.info(f'Parallelizing the data report generation using {self.n_threads} threads.')
                        
            processes = [] # initialize a list to store the processes
                
            # divide the patients into groups based on the number of available threads
            chunked_patients = [x for x in np.array_split(patient_list, self.n_threads) if x.size != 0]
            
            # create a process for each thread
            for n, patients in enumerate(chunked_patients):
                p = mp.Process(target=get_report, args=(patients, infer_rx_dose, n, save,))
                p.start()
                processes.append(p)
            
            for p in processes:
                p.join()   
            
        else:
            data_report = get_report(patient_list, infer_rx_dose, save_file = save)
        
        # merge data reports and remove temp files
        if self.parallelize and save:
            data_report = {}
            for n in range(len(chunked_patients)):
                with open(os.path.join('temp','data',f'basic_data_report_{n}.pickle'), 'rb') as handle:
                    data_report.update(pickle.load(handle))
                os.remove(os.path.join('temp','data',f'basic_data_report_{n}.pickle'))
            # save merged data report
            with open(os.path.join('temp','data','basic_data_report.pickle'), 'wb') as handle:
                pickle.dump(data_report, handle, protocol=pickle.HIGHEST_PROTOCOL)
                       
        self.__logger.info(f"Finished generating basic report for all detected patients.")

        return data_report
    
    def infer_rx_dose(self, patient_list=None, min_infered_rx_dose = 20, get_dose_statistics = True):
            
        if patient_list is None: 
            patient_list = self.__get_patient_list(self.patient_data_directory)
        else:
            if type(patient_list) != type([]): patient_list = [patient_list]
            patient_list = sorted([str(n) for n in patient_list])
            
        target_properties = {p : {} for p in patient_list} # TODO: change this to dose information
        rx_dose_info ={'AID':[p for p in patient_list], 'Rx Dose':[], 'Plan Type':[], 'dose_scale':[], 'site':[]}
        
        if self.echo_progress: patient_list = tqdm(patient_list, desc="Infering prescription doses:", leave=True)

        for p in patient_list: 
                
            self.reset()
                                        
            self.parse_dicom_files(p, mask_names_only=True)
            
            # Indentify target volumes
            all_target_volumes = [x for x in self.contours if self.user_inputs["TYPE_OF_TARGET_VOLUME"] in x]
            len_target_name = len(self.user_inputs["TYPE_OF_TARGET_VOLUME"])
            target_volumes = [x for x in all_target_volumes if x[:len_target_name] == self.user_inputs["TYPE_OF_TARGET_VOLUME"]]        
            
            # Infer the prescription dose 
            rx = [self.__find_posible_rx_dose(x) for x in target_volumes]
            rx = list(set([str(x) for x in rx if float(x) >=min_infered_rx_dose]))    
            rx = [x for x in rx if x != '0']
            if len(rx) > 1: # if multiple rx doses are found, check if they are multiples of each other
                rx = [float(x)/100 if float(x) > 100 else float(x) for x in rx]
                rx = [str(x) for x in set(rx)]
                
            target_properties[p]['infered_rx_dose'] = ','.join(rx) if rx != [] else '0'
            target_properties[p]['all_target_like_structures'] = ','.join([x for x in self.contours if self.user_inputs["TYPE_OF_TARGET_VOLUME"] in x])
            rx_dose_info['Rx Dose'].append(','.join(rx) if rx != [] else '0')
            rx_dose_info['Plan Type'].append('Unknown')
            rx_dose_info['dose_scale'].append(1.0)
            rx_dose_info['site'].append('Unknown') 
            
            cum_dose = self.cumulative_dose
            
            for t in target_volumes:
                target_properties[p][t] = {} # TODO: use dose information variable instead of results
                self.__find_posible_rx_dose(t)
                # parse target mask
                target_data = self.parse_structure_files(mask_names = t, resolution = 'dose')[t].data
                # get the cumulative dose
                dose_in_mask = cum_dose[np.where(target_data>0)]
                
                # find possible rx dose
                rx_from_name = self.__find_posible_rx_dose(t)
                # find dose statistics
                target_properties[p][t]['D95'] = np.percentile(dose_in_mask, 100-max(0, min(100, 95)))
                target_properties[p][t]['D98'] = np.percentile(dose_in_mask, 100-max(0, min(100, 98)))
                target_properties[p][t]['D2'] = np.percentile(dose_in_mask, 100-max(0, min(100, 2)))
                target_properties[p][t]['mean_dose'] = np.mean(dose_in_mask)
                target_properties[p][t]['max_dose'] = np.max(dose_in_mask)
                
        # save target properties as json nicely formatted
        with open(os.path.join('temp','data','target_properties.json'), 'w') as f:
            json.dump(target_properties, f, indent=4)         
        
        # save rx dose information as csv
        rx_dose_info = pandas.DataFrame(rx_dose_info)
        rx_dose_info.to_csv(os.path.join('temp','data','rx_dose_info.csv'), index=False)
                              
    def __find_posible_rx_dose(self, string):
        string = string.replace(self.user_inputs["TYPE_OF_TARGET_VOLUME"], '')
        if 'mm' in string: return 0 # return 0 if the string contains units of mm
        pattern = r'\d+(?:\.\d+)?'  # Matches whole numbers and decimal numbers
        matches = {float(x):x for x in re.findall(pattern, string)}
        return matches[max(matches.keys())] if matches != {} else 0

    def store_patient_data_as_hdf5(self, patient_data_directory=None, patient_list=None, mask_resolution = None, output_directory=None):
        
        self.mask_resolution = mask_resolution if mask_resolution is not None else 'dose'

        if patient_data_directory is not None: self.patient_data_directory = patient_data_directory

        if output_directory is None: 
            output_directory = os.path.join('temp','data', 'patient_data.h5')   

        # Create h5file
        if self.write_new_hdf5_file:
            out_file = h5py.File(output_directory,'w'); out_file.close()

        if patient_list is None: patient_list = self.identify_patient_files()

        if self.echo_progress: patient_list = tqdm(patient_list, desc="Saving HDF5 patient file", leave=True)
        
        for p in patient_list:

            self.patient_id, files = p, self.run_initial_check(p)
            
            self.ct = self.parse_ct_study_files(files['ct'])

            self.append_data_to_hdf5_file(output_directory, data_type='ct')
            
            # Parse the dose volime and (optinally) the plan 
            if 'plan' in files.keys() and files['plan'] != []:
                self.dose, self.plan = self.parse_rt_dose_files(files['dose'], files['plan'])
            else:
                self.dose = self.parse_rt_dose_files(files['dose'])
                self.plan = Plan()

            self.append_data_to_hdf5_file(output_directory, data_type='dose')

            # Parse the contours
            for f in files['structures']:
            
                self.contours = self.parse_structure_files([f], resolution=self.mask_resolution)

                self.append_data_to_hdf5_file(output_directory, 'contours')

            self.parse_dicom_files(p)

        self.__logger.info(f"The patient data was stored as an HDF5 file in {output_directory}.")
    
    def append_data_to_hdf5_file(self, output_directory, data_type):
        """Append the data for each patient to an HDF5 file for faster I/O

        Parameters
        ----------
        `output_directory` : str
            directory for the folder where the patient data file will be stored.
        """

        if data_type == 'ct':
            with h5py.File(output_directory,'a') as f:

                # CT data
                dset = f.create_dataset(name='/'.join([self.patient_id, 'ct']), data = self.ct.data, compression = self.compression)
                dset.attrs['units'] = self.ct.units
                dset.attrs['rescale_slope'] = self.ct.rescale_slope
                dset.attrs['rescale_intercept'] = self.ct.rescale_intercept
                dset.attrs['units'] = self.ct.units
                dset.attrs['resolution'] = self.ct.resolution
                dset.attrs['x'] = self.ct.coordinates.x
                dset.attrs['y'] = self.ct.coordinates.y
                dset.attrs['z'] = self.ct.coordinates.z
                dset.attrs['dx'] = self.ct.coordinates.dx
                dset.attrs['dy'] = self.ct.coordinates.dy
                dset.attrs['dz'] = self.ct.coordinates.dz
                dset.attrs['image_position'] = self.ct.coordinates.image_position

        elif data_type == 'dose':
                
            with h5py.File(output_directory,'a') as f:
                
                # Dose data
                if self.radiation_type.lower() == 'pronton':
                    for b in self.dose.keys():
                        dset = f.create_dataset(name='/'.join([self.patient_id, 'dose', str(b)]), data = self.dose[b].data, compression = self.compression )
                        dset.attrs['dose_grid_scaling'] = self.dose[b].dose_grid_scaling
                        dset.attrs['dose_units'] = self.dose[b].dose_units
                        dset.attrs['beam_type'] = self.dose[b].beam_type
                        dset.attrs['vsad'] = self.dose[b].vsad
                        dset.attrs['number_of_control_points'] = self.dose[b].number_of_control_points
                        dset.attrs['gantry_angle'] = self.dose[b].gantry_angle
                        dset.attrs['patient_support_angle'] = self.dose[b].patient_support_angle
                        dset.attrs['table_top_pitch_angle'] = self.dose[b].table_top_pitch_angle
                        dset.attrs['table_top_roll_angle'] = self.dose[b].table_top_roll_angle
                        dset.attrs['isocenter'] = self.dose[b].isocenter
                        dset.attrs['resolution'] = self.dose[b].resolution
                        dset.attrs['x'] = self.dose[b].coordinates.x
                        dset.attrs['y'] = self.dose[b].coordinates.y
                        dset.attrs['z'] = self.dose[b].coordinates.z
                        dset.attrs['dx'] = self.dose[b].coordinates.dx
                        dset.attrs['dy'] = self.dose[b].coordinates.dy
                        dset.attrs['dz'] = self.dose[b].coordinates.dz
                        dset.attrs['image_position'] = self.dose[b].coordinates.image_position
                        
                #TODO: add values for photon dose
                
                # Plan data
                plan = f.create_group(name=f'{self.patient_id}/plan')
                plan.attrs['number_of_beams'] = self.plan.number_of_beams
                plan.attrs['plan_label'] = self.plan.plan_label
                plan.attrs['patient_sex'] = self.plan.patient_sex
                plan.attrs['plan_name'] = self.plan.plan_name

        elif data_type == 'contours':
            with h5py.File(output_directory,'a') as f:
                for c in self.contours.keys():
                    name = str(c).replace("/","-")
                    dset = f.create_dataset(name='/'.join([self.patient_id, 'contours', self.contours[c].structure_set, str(name)]), 
                                            data = self.contours[c].data, compression = self.compression )
                    dset.attrs['resolution'] = self.contours[c].resolution

    def read_data_from_hdf5(self, patient_id = None, contours_type='clinical'):
        """Read patient data from HDF5 files generated by this class using  the `store_patient_data_as_hdf5` method. 
        This allows for fast I/O but increases the hard-drive memory consumption by creating (in many cases) a large 
        HDF5 file with all of the patient data. Please ensure that enough memory is available when using this 
        method of data wrangling.

        Parameters
        ----------
        `patient_id` : str, optional
            ID of the patient data to read, by default None
        """

        if patient_id is not None: self.patient_id = patient_id

        with h5py.File(self.user_inputs["DIRECTORIES"]["raw_patient_data"],'r') as f:

            # Plan
            plan = f[f'{self.patient_id}/plan']
            number_of_beams = plan.attrs['number_of_beams']
            patient_sex = plan.attrs['patient_sex']            
            plan_label = plan.attrs['plan_label'] 
            plan_name = plan.attrs['plan_name']
            modality = plan.attrs['modality']
            self.plan = Plan(number_of_beams, patient_sex, plan_label, plan_name, modality)

            # CT
            params = [f[f'{self.patient_id}/ct'][...].shape]
            params.append(f[f'{self.patient_id}/ct'].attrs['resolution'])
            params.append(f[f'{self.patient_id}/ct'][...].max())
            params.append(f[f'{self.patient_id}/ct'][...].min())
            params.append(f[f'{self.patient_id}/ct'].attrs['units'])
            params.append(f[f'{self.patient_id}/ct'].attrs['rescale_slope'])
            params.append(f[f'{self.patient_id}/ct'].attrs['rescale_intercept'])
            params.append(f[f'{self.patient_id}/ct'][...])
            x, y, z = f[f'{self.patient_id}/ct'].attrs['x'], f[f'{self.patient_id}/ct'].attrs['y'], f[f'{self.patient_id}/ct'].attrs['z']
            dx, dy, dz = f[f'{self.patient_id}/ct'].attrs['dx'], f[f'{self.patient_id}/ct'].attrs['dy'], f[f'{self.patient_id}/ct'].attrs['dz']
            image_position = f[f'{self.patient_id}/ct'].attrs['image_position']
            self.original_ct_coordinates = Coordinates(x,y,z,dx,dy,dz,image_position)
            params.append(self.original_ct_coordinates)
            self.ct = CT(*params)

            # Dose
            self.dose = {int(bn):None for bn in f[f'{self.patient_id}/dose'].keys()}
            for k in self.dose.keys():
                params = [f[f'{self.patient_id}/dose/{k}'][...].shape]
                params.append(f[f'{self.patient_id}/dose/{k}'][...].max())
                params.append(f[f'{self.patient_id}/dose/{k}'][...].min())
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['resolution'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['dose_grid_scaling'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['dose_units'])
                params.append(f[f'{self.patient_id}/dose/{k}'][...])
                x, y, z = f[f'{self.patient_id}/dose/{k}'].attrs['x'], f[f'{self.patient_id}/dose/{k}'].attrs['y'], f[f'{self.patient_id}/dose/{k}'].attrs['z']
                dx, dy, dz = f[f'{self.patient_id}/dose/{k}'].attrs['dx'], f[f'{self.patient_id}/dose/{k}'].attrs['dy'], f[f'{self.patient_id}/dose/{k}'].attrs['dz']
                image_position = f[f'{self.patient_id}/dose/{k}'].attrs['image_position']
                self.original_dose_coordinates = Coordinates(x,y,z,dx,dy,dz,image_position)
                params.append(self.original_dose_coordinates)
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['modality'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['beam_type'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['vsad'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['number_of_control_points'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['gantry_angle'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['patient_support_angle'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['table_top_pitch_angle'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['table_top_roll_angle'])
                params.append(f[f'{self.patient_id}/dose/{k}'].attrs['isocenter'])
                self.dose[k] = Dose(*params)
                
            #TODO: differentiate between photon and proton dose

            # Contours
            self.contours = {c:None for c in  f[f'{self.patient_id}/contours/{contours_type}'].keys()}
            for k in self.contours.keys():
                data = f[f'{self.patient_id}/contours/{contours_type}/{k}'][...]
                resolution = f[f'{self.patient_id}/contours/{contours_type}/{k}'].attrs['resolution']
                if resolution == 'ct':
                    coordinates = self.original_ct_coordinates
                else:
                    coordinates = self.original_dose_coordinates
                self.contours[k] = Mask(data, resolution, coordinates)
            # Grab CTVs in case the desired contours are automatic
            if contours_type == 'auto':
                for k in [c for c in f[f'{self.patient_id}/contours/clinical'].keys() if 'ctv' in c and 'fsctv' not in c and 'pctv' not in c]:
                    data = f[f'{self.patient_id}/contours/clinical/{k}'][...]
                    resolution = f[f'{self.patient_id}/contours/clinical/{k}'].attrs['resolution']
                    coordinates = self.original_ct_coordinates if resolution == 'ct' else self.original_dose_coordinates
                    self.contours[k] = Mask(data, resolution, coordinates)

    @property
    def cumulative_dose(self):
        """Compute the cumulative dose volume for all of the beams in the plan.

        Returns
        -------
        ndarray
            3D array containing the cumulative dose volume.
        """
        return np.sum([self.dose[b].data for b in self.dose.keys()], axis = 0)

    @property
    def identify_parallel_capabilities(self):
        
        self.n_threads = self.user_inputs['PARALLELIZATION']['number_of_processors'] 
        if self.n_threads is None: self.n_threads = mp.cpu_count()//2 # assumes that half of the CPUs are available for parallel processing
        if self.n_threads >  mp.cpu_count():
            self.__logger.warning(f"The number of threads ({self.n_threads}) exceeds the number of available CPUs ({mp.cpu_count()}).")
            self.n_threads = mp.cpu_count()//2
            self.__logger.warning(f"Parallel processing will be used with {self.n_threads} threads")
        self.parallelize = True if self.user_inputs['PARALLELIZATION']['parallelize_data_preprocessing'] and self.n_threads > 1 else False   
        
        if not self.parallelize:
            self.n_threads = 1
            self.__logger.info("Preprocessing will be performed sequentially. Consider using parallel processing to speed up the process.")
            return

        # log the available resources
        self.__logger.info(f"Number of CPUs (Virtual): {mp.cpu_count()}")
        
        # get the number of threads ready to perform work concurrently
        if self.user_inputs['PARALLELIZATION']['number_of_processors'] is None and self.parallelize:
            self.n_threads = mp.cpu_count()//2 # assumes that half of the CPUs are available for parallel processing
            self.__logger.info(f"Parallel processing will be used with {self.n_threads} threads")
        else:
            self.n_threads = int(self.user_inputs['PARALLELIZATION']['number_of_processors'])
            if self.parallelize and self.n_threads > mp.cpu_count():
                self.__logger.warning(f"The number of threads ({self.n_threads}) exceeds the number of available CPUs ({mp.cpu_count()}).")
                self.__logger.warning(f"Parallel processing will be used with {mp.cpu_count()} threads")
                self.n_threads = mp.cpu_count()
            elif self.n_threads == 1:
                self.parallelize = False
                self.__logger.info("Parallel processing will not be used since only 1 thread was requested")
        
        if self.n_threads > 1 and self.parallelize:
            self.__logger.info(f"Parallel processing will be used with {self.n_threads} threads")
        else:       
            self.__logger.info("Preprocessing will be performed sequentially. Consider using parallel processing to speed up the process.")
    
 
user_inputs_dir = "configuration_files/user_config.json"
dt = DicomToolbox(user_inputs_dir)
patient_ID = 1

dt.parse_dicom_files(patient_ID, mask_resolution='dose')


In [None]:
dt.contours.keys()

In [None]:
import sys
sys.path.insert(0, "./rt_utils")

import RTStructBuilder
import matplotlib.pyplot as plt




**Creating a basic report**

Several of the routines used for preprocessing rely on a *basic report* that can be created with the `DicomToolbox` class. The basic report is created by executing the `get_dicom_data_report()` method. This is normally done automatically by the classes and should be renewed for every new dataset. Currently, the report needs to be deleted manually when the data changes. In the next version, we will add a check to see if the data has changed and renew the report automatically.


The report is helpful to quickly inspect a set of patients. You can give a list of patients as an input with the `patient_list` variable if you want to run this method on a specific set. If a list of patients in not given, the method can use the `identify_patient_files()` to detect all patients with a complete set of DICOM-RT files. This method returns a dictionary with the result of the basic report. The result is also saved as a pickle file in the temporary folder. 

**Important**: If you do not have the prescription values for the patients, you can set the `infer_rx_dose` flag to `True` when calling the `create_basic_report()` method. This will infer the prescription dose from the structure names. This is not a perfect method, but it is helpful when you do not have the prescription dose for the patients. You can further refine the file produced by manually editing the CSV file produced when infering the prescription dose.


In [None]:
basic_report = dt.get_dicom_data_report()

# basic_report = dt.get_dicom_data_report(patient_list=['67'])

In [None]:
dt.infer_rx_dose()

In [None]:
for k in basic_report.keys():
    print(k, basic_report[k]['beam_type'])

For your convinience, you can quickly check the name of potential target volumes for the patients. The result will be saved in a file called `target_report.json` in the `temp/data` directory. Note that the candidate target volumes are found based on the value you specify for `TYPE_OF_TARGET_VOLUME` in the configuration JSON file. Once the file is generated, you can open it and compare it to the file infering the prescription dose. This should help you clean the prescription dose file.

In [None]:
patient_ids = list(basic_report.keys())
patient_ids.sort(key=natural_keys)
target_report = {p:basic_report[p]['candidate_target_volumes'].split(",") for p in patient_ids}

with open(os.path.join('temp', 'data', 'target_report.json'), 'w') as fp:
    json.dump(target_report, fp, indent=4)

### Storing the patient data as a single HDF5 file

**Not available for Photon data yet**

In some circumstances, it may be convinient to save the data as a single HDF5 file to expedite the development of new routine. This is because the I/O of patient data from an HDF5 file is many times faster than reading data from raw DICOM files in python. The `DicomToolbox` class was designed to know how to obtain all of the relevant patient information from HDF5 files it creates. This files contain all of the details needed by other classes that later perform the steps involved in preprocessing and creating the training data.

In [None]:
# dt.store_patient_data_as_hdf5()

In [None]:
# import h5py

# with h5py.File('temp/data/patient_data.h5', 'r') as hf:
#     print(hf['16/contours/clinical'].keys())
#     ct = hf['16/ct'][...]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Define the data
dose = np.array([25.03309, 25.03309, 50.06618, 50.06618, 75.09927, 75.09927, 100.1324, 100.1324, 149.8357, 149.8357, 199.9019, 249.9681, 249.9681, 300.0343, 350.1005, 350.1005])
reading = np.array([88918, 89803, 179272, 180574, 271044, 272178, 364934, 361838, 551177, 555773, 744155, 949256, 943599, 1133267, 1354215, 1366654])

# Fit a linear model to the data
coeffs = np.polyfit(dose, reading, 1)

# Generate y-values based on the fit
fit = np.polyval(coeffs, dose)

# Plot the data
plt.scatter(dose, reading, label='Data')

# Plot the fit
plt.plot(dose, fit, 'r-', label='Fit: a=%5.3f, b=%5.3f' % (coeffs[0], coeffs[1]))

plt.xlabel('Dose (cGy)')
plt.ylabel('Reading')
plt.legend()
plt.show()

---
## The DataExplorer package

This class was developed to help explore a set of patient folders. The `DataExplorer` class inherits the content of the `DicomToolbox` class and can, therefore, use its functionalities. 

In [None]:
from data_explorer import DataExplorer

In [None]:
de = DataExplorer(user_inputs_dir)

### Creating a human friendly report for the DICOM data
The `create_data_report()` method in the `DataExplorer` class generates a "csv" file containing a report of the patient data. This report is also stored as a class variable for other methods to use. For example, the method used to filter the patients relies on this report to detect patients with desirable characteristics.

In [None]:
de.create_data_report(save=True)

### Filtering Patients
The `DataExplorer` class can filter patients based on the properties of their data, like the number of beams used in the plan and the geometries of the beams. 

In [None]:
selected = de.apply_patient_filter()

print(f'No. of selected patients: {len(selected)}')

### Generate a report of the available contours

With this class, a comprehensive report of the available contours can be generated. This report is particularly useful when adding new patients to the data pool since it allows you to compare the name of the contours that were selected with those available for each patient. 

In [None]:
de.generate_contour_report()

Run the cell below for an example on how to find targets for a single patient

In [None]:
de.find_targets(pat_id=159)

Run the cell below for an example on how to find OARs for a single patient

In [None]:
de.find_oars(pat_id=6)

The output of this method consists of four JSON files describing the available and selected contours for the patients. The files are saved in the `temp/data` folder. You should carefully assess the quality of the picks before performing preprocessing. Since it is unlikely that we have captured all variations for the names of the contours, you may need to make changes to the code, the files used to find structures, or add structures manually.

---
### Analyzing the data further

The routines below can help you get a sense for the properties of your data. They are not part of the preprocessing pipeline, but they can help you understand the data better.

#### Gender information

This will not be informative if the gender was annonimized. You may need to run this on the original DICOM files. If the original data is in a directory different than the one in your `user_config` file, you can specify the path to the data in the `DATA_DIRECTORY` variable below.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from pydicom import dcmread
import pickle

DATA_DIRECTORY = dt.patient_data_directory  # directory for data containing gender information

with open(os.path.join('temp', 'data', 'basic_data_report.pickle'), 'rb') as f:
    basic_report = pickle.load(f) # baic data report used by some of the cells below

In [None]:
def get_patient_sex(file_path):
    try:
        ds = dcmread(file_path)
        return ds.PatientSex if 'PatientSex' in ds else None
    except Exception as e:
        # print(f"Error reading file {file_path}: {e}")
        return None

genders = []
for folder in os.listdir(DATA_DIRECTORY):
    patient_folder = os.path.join(DATA_DIRECTORY, folder)
    gender_found = False
    for root, dirs, files in os.walk(patient_folder):
        if gender_found: break
        for file in files:
            file_path = os.path.join(root, file)
            gender = get_patient_sex(file_path)
            if gender is not None:
                genders.append(gender)
                gender_found = True
                break
        
# Count unique gender types
gender_types, gender_type_counts = np.unique(genders, return_counts=True)
# Create a pie-chart with patient gender
plt.pie(gender_type_counts, labels=gender_types, autopct='%1.1f%%')
plt.legend()
plt.show()

In [None]:
# Table showing patient positions 
positions = [basic_report[patient]['patient_position'] for patient in basic_report]
position_types, position_type_counts = np.unique(positions, return_counts=True)
position_dict = dict(zip(position_types, position_type_counts))
position_df = pd.DataFrame.from_dict(position_dict, orient='index', columns=['Count'])
position_df.index.name = 'Patient Position'
position_df

In [None]:
# Table showing the beam types 
beam_types = [basic_report[patient]['beam_type'] for patient in basic_report]
beam_type_types, beam_type_type_counts = np.unique(beam_types, return_counts=True)
beam_type_dict = dict(zip(beam_type_types, beam_type_type_counts))
beam_type_df = pd.DataFrame.from_dict(beam_type_dict, orient='index', columns=['Count'])
beam_type_df.index.name = 'Beam Type'
beam_type_df

### Get the size of the dose grid for the selected and entire patient population

The code below will inspect the dimensions (number of voxels) and resolution (voxel size) of the dose grid for the selected patients and the entire patient population. This is useful to check if the dose grid is consistent across patients. If the dose grid is not consistent, you may need to resample the dose grid to a common resolution. This is currently an option during preprocessing. 

In [None]:
dose_grid_dimensions = {'z':[], 'y':[], 'x':[]}
for p in selected:
    dose_grid_dimensions['z'].append(basic_report[p]['dose_array_dimensions'][0])
    dose_grid_dimensions['y'].append(basic_report[p]['dose_array_dimensions'][1])
    dose_grid_dimensions['x'].append(basic_report[p]['dose_array_dimensions'][2])
    
# generate seaborn plot for the dose grid dimensions side by side
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
sns.histplot(dose_grid_dimensions['z'], alpha=0.5, label='z', color='red')
plt.xlabel('z dimension')
plt.subplot(1,3,2)
sns.histplot(dose_grid_dimensions['y'], alpha=0.5, label='y', color='green')
plt.xlabel('y dimension')
plt.subplot(1,3,3)
sns.histplot(dose_grid_dimensions['x'], alpha=0.5, label='x', color='blue')
plt.xlim(np.min(dose_grid_dimensions['x']))
plt.xlabel('x dimension');

#### Voxel size analysis

In [None]:
# report the maximum size for each dimension
print(f'Maximum size for z-axis: {np.max(dose_grid_dimensions["z"])}')
print(f'Maximum size for y-axis: {np.max(dose_grid_dimensions["y"])}')
print(f'Maximum size for x-axis: {np.max(dose_grid_dimensions["x"])}')

dose_grid_dimensions = {'dz':[], 'dy':[], 'dx':[]}
for p in basic_report.keys():
    dose_grid_dimensions['dz'].append(basic_report[p]['dose_array_resolution']['dz'])
    dose_grid_dimensions['dy'].append(basic_report[p]['dose_array_resolution']['dy'])
    dose_grid_dimensions['dx'].append(basic_report[p]['dose_array_resolution']['dx'])
    
# print a table of the dose grid resolution
df = pd.DataFrame(dose_grid_dimensions)
df.describe()

In [None]:
# plot the dose grid resolution
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
sns.histplot(dose_grid_dimensions['dz'], alpha=0.5, label='z', color='red')
plt.xlabel('z resolution')
plt.subplot(1,3,2)
sns.histplot(dose_grid_dimensions['dy'], alpha=0.5, label='y', color='green')
plt.xlabel('y resolution')
plt.subplot(1,3,3)
sns.histplot(dose_grid_dimensions['dx'], alpha=0.5, label='x', color='blue')
plt.xlabel('x resolution');

### Perform an evaluation of the dose properties

This method checks the properties of the dose inside the target volumes and OARs to help identify potential issues like incorrectly assigned prescription dose values. 

In [None]:
de.get_dose_property_reports(target_metrics={'D':[95,98,99, 'max', 'mean'],'V':[95, 100, 105]}, oar_metrics = None)

In [None]:
de.get_dose_metric_report_as_csv(target_analysis_dir = 'temp/data/target_dose_analysis.json', metric='V_100')

### Get information about the contours available for the selected patients

You can visualize the frequency of some of the OARs for your patients by running the cell below. This will generate a bar plot with the frequency of the OARs in the selected patients. 

In [None]:
# read report of the oars found
with open(os.path.join('utilities', 'liver', 'selected_oars_report.json'), 'r') as f:
    selected_oars_report = json.load(f)

contours_frequency_report = {c:0 for c in selected_oars_report[list(selected_oars_report.keys())[0]].keys()}

for p in selected_oars_report.keys():
    for c in selected_oars_report[p].keys():
        if selected_oars_report[p][c] != 0:
            contours_frequency_report[c]+=1
            
# create histogram of the frequency of each contour
plt.figure(figsize=(15,5))
plt.bar(contours_frequency_report.keys(), contours_frequency_report.values(), color='blue')
plt.xticks(rotation=90)
plt.ylabel('Frequency')
plt.title('Frequency of each contour');

---
## The PreProcessing package

In [None]:
from preprocessing import PreProcessing

In [None]:
user_inputs_dir = "configuration_files/user_config.json"
dp = PreProcessing(user_inputs_dir)
# dp.prepare_training_data(selected_patients = [236]) # you can specify the patients you want to prepare
dp.prepare_training_data(use_contour_record = True) 

In [None]:
for n in dp.plan.beam.keys():
    print(n, dp.plan.beam[n])
    # print(dp.plan.beam[n].sad)

In [None]:
import h5py
import matplotlib.pyplot as plt

colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', 'yellow', 'lime', 
          'navy', 'teal', 'maroon', 'aqua', 'fuchsia', 'silver', 'gold', 'coral', 'beige', 'mint', 'lavender']

with h5py.File('temp/data/preprocessed.h5', 'r') as f:
    print(len(f.keys()))
    dose = f['1/dose'][:]
    ct = f['1/ct'][:]
  
    plt.imshow(ct[100,:,:], cmap = 'gray')
    plt.imshow(dose[100,:,:], cmap = 'jet', alpha = 0.5) 
    for n,c in enumerate(f['1/contours'].keys()):
        print(n,c, colors[n])
        plt.contour(f['1/contours'][c][100,:,:], colors=[colors[n]])
    plt.axis('off');

In [None]:
import h5py 

with h5py.File('temp/data/preprocessed.h5', 'r') as hf:
    print(hf['123/contours'].keys())

___ 
# Train the model

In [None]:
import os, logging, logging.config, random, argparse
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # could be given with argparse
import tensorflow as tf
from input_pipeline_tools import InputPipelineTools
from training_pipeline_tools import TrainingPipelineTools
import numpy as np
from tfrecords_parser import parse_training_tfrecords, parse_validation_tfrecords, random_occlusion, augment_dataset

# Initialize the input pipeline tools 
ipt = InputPipelineTools(user_inputs_dir, output_dir=os.path.join('temp', 'data')) 
tpt = TrainingPipelineTools(user_inputs_dir) 

# set random seed
tf.random.set_seed(ipt.seed)
np.random.seed(ipt.seed)
random.seed(ipt.seed)

# Preparing data files for model
## Creating new TFRecords
if ipt.load_data_split and ipt.data_split_info_dir != "none": 
    ipt.get_data_split_info(ipt.data_split_info_dir)
else: 
    ipt.create_data_splits()

ipt.prepare_tfrecords()


In [None]:
# Reading data from TFRecords
train_ds = tf.data.Dataset.list_files(os.path.join(f"{ipt.train_dir}","*.tfrec"), shuffle=True)
val_ds = tf.data.Dataset.list_files(os.path.join(f"{ipt.val_dir}","*.tfrec"), shuffle=False)

# Interleave TFRecrodDataset function to read files
train_ds = train_ds.interleave(tf.data.TFRecordDataset, num_parallel_calls = tf.data.AUTOTUNE, deterministic=False)
val_ds = val_ds.interleave(tf.data.TFRecordDataset, num_parallel_calls = tf.data.AUTOTUNE, deterministic=False)

# Data parsing and augmentation
## Training set
parsing_function = lambda X: parse_training_tfrecords(X,
                                                        features_in_input = tf.constant([x for x in ipt.inputs], dtype=tf.string),
                                                        structures_in_target = tf.constant([x for x in ipt.targets if 'contours/' in x], dtype=tf.string),
                                                        include_voi_weights = tf.constant(ipt.use_weight_matrix))

augmentation_function = lambda inputs, targets, tlc, brc: augment_dataset(inputs, targets, tlc, brc, 
                                                                            input_size = len(ipt.inputs),
                                                                            target_size = len(ipt.targets),
                                                                            aug_parameters = ipt.augmentation_details)

rw_prob = ipt.augmentation_parameters['random_occlusion']['probability']
rw_size = ipt.augmentation_parameters['random_occlusion']['window_size']
random_occlusion_fn = lambda inputs, targets: random_occlusion(inputs, targets, size = tf.constant(rw_size, dtype=tf.int64), probability = tf.constant(rw_prob, dtype=tf.float32))

# Apply random occlusion patch inside of training data if specified
if 'random_occlusion' in  ipt.augmentation_types:
    rw_prob = ipt.augmentation_parameters['random_occlusion']['probability']
    rw_size = ipt.augmentation_parameters['random_occlusion']['window_size']
    random_occlusion_fn = lambda inputs, targets: random_occlusion(inputs, targets, 
                                                                    size = tf.constant(rw_size, dtype=tf.int64), 
                                                                    probability = tf.constant(rw_prob, dtype=tf.float32))

# Read the data from the TFRecordDataset for every patient in the training set
train_ds = train_ds.map(parsing_function, num_parallel_calls = tf.data.AUTOTUNE, deterministic=False)
if ipt.cache_training_dataset: train_ds = train_ds.cache() # cache the data for faster training

# Apply augmentation
train_ds = train_ds.map(augmentation_function, num_parallel_calls = tf.data.AUTOTUNE, deterministic=False)
if 'random_occlusion' in ipt.augmentation_types: train_ds = train_ds.map(random_occlusion_fn, num_parallel_calls = tf.data.AUTOTUNE, deterministic=False)

# Shuffle, repeat, batch and prefetch the data
train_ds = train_ds.shuffle(buffer_size=ipt.shuffle_buffer_size)
train_ds = train_ds.repeat(ipt.dataset_repeats["training-set"]) 
train_ds = train_ds.batch(batch_size=ipt.batch_size, drop_remainder = True)
train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)

## Validation set
parsing_function = lambda X: parse_validation_tfrecords(X, 
                                                        input_size = len(ipt.inputs),
                                                        target_size = len(ipt.targets),
                                                        features_in_input = tf.constant([x for x in ipt.inputs], dtype=tf.string),
                                                        structures_in_target = tf.constant([x for x in ipt.targets if 'contours/' in x], dtype=tf.string),
                                                        include_voi_weights = tf.constant(ipt.use_weight_matrix))

val_ds = val_ds.map(parsing_function, num_parallel_calls=tf.data.AUTOTUNE)
if ipt.cache_validation_dataset: val_ds = val_ds.cache()
val_ds = val_ds.batch(batch_size=ipt.batch_size, drop_remainder = True)
val_ds = val_ds.prefetch(buffer_size=tf.data.AUTOTUNE)

# TRAINING PIPELINE
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    
    model = tpt.get_model()
    model.compile(optimizer=tpt.get_optimizer(),
                    loss=tpt.get_loss_function())

# Train the model (with or without profiler)
if tpt.activate_profiler : tf.profiler.experimental.start('profiler_lok_dir')

history = model.fit(train_ds,   
                    epochs=ipt.epochs,
                    validation_data=val_ds,
                    callbacks = tpt.callbacks)

if tpt.activate_profiler: tf.profiler.experimental.stop()

# generate a plot of the training history
tpt.plot_training_history(history)

In [None]:
# generate a plot of the training history
tpt.plot_training_history(history)

---
# Evaluate the Model

In [None]:
from evaluate_model import evaluate_model

# evaluate_model(user_inputs_dir)

In [None]:
from evaluation_toolbox import ModelEvaluationTools
from tqdm import tqdm
import os

echo_progress = True
user_inputs_dir = "configuration_files/user_config.json"

met = ModelEvaluationTools(user_inputs_dir)

# create model inference folder
met.prepare_inference_folder


# Determine the fold(s) to evaluate
folds = met.eval_folds if met.eval_folds is not None else 0

# Weight directory
weights_dir = met.weights_dir if met.weights_dir is not None else ""

# Sanity checks
if type(folds) == type([]) and len(folds)>1:
    if met.data_set_to_evaluate.lower() != 'test':
        msg = 'Multi-fold evaluation only makes sense on the test set. Give only one fold or change the data set to test.'
        raise ValueError(msg)
    
# prepare the evaluation data
eval_data_set = met.read_data_ids(data_type = met.data_set_to_evaluate)

pat_eval = {p:{} for p in eval_data_set}  
gt_norm_fctr, pr_norm_fctr = 1.0, 1.0
patient_ids = tqdm(eval_data_set) if echo_progress else eval_data_set

# prepare the model if only one fold is used and a prediction file is to be written
if len(folds) == 1 and met.prediction_file_mode in ['w', 'a']: 
    model = met.prepare_model(folds[0], weights_dir, compiled=True)
else:
    model = None
    
# Evaluate the prediction for each patient in the evaluation set 
for p in patient_ids:
    
    # get the dose distribution for the patient
    gt_dose, pr_dose = met.get_dose_volumes(p, folds, weights_dir, model)
                                    
    # Normalize the dose if desired
    gt_dose, pr_dose, gt_norm_fctr, pr_norm_fctr = met.normalize_dose(p, gt_dose, pr_dose, 
                                                                        method=met.normalization_method, 
                                                                        normalize=met.reference_normalization_volume)
        
    # Save the prediction
    if met.prediction_file_mode == 'w':
        met.save_patient_data(p, gt_dose, pr_dose, gt_norm_factor=gt_norm_fctr, pr_norm_factor=pr_norm_fctr)
    elif met.prediction_file_mode == 'a':
        if not os.path.exists(os.path.join(met.model_inferences_dir ,f'AID-{p}.h5')):
            met.save_patient_data(p, gt_dose, pr_dose, gt_norm_factor=gt_norm_fctr, pr_norm_factor=pr_norm_fctr)

In [None]:

# create model inference folder
met.prepare_inference_folder

# Determine the fold(s) to evaluate
folds = met.eval_folds if met.eval_folds is not None else 0

# Weight directory
weights_dir = met.weights_dir if met.weights_dir is not None else ""

# Sanity checks
if type(folds) == type([]) and len(folds)>1:
    if met.data_set_to_evaluate.lower() != 'test':
        msg = 'Multi-fold evaluation only makes sense on the test set. Give only one fold or change the data set to test.'
        raise ValueError(msg)
    
# prepare the evaluation data
eval_data_set = met.read_data_ids(data_type = met.data_set_to_evaluate)

pat_eval = {p:{} for p in eval_data_set}  
gt_norm_fctr, pr_norm_fctr = 1.0, 1.0
patient_ids = tqdm(eval_data_set) if echo_progress else eval_data_set

# prepare the model if only one fold is used and a prediction file is to be written
if len(folds) == 1 and met.prediction_file_mode in ['w', 'a']: 
    model = met.prepare_model(folds[0], weights_dir, compiled=True)
else:
    model = None
    
# Evaluate the prediction for each patient in the evaluation set 
for p in patient_ids:
    
    # get the dose distribution for the patient
    gt_dose, pr_dose = met.get_dose_volumes(p, folds, weights_dir, model)
                                    
    # Normalize the dose if desired
    gt_dose, pr_dose, gt_norm_fctr, pr_norm_fctr = met.normalize_dose(p, gt_dose, pr_dose, 
                                                                        method=met.normalization_method, 
                                                                        normalize=met.reference_normalization_volume)
            
    # Save the prediction
    if met.prediction_file_mode == 'w':
        met.save_patient_data(p, gt_dose, pr_dose, gt_norm_factor=gt_norm_fctr, pr_norm_factor=pr_norm_fctr)
    elif met.prediction_file_mode == 'a':
        if not os.path.exists(os.path.join(met.model_inferences_dir ,f'AID-{p}.h5')):
            met.save_patient_data(p, gt_dose, pr_dose, gt_norm_factor=gt_norm_fctr, pr_norm_factor=pr_norm_fctr)
    
    # Determine minimum dose
    min_dose = met.get_minimum_desired_dose_for_analysis(gt_dose)
    
    # evaluate the prediction for patient
    pat_eval[p] = met.evaluate_prediction(p, gt_dose, pr_dose, met.vol_for_max_dose, min_dose=min_dose)  
                    
# GPR analysis
if met.include_gpr_analysis:
    pat_eval = met.get_gamma_passing_rate(eval_data_set, pat_eval, data_folder=met.model_inferences_dir)

# Save results as a pickle file
with open(os.path.join('temp', 'data',f'evaluation_results_{"-".join([str(x) for x in folds])}.pickle'), 'wb') as handle:
    pickle.dump(pat_eval, handle, protocol=pickle.HIGHEST_PROTOCOL) 

met.get_summay(pat_eval)

---
# A quick way to view the content of a dicom header

In [None]:
from pydicom import dcmread

file_directory = 'write directory here'

with dcmread(file_directory) as ds:
    print(ds)

In [None]:
beams = list({}.keys())
len(beams)

---
# Inspecting Target Dose

Specify the patient ID in `PAT_ID` below and run the cell to prepare the data for inspection. 

You may only need to change values in the USER PARAMETERS section below.

In [None]:
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, Select, HBox
from utilities import interpolate_volume
from dicom_toolbox import DicomToolbox
import logging, os

logger  = logging.getLogger(__name__) # create logger (this and the next 3 lines are optional)
log_file = os.path.join('logs','preprocessing.log')
logging.basicConfig(filename=log_file, filemode='w',
                    level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

######### USER PARAMETERS ##############
PAT_ID = 2153559
CONFIG_FILE_DIR = "configuration_files/user_config.json"
TARGET_TYPE = 'ptv'
CT_MIN_VALUE = -300 # HU
CT_MAX_VALUE = 900 # HU
INITIAL_SLICE = None
COLORMAP = 'jet'
#########################################

dt = DicomToolbox(CONFIG_FILE_DIR) # create an instance of the DicomToolbox class
dt.parse_dicom_files(PAT_ID, mask_resolution='dose') # parse the dicom files
ct = dt.ct.data # grab the ct data
dose = dt.cumulative_dose # grab the cumulative dose data
bn = list(dt.dose.keys())[0] # Grab the first beam name/number
targets = [c for c in dt.contours.keys() if TARGET_TYPE in c] # Grab the target names
structures = [c for c in dt.contours.keys() if TARGET_TYPE not in c] # Grab the structure names

# Create a Select widget
select_targets = Select(
    options=targets,
    value=targets[0],
    description='Targets:',
    rows=len(targets),
)

select_structures = Select(
    options=structures,
    value=structures[0],
    description='Structures:',
    rows=len(structures),
)

original_coordinates = (dt.ct.coordinates.z, dt.ct.coordinates.y, dt.ct.coordinates.x)
new_coordinates = (dt.dose[bn].coordinates.z, dt.ct.coordinates.y, dt.ct.coordinates.x)
ct = interpolate_volume(ct, original_coordinates, new_coordinates, intMethod='linear', boundError=0, fillValue=0)

display(HBox([select_targets, select_structures]))

In [None]:
CT_MIN_VALUE = -300 # HU
CT_MAX_VALUE = 900 # HU
# balance the contrast of the CT to enhance soft tissue
ct[ct<CT_MIN_VALUE] = CT_MIN_VALUE
ct[ct>CT_MAX_VALUE] = CT_MAX_VALUE

In [None]:
######### USER PARAMETERS ##############

MIN_DOSE = 0 # minimum dose to display in Gy

#########################################

plt.clf()  # Clear open figures
fig, ax = plt.subplots(figsize=(7,7))   
plt.tight_layout()

dose_extent = [dt.dose[bn].coordinates.x.min(), dt.dose[bn].coordinates.x.max(), dt.dose[bn].coordinates.y.max(), dt.dose[bn].coordinates.y.min()]
ct_extent = [dt.ct.coordinates.x.min(), dt.ct.coordinates.x.max(), dt.ct.coordinates.y.max(), dt.ct.coordinates.y.min()]
slice_idx = INITIAL_SLICE if INITIAL_SLICE is not None else ct.shape[0]//2

# balance the contrast of the CT to enhance soft tissue
# ct[ct<CT_MIN_VALUE] = CT_MIN_VALUE
# ct[ct>CT_MAX_VALUE] = CT_MAX_VALUE

# coordinates
x, y = dt.dose[bn].coordinates.x, dt.dose[bn].coordinates.y

# Plot initial slice
ct_slice = ax.imshow(ct[slice_idx], cmap='gray', interpolation='none', extent=ct_extent)
# target = dt.contours[select_targets.value].data
structure = dt.contours[select_structures.value].data
dose_sn = np.ma.masked_where(dose[slice_idx] <= MIN_DOSE, dose[slice_idx])
# dose_slice = ax.imshow(dose_sn, cmap=COLORMAP, alpha=0.3, interpolation='none', vmin=np.min(dose), vmax=np.max(dose), extent=dose_extent)
# target_contour = ax.contour(x,y, target[slice_idx], levels=[0.5], colors='yellow')
structure_contour = ax.contour(x,y, structure[slice_idx], levels=[0.5], colors='red')
ax.axis('off')

# Update function for scroll bar
def update(slice_idx):
    global target_contour, structure_contour
    
    ct_slice.set_data(ct[slice_idx])
    # dose_sn = np.ma.masked_where(dose[slice_idx] < MIN_DOSE, dose[slice_idx])
    # dose_slice.set_data(dose_sn)
    ct_slice.set_clim(vmin=np.min(ct[slice_idx]), vmax=np.max(ct[slice_idx]))
    # Remove previous contour lines for target
    # for coll in target_contour.collections:
    #     coll.remove()
 
    # # Remove previous contour lines for structure
    for coll in structure_contour.collections:
        coll.remove()
 
    # Plot new contour for target
    
    # target_contour = ax.contour(x, y, target[slice_idx], levels=[0.5], colors='yellow')
    # # Plot new contour for structure
    structure_contour = ax.contour(x, y, structure[slice_idx], levels=[0.5], colors='red')
 
    plt.draw()

# Create interactive scroll bar
interact(update, slice_idx=IntSlider(min=0, max=ct.shape[0], step=1, value=ct.shape[0]//2));

-------
# Visualize the preprocessed data

In [None]:
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
import h5py

PAT_ID = '663560'
MIN_DOSE = 0.03
OAR_1 = 'liver'
# OAR_2 = 'body'
FILE_DIR = 'temp/data/preprocessed.h5'

# close any open figures
plt.close('all')

# Initialize Plot

# Initial slice
slice_idx = 0

with h5py.File(FILE_DIR, 'r') as f:
    pat_data = f[str(PAT_ID)]
    print(pat_data['contours'].keys())
    ct = pat_data['ct'][:]
    dose = pat_data['dose'][:]
    oar_1_data = pat_data[f'contours/{OAR_1}'][:]
    # oar_2_data = pat_data[f'contours/{OAR_2}'][:]
    cmb_targets = pat_data['contours/combined_targets'][:]
    max_slice = ct.shape[0]

print(ct.max(), dose.max())

In [None]:


fig, ax = plt.subplots()

# print maximum value of the dose, ct
print(f"Maximum dose: {np.max(dose):.2f} Gy")
print(f"Maximum CT: {np.max(ct):.2f} HU")    

# balance the contrast of the CT to enhance soft tissue
# ct[ct<0.05] = 0.05 
# ct[ct>.2] = .2
    
# Plot initial slice
ct_slice = ax.imshow(ct[slice_idx], cmap='gray', interpolation='none')

# dose_sn = np.ma.masked_where(dose[slice_idx] == 0, dose[slice_idx])
# dose_slice = ax.imshow(dose_sn, cmap='jet', alpha=0.3, interpolation='none', vmin=np.min(dose), vmax=np.max(dose))
oar_1_contour_plot = ax.contour(oar_1_data[slice_idx], levels=[0.5], colors='y')
# oar_2_contour_plot = ax.contour(oar_2_data[slice_idx], levels=[0.5], colors='w')
target_contour_plot = ax.contour(cmb_targets[slice_idx], levels=[0.5], colors='r')
ax.axis('off')

# Update function for scroll bar
def update(slice_idx):

    ct_slice.set_data(ct[slice_idx])
    # dose_sn = np.ma.masked_where(dose[slice_idx] == 0, dose[slice_idx])
    # dose_slice.set_data(dose_sn)
    
    # dose_slice.set_clim(vmin=np.min(dose[slice_idx]), vmax=np.max(dose[slice_idx]))
    ct_slice.set_clim(vmin=np.min(ct[slice_idx]), vmax=np.max(ct[slice_idx]))
    
    # # Remove previous contour lines
    for coll in oar_1_contour_plot.collections:
        coll.remove()
    
    # for coll in oar_2_contour_plot.collections:
    #     coll.remove()
    
    for coll in target_contour_plot.collections:
        coll.remove()

    # # Plot new contour
    new_contour = ax.contour(oar_1_data[slice_idx], levels=[0.5], colors='y')
    oar_1_contour_plot.collections = new_contour.collections
    
    # new_contour = ax.contour(oar_2_data[slice_idx], levels=[0.5], colors='w')
    # oar_2_contour_plot.collections = new_contour.collections
    
    
        
    # # Plot new contour
    new_contour = ax.contour(cmb_targets[slice_idx], levels=1, colors='r')
    target_contour_plot.collections = new_contour.collections
    
    plt.draw()

# Create interactive scroll bar
interact(update, slice_idx=IntSlider(min=0, max=ct.shape[0], step=1, value=0));


In [None]:
from evaluation_toolbox import ModelEvaluationTools
from tqdm import tqdm
import h5py, pickle, os, logging, datetime
import multiprocessing

user_inputs_dir = "configuration_files/user_config.json"
echo_progress = True

met = ModelEvaluationTools(user_inputs_dir)

# create model inference folder
met.prepare_inference_folder

# Determine the fold(s) to evaluate
folds = [0]

# Weight directory
weights_dir = ""


# prepare the evaluation data
eval_data_set = met.read_data_ids(data_type = met.data_set_to_evaluate)
gt_norm_fctr, pr_norm_fctr = 1.0, 1.0
patient_ids = tqdm(eval_data_set) if echo_progress else eval_data_set




# prepare the model if only one fold is used and a prediction file is to be written
if len(folds) == 1 and met.prediction_file_mode in ['w', 'a']: 
    model = met.prepare_model(folds[0], weights_dir, compiled=True)
else:
    model = None
    
# Evaluate the prediction for each patient in the evaluation set 
for p in patient_ids:
    
    # get the dose distribution for the patient
    gt_dose, pr_dose = met.get_dose_volumes(p, folds, weights_dir, model)
                                    
    # Normalize the dose if desired
    gt_dose, pr_dose, gt_norm_fctr, pr_norm_fctr = met.normalize_dose(p, gt_dose, pr_dose, 
                                                                        method=met.normalization_method, 
                                                                        normalize=met.reference_normalization_volume)
            
    # Save the prediction
    if met.prediction_file_mode == 'w':
        met.save_patient_data(p, gt_dose, pr_dose, gt_norm_factor=gt_norm_fctr, pr_norm_factor=pr_norm_fctr)
    elif met.prediction_file_mode == 'a':
        if not os.path.exists(os.path.join(met.model_inferences_dir ,f'AID-{p}.h5')):
            met.save_patient_data(p, gt_dose, pr_dose, gt_norm_factor=gt_norm_fctr, pr_norm_factor=pr_norm_fctr)
    
#     # Determine minimum dose
#     min_dose = met.get_minimum_desired_dose_for_analysis(gt_dose)
    
#     # evaluate the prediction for patient
#     pat_eval[p] = met.evaluate_prediction(p, gt_dose, pr_dose, met.vol_for_max_dose, min_dose=min_dose)  
                    
# # GPR analysis
# if met.include_gpr_analysis:
#     pat_eval = met.get_gamma_passing_rate(eval_data_set, pat_eval, data_folder=met.model_inferences_dir)

# # Save results as a pickle file
# with open(os.path.join('temp', 'data',f'evaluation_results_{"-".join([str(x) for x in folds])}.pickle'), 'wb') as handle:
#     pickle.dump(pat_eval, handle, protocol=pickle.HIGHEST_PROTOCOL) 

# met.get_summay(pat_eval)