In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, clear_output
import itk
import SimpleITK as sitk
import matplotlib.patches as patches
from ipywidgets import interact, fixed
from matplotlib.widgets import Slider
import multiprocessing
import pandas as pd

In [None]:
def change_voxel_spacing(image):
    new_spacing = [0.71875,0.71875, 1.0] # spacing of our private dataset
    out_size = [
        int(np.round(
            size * (spacing_in / spacing_out)
        ))
        for size, spacing_in, spacing_out in zip([image.GetSize()[0],image.GetSize()[1],image.GetSize()[2]], [image.GetSpacing()[0],image.GetSpacing()[1],image.GetSpacing()[2]], new_spacing)
    ]
    reference_image = sitk.Image(out_size[0], out_size[1], out_size[2], image.GetPixelID())
    reference_image.SetOrigin(image.GetOrigin())
    reference_image.SetSpacing(new_spacing)
    reference_image.SetDirection(image.GetDirection())
    initial_transform =sitk.Transform()
    res = sitk.Resample(image, reference_image, initial_transform, sitk.sitkHammingWindowedSinc, 0.0, image.GetPixelID())
    return res

In [None]:
def center_crop_1024_image(sitk_image, offset, size):
    # Get the size of the image
    image_size = sitk_image.GetSize()

    # Assume 3D image; for 2D, ignore the z-dimension
    image_center = [image_size[i] // 2 for i in range(len(image_size))]

    # Adjust the center by the offset
    adjusted_center = [image_center[i] + offset[i] for i in range(len(offset))]

    # Calculate the top-left corner of the crop region
    start_point = [adjusted_center[i] - size[i] // 2 for i in range(len(size))]

     # Using RegionOfInterest to crop the image
    roi_filter = sitk.RegionOfInterestImageFilter()
    roi_filter.SetSize(size)
    roi_filter.SetIndex(start_point)
    cropped_sitk_image = roi_filter.Execute(sitk_image)
    return cropped_sitk_image

In [None]:
def register_adc_image(reference, adc_image):
    # Set up the registration method
    registration_method = sitk.ImageRegistrationMethod()

    # Similarity metric settings
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)

    # Interpolator
    registration_method.SetInterpolator(sitk.sitkHammingWindowedSinc)

    # Optimizer settings
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10)
    registration_method.SetOptimizerScalesFromPhysicalShift()

    # Setup for the multi-resolution framework
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Initial alignment of the moving image
    initial_transform = sitk.CenteredTransformInitializer(reference, adc_image, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY)
    registration_method.SetInitialTransform(initial_transform, inPlace=False)

    # Execute the registration
    final_transform = registration_method.Execute(reference, adc_image)

    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(reference)
    resampler.SetInterpolator(sitk.sitkHammingWindowedSinc)
    resampler.SetTransform(final_transform)
    resampled = resampler.Execute(adc_image)
    return resampled

In [None]:
def is_cancer_positive(patient_id, dataset_metadata):
    lesion_gs_value = dataset_metadata.loc[dataset_metadata['patient_id'] == patient_id, 'lesion_GS'].values
    # print('lesion here', lesion_gs_value)
    try:
        s = lesion_gs_value[0]
    except Exception as e:
        return False
    # Check if the value is NaN
    if pd.isna(s):
        return False
    else:
        # Check for non-zero digits
        return any(char.isdigit() and char != '0' for char in str(s))

In [None]:
def find_fixed_image(files_path):
    for image_name in os.listdir(files_path):
        if 't2w' in image_name:
            return sitk.ReadImage(os.path.join(files_path, image_name), sitk.sitkFloat32), image_name
    return None

In [None]:
root_dir = '/local_ssd/practical_wise24/prostate_cancer/PICAIDataset/input/images'
registered_dataset_path = '/local_ssd/practical_wise24/prostate_cancer/registered_picai_dataset' 
# registered_dataset_path = './registered_picai_dataset' 
df = pd.read_csv('/local_ssd/practical_wise24/prostate_cancer/PICAIDataset/input/picai_labels/clinical_information/marksheet.csv')
# Convert 'patient_id' column to string
df['patient_id'] = df['patient_id'].astype(str)

In [None]:
def process_patient(patient_folder):
    write_file_path = os.path.join(registered_dataset_path, patient_folder)
    if not is_cancer_positive(patient_folder, df):
        return
    elif not os.path.exists(write_file_path):
        os.makedirs(write_file_path)
    elif len(os.listdir(write_file_path)) == 2:
        return
    files_path = os.path.join(root_dir, patient_folder)
    fixed_image, fixed_image_name = find_fixed_image(files_path)
    fixed_image_output_path = os.path.join(write_file_path, fixed_image_name)
    if fixed_image is not None:
        if fixed_image.GetSize()[0] == 1024:
            fixed_image = center_crop_1024_image(fixed_image, [0,0,0], [612,612,fixed_image.GetSize()[2]])
        fixed_image = change_voxel_spacing(fixed_image)
        for mha_file_name in os.listdir(files_path):
            description = mha_file_name.split('.')[0]
            if 'adc' in description:
                moving_image = sitk.ReadImage(os.path.join(files_path, mha_file_name), sitk.sitkFloat32)
                moving_image_output_path = os.path.join(write_file_path, mha_file_name)
                registered_adc = register_adc_image(fixed_image, moving_image)
            else:
                continue
        sitk.WriteImage(fixed_image, fixed_image_output_path)
        sitk.WriteImage(registered_adc, moving_image_output_path)
        print(patient_folder)

In [None]:
# Create a pool of workers
with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
    # Map process_patient function to each patient folder
    pool.map(process_patient, [patient_folder for patient_folder in os.listdir(root_dir)])