In [None]:


import importlib
import pip
import time


def install_and_reload(package_name):
    print(f"Installing {package_name}...")
    pip.main(['install', package_name])
    time.sleep(5)
    print(f"Installation of {package_name} complete.")
    importlib.reload(importlib.import_module(package_name))

try:
    import pydicom
except ImportError:
    install_and_reload('pydicom')
else:
    print("pydicom already installed, importing...")
try:
    import radiomics
except ImportError:
    install_and_reload('radiomics')
else:
    print("pyradiomics already installed, importing...")
try:
    import rt_utils
except ImportError:
    install_and_reload('rt-utils')
else:
    print("rt-utils already installed, importing...")
try:
    import dcm2niix
except ImportError:
    install_and_reload('dcm2niix')
else:
    print("dcm2niix already installed, importing...")

import pydicom
import os
import numpy as np
from rt_utils import RTStructBuilder
from radiomics import featureextractor
import nibabel as nib
import SimpleITK as sitk

import subprocess
from scipy import ndimage
import SimpleITK as sitk
from typing import List
import dcm2niix
import pandas as pd
import shutil
import re
import matplotlib.pyplot as plt
import subprocess
import tempfile
import re


In [2]:
class Mask():
    def __init__(self, source_folder):
        self.source_folder = source_folder
        self.cases = {}
        self.roi_names={}
        self.nifti_extension = '.nii'

    def path(self, folder_path=None):
        try:
            path_to_use = folder_path if folder_path else self.source_folder
            for root, _, files in sorted(os.walk(path_to_use)):
                files = [f for f in files if not f.startswith('.')]
                if not files:
                    continue
                case = os.path.basename(root)
                self.cases[case] = {
                    'Folder path': root,
                    'Files': files
                }
            
            return self.cases
        except Exception as e:
            print(f"Error while listing cases: {e}")
            return {}
        

    def process_cases(self):
        try:
            for key in self.cases.keys():
                temp_dir = self.copy_ct_files_to_directory(key)
                output_filename = f"CT_{key}"
                output_directory = os.path.join(self.source_folder, key)

                try:
                    # Convert CT to NIfTI
                    ct = self.run_dcm2niix(temp_dir, output_directory, output_filename)
                    file_path = self.extract_nifti_paths(ct)
                    self.cases[key]['CT nifti'] = file_path

                    # Select RTStruct
                    rt_struct = None
                    rt_dose = None
                    for file in self.cases[key]['Files']:
                        try:
                            ds = pydicom.dcmread(os.path.join(self.cases[key]['Folder path'], file))
                            if ds.Modality == "RTSTRUCT":
                                rt_struct = os.path.join(self.cases[key]['Folder path'], file)
                            elif ds.Modality == 'RTDOSE':
                                rt_dose = os.path.join(self.cases[key]['Folder path'], file)
                        except Exception as e:
                            print(f"Error reading DICOM file: {e}")

                    self.cases[key]['RTStruct path'] = rt_struct
                    self.cases[key]['RTDose path'] = rt_dose

                    if rt_struct:
                        self.get_ROIs(key, temp_dir)

                    if rt_dose:
                        input_path = self.set_spacing(rt_dose)
                        output_filename = f"RD_{key}"
                        rd = self.run_dcm2niix(input_path, output_directory, output_filename)
                        file_path = self.extract_RD_nifti_paths(rd)
                        self.cases[key]['RD nifti'] = file_path

                    if os.path.exists(temp_dir):
                        shutil.rmtree(temp_dir)

                except Exception as e:
                    print(f"Error processing case {key}: {str(e)}")

            return self.cases
        except Exception as e:
            print(f"Error while processing cases: {e}")
            return {}
    
    def set_spacing(self, dicom_file_path):
        """
        Check if the DICOM contains the slice spacing, if not, set it. 
        """
        ds = pydicom.dcmread(dicom_file_path)
        slice_thickness = ds.GridFrameOffsetVector[1] - ds.GridFrameOffsetVector[0]
        ds.SliceThickness = slice_thickness
        ds.save_as(dicom_file_path)
    
        return dicom_file_path
    
    def select_first_case(self, key): 
        if 'ROI names' in self.cases[key]:
            variable = self.cases[key]['ROI names']
            return variable
        else:
            print(f"Key 'ROI names' not found for case {key}.")
            return None
    
    def set_roi(self, key, roi):
        self.cases[key]['Selected ROI'] = roi
        return self.cases
    
    def get_ROIs(self,key, ct_series_path):
        
        rt_struct_path = self.cases[key]['RTStruct path']
        rt_struct = RTStructBuilder.create_from(dicom_series_path=ct_series_path, rt_struct_path=rt_struct_path)
        shutil.rmtree(ct_series_path)
        roi_names = rt_struct.get_roi_names()
        self.cases[key]['ROI names'] = roi_names
        
        return self.cases
    
    def create_binary_mask(self, case, ROI):
        ct_series = self.copy_ct_files_to_directory(case)
        rt_struct_obj = RTStructBuilder.create_from(dicom_series_path= ct_series , rt_struct_path= self.cases[case]['RTStruct path'])
        ct_nifti = nib.load(self.cases[case]['CT nifti'])
        mask = rt_struct_obj.get_roi_mask_by_name(ROI)
        mask_3d = np.where(mask, 1, 0)
        mask_3d = np.rot90(mask_3d, -1)
        output_directory = os.path.join(self.source_folder, case)
        output_mask_path = os.path.join(output_directory, f'mask_{case}_{ROI}.nii')
        self.cases[case]['Mask path'] = output_mask_path
        self.save_image(mask_3d.astype(np.uint16), ct_nifti.affine, output_mask_path)
        print('created mask:', output_mask_path)
        if os.path.exists(ct_series):
            shutil.rmtree(ct_series)
        return self.cases
        
    def copy_ct_files_to_directory(self, key):
        target_dir = os.path.join(self.source_folder, f"temp_ct_files_{key}")
        os.makedirs(target_dir, exist_ok=True)
        for file in self.cases[key]['Files']:
            try:
                ds = pydicom.dcmread(os.path.join(self.cases[key]['Folder path'], file))
                if hasattr(ds, "pixel_array") and ds.Modality == "CT":
                    source_file = os.path.join(self.cases[key]['Folder path'], file)
                    target_file = os.path.join(target_dir, file)
                    shutil.copy(source_file, target_file)
            except Exception as e:
                continue
        return target_dir
    
    def save_image(self, image, affine, output_path):
        nifti = nib.Nifti1Image(image, affine)
        nib.save(nifti, output_path)

    def find_similar_ROIs(self, roi):
        roi = roi.lower()
        self.cases_not_matching = []

        for key, case_data in self.cases.items():
            found = False
            if 'ROI names' in case_data:
                for roi_name in case_data['ROI names']:
                    roi_name_lower = roi_name.lower()
                    roi_parts = re.split(r'\s+|[-_]', roi_name_lower)

                    for part in roi_parts:
                        if len(part)>2 and (part in roi or roi in part or part == roi):
                            self.cases[key]['Selected ROI'] = roi_name
                            found = True
                            break
                        
                    if found:
                        break 
                if not found:
                    self.cases_not_matching.append(key)
            else:
                continue
        return self.cases, self.cases_not_matching

    def extract_nifti_paths(self, output):
        file_path = None
        lines = output.split('\n')
        for line in lines:
            if 'Convert' in line and 'DICOM as' in line:
                line = line.strip()
                start = line.find('DICOM as') + len('DICOM as') + 1
                if start == -1:
                    continue
                end = line.find('(', start)
                if end == -1:
                    continue
                file_path = line[start:end].strip()
                file_path += self.nifti_extension
        print(f"Extracted NIfTI path: {file_path}")  
        return file_path
    
    def extract_RD_nifti_paths(self, output):
        file_path=None
        lines = output.split('\n')
        for line in lines:
            if 'Convert 1 DICOM as' in line:
                line = line.strip()
                start = line.find('DICOM as') + len('DICOM as') + 1
                if start == -1:
                    continue
                end = line.find('(', start)
                if end == -1:
                    continue

                file_path = line[start:end].strip()
                file_path += self.nifti_extension
                print(file_path)

        return file_path

    def run_dcm2niix(self, input_path, output_directory, output_filename):
        command = [
            'dcm2niix',
            '-f', output_filename,
            '-d', 'y',
            '-i', 'y',
            '-o', output_directory,
            input_path
        ]
        try:
            result = subprocess.run(command, stdout=subprocess.PIPE, text=True, check=True)
            print(f"dcm2niix output for {input_path}:", result.stdout)
            return result.stdout
        except subprocess.CalledProcessError as e:
            print(f"Error executing dcm2niix: {e}")
            return ""

    def resample(self, image_path , new_resolution, interpolator, key, dose_scaling=False):
        image = sitk.ReadImage(image_path)
        new_size = [int(round(image.GetSize()[i] * image.GetSpacing()[i] / new_resolution[i])) for i in range(3)]
        resampler = sitk.ResampleImageFilter()
        resampler.SetSize(new_size)
        resampler.SetOutputSpacing(new_resolution)
        resampler.SetOutputOrigin(image.GetOrigin())
        resampler.SetOutputDirection(image.GetDirection())
        resampler.SetInterpolator(interpolator)
        rescaled_image = resampler.Execute(image)

        array_rescaled,min_rescaled,max_rescaled = self.get_array(rescaled_image)
        _,min_original, max_original = self.get_array(image)

        adjusted_array = ((array_rescaled - min_rescaled) / (max_rescaled - min_rescaled)) * (max_original - min_original) + min_original
        
        if dose_scaling:
            scaled_image = self.dose_scaling(adjusted_array,key)
          
            scaled_image = sitk.GetImageFromArray(scaled_image)
            
        else:
            scaled_image = sitk.GetImageFromArray(adjusted_array)

        scaled_image.CopyInformation(rescaled_image)
       
        return scaled_image
    
    def save_nifti_image(self,image,key,output_filename):
        output_path = os.path.join(os.path.join(self.source_folder, key), output_filename )
        sitk.WriteImage(image, output_path)
        return output_path

    def check_type(self, image_path):
        ds = pydicom.dcmread(image_path)
        if ds.Modality == 'RTDOSE':
            if hasattr(ds, 'DoseGridScaling'):
                return 'dose'

    def get_array(self, image):
        array = sitk.GetArrayFromImage(image)
        min = np.min(array)
        max = np.max(array)
        return array,min,max

    def dose_scaling(self,nifti_data,case): 
        for key,value in self.cases.items():
            if key == case:
                ds_path = value['RTDose path']
                ds = pydicom.dcmread(ds_path)
                if hasattr(ds, 'DoseGridScaling'):
                    scaled_image = nifti_data*(ds.DoseGridScaling)
        return scaled_image

    def EQD2(self, alfa_beta, case, scaled_image, sessions=None):
        path = os.path.join(self.source_folder, case)
        if not os.path.isdir(path):
            raise ValueError(f"The directory {path} does not exist.")
        try:
            files = sorted(os.listdir(path))
            for file in files:
                if file.lower().endswith('.dcm'):
                    try:
                        ds = pydicom.dcmread(os.path.join(path, file))
                        if ds.Modality == 'RTPLAN':
                            sessions = ds.FractionGroupSequence[0].NumberOfFractionsPlanned
                            break
                    except Exception as e:
                        print(f"Error reading DICOM file {file}: {e}")
    
            if sessions is None:
                raise ValueError("Number of fractions planned not found in RTPLAN DICOM file. Please enter the number of sessions.")
            for key in self.cases.keys():
                if key == case:
                    eqd2_image = scaled_image * (((scaled_image / sessions) + alfa_beta) / (2 + alfa_beta))
                    return eqd2_image
        except Exception as e:
            raise ValueError(f"Error calculating EQD2 image: {e}")

    def select_bin_width(self, path, bin_count):
        image = sitk.ReadImage(path)
        _,min_value,max_value = self.get_array(image)
        bin_width = (max_value - min_value) / bin_count
        return bin_width

    def extract_features(self, mask_path, image_path, bin_width):
        try: 
            mask = sitk.ReadImage(mask_path)
            image = sitk.ReadImage(image_path)
    
            params = {
                'binWidth': bin_width,
                'symmetricalGLCM': True, 
                'correctMask': True
            }   
    
            extractor = featureextractor.RadiomicsFeatureExtractor(**params)
            featureVector = extractor.execute(image, mask)
    
            return featureVector
        except Exception as e:
            print(f"Error extracting features: {e}")
            return None

    def dataframe_features(self, featureVector, key):
        if featureVector is None:
            return pd.DataFrame()

        feature_names = []
        feature_values = []
        for clave, valor in featureVector.items():
            try:
                image_type, feature_class, feature_name = clave.split('_')
                if image_type == 'original':
                    if feature_name in feature_names:
                        feature_name = f"{feature_name}_{feature_class}"
                    feature_names.append(feature_name)
                    feature_values.append(valor)
            except ValueError:
                continue

        data = {
            'Feature Name': feature_names,
            f'{key}': feature_values
        }

        df = pd.DataFrame.from_dict(data)

        return df

    def df_to_excel(self, df, output_filename):
        output_path = os.path.join(self.source_folder, output_filename)
        
        df.to_excel(output_path, index=False)