In [1]:
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import zoom
import ipywidgets as widgets
from IPython.display import display
import os

In [2]:
def read_dicom_series(directory):
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(directory)
    reader.SetFileNames(dicom_names)
    image = reader.Execute()

    # If the image is multi-component (e.g., RGB), convert it to grayscale
    if image.GetNumberOfComponentsPerPixel() > 1:
        channels = [sitk.VectorIndexSelectionCast(image, i) for i in range(image.GetNumberOfComponentsPerPixel())]
        image = sum(channels) / len(channels)

    image = sitk.Cast(image, sitk.sitkFloat32)
    return image

primary_ctp_directory = r'D:\CTH_archive\PN1'
fixed_image_directory = r'D:\CTH_archive\CTH_DICOM_SINGLE_FILES'

moving_image = read_dicom_series(primary_ctp_directory)
fixed_image = read_dicom_series(fixed_image_directory)

In [3]:
# Initialize the registration method
registration_method = sitk.ImageRegistrationMethod()

# Compute the centers of the images
fixed_center = np.array(fixed_image.TransformContinuousIndexToPhysicalPoint(np.array(fixed_image.GetSize()) / 2.0))
moving_center = np.array(moving_image.TransformContinuousIndexToPhysicalPoint(np.array(moving_image.GetSize()) / 2.0))

# Compute the translation needed to align the centers
translation = sitk.TranslationTransform(fixed_image.GetDimension())
translation.SetOffset(np.array(fixed_center - moving_center))

# Initialize the affine transform with the computed translation
affine_transform = sitk.AffineTransform(fixed_image.GetDimension())
affine_transform.SetTranslation(translation.GetOffset())


# Improved metric, optimizer, and interpolator settings
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingPercentage(0.4, sitk.sitkWallClock)  
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)

registration_method.SetOptimizerAsGradientDescentLineSearch(learningRate=0.5, numberOfIterations=500, convergenceMinimumValue=1e-6, convergenceWindowSize=20)
registration_method.SetOptimizerScalesFromPhysicalShift()

# Enhanced multi-resolution strategy
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[16, 8, 4, 2])  
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[5, 4, 2, 1])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Ensure the initial transform is correctly set
registration_method.SetInitialTransform(affine_transform)

# Execute the affine registration with refined settings
final_transform = registration_method.Execute(fixed_image, moving_image)

# Resample using the final transform
resampled_CTH_CTP_Registrion = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID())

In [4]:
def display_image_slices(fixed_image, moving_image, transformed_image):
    # Convert SimpleITK images to arrays for easier manipulation
    fixed_image_array = sitk.GetArrayFromImage(fixed_image)
    moving_image_array = sitk.GetArrayFromImage(moving_image)
    transformed_image_array = sitk.GetArrayFromImage(transformed_image)

    # Determine the maximum number of slices from all images to set the slider range
    max_slices = max(fixed_image.GetSize()[2], moving_image.GetSize()[2], transformed_image.GetSize()[2])

    # Define a function to update the displayed images when the slider is moved
    def update_slice(slice_idx):
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        if slice_idx < fixed_image_array.shape[0]:
            axs[0].imshow(fixed_image_array[slice_idx], cmap='gray')
            axs[0].set_title('Fixed Image')
        else:
            axs[0].text(0.5, 0.5, 'Slice not available', horizontalalignment='center', verticalalignment='center')
            axs[0].set_title('Fixed Image')

        if slice_idx < transformed_image_array.shape[0]:
            axs[1].imshow(transformed_image_array[slice_idx], cmap='gray')
            axs[1].set_title('Transformed Moving Image')
        else:
            axs[1].text(0.5, 0.5, 'Slice not available', horizontalalignment='center', verticalalignment='center')
            axs[1].set_title('Transformed Moving Image')

        if slice_idx < moving_image_array.shape[0]:
            axs[2].imshow(moving_image_array[slice_idx], cmap='gray')
            axs[2].set_title('Moving Image')
        else:
            axs[2].text(0.5, 0.5, 'Slice not available', horizontalalignment='center', verticalalignment='center')
            axs[2].set_title('Moving Image')

        for ax in axs:
            ax.axis('off')

        plt.show()

    # Create a slider widget for slice selection
    slice_slider = widgets.IntSlider(min=0, max=max_slices-1, step=1, value=max_slices//2, description='Slice')

    # Display the widget and use `interactive_output` to connect the slider with the update function
    interactive_output = widgets.interactive_output(update_slice, {'slice_idx': slice_slider})
    display(slice_slider, interactive_output)

display_image_slices(fixed_image, moving_image, resampled_CTH_CTP_Registrion)

IntSlider(value=16, description='Slice', max=31)

Output()

In [5]:
def convert_rgb_series_to_grayscale_and_save_as_nifti(input_directory, output_file):
    # Read the DICOM series
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(input_directory)
    reader.SetFileNames(dicom_names)
    image_series = reader.Execute()
    
    # Convert the image series to a numpy array and then to grayscale
    img_array = sitk.GetArrayFromImage(image_series)  # Shape: [z, y, x, channels]
    grayscale_array = np.dot(img_array[..., :3], [0.299, 0.587, 0.114]).astype(np.uint16)

    # Convert the grayscale numpy array back to a SimpleITK Image
    grayscale_image = sitk.GetImageFromArray(grayscale_array)

    # Copy the spacing, origin, and direction information from the original series
    grayscale_image.SetSpacing(image_series.GetSpacing())
    grayscale_image.SetOrigin(image_series.GetOrigin())
    grayscale_image.SetDirection(image_series.GetDirection())

    # Write the grayscale image as a NIfTI file
    sitk.WriteImage(grayscale_image, output_file)

input_directory = r'D:\CTH_archive\TEST_TMAX'
output_nifti_file = r'D:\CTH_archive\TEST_TMAX_Grayscale.nii'

# Convert the DICOM series and save it as grayscale NIfTI
convert_rgb_series_to_grayscale_and_save_as_nifti(input_directory, output_nifti_file)


In [6]:
# Get the moving image size, spacing, and origin
CTP_size = moving_image.GetSize()
CTP_spacing = moving_image.GetSpacing()
CTP_origin = moving_image.GetOrigin()

print(f"Image Size: {CTP_size}")
print(f"Image Spacing: {CTP_spacing}")
print(f"Image Origin: {CTP_origin}")

Image Size: (512, 512, 23)
Image Spacing: (0.429688, 0.429688, 5.0)
Image Origin: (-105.3, -116.7, -49.75)


In [7]:
gray_scale_nfti_image = sitk.ReadImage(output_nifti_file)

# Convert image to array and crop
gray_scale_nfti_array = sitk.GetArrayFromImage(gray_scale_nfti_image)
gray_scale_nfti_array = gray_scale_nfti_array[:, 30:, :]  # Crop out "TMax text"
gray_scale_nfti_array[:,:,:27] = 0 # Remove scale at left
gray_scale_nfti_cropped_image = sitk.GetImageFromArray(gray_scale_nfti_array)

# Desired output size
desired_size = [512, 512, gray_scale_nfti_cropped_image .GetSize()[2]]

# Resample the image to the new size
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(gray_scale_nfti_cropped_image )
resampler.SetSize(desired_size)
resampler.SetOutputSpacing([gray_scale_nfti_cropped_image .GetSpacing()[i] * (gray_scale_nfti_cropped_image .GetSize()[i] / desired_size[i]) for i in range(3)])
resampler.SetTransform(sitk.Transform())
resampler.SetInterpolator(sitk.sitkLinear)

resized_image = resampler.Execute(gray_scale_nfti_cropped_image)

# Set new spacing and origin as specified
new_map_spacing = CTP_spacing
new_map_origin = CTP_origin

resized_image.SetSpacing(new_map_spacing)
resized_image.SetOrigin(new_map_origin)

In [8]:

# Apply the transformation
resampled_image = sitk.Resample(resized_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID())

transformed_image_path = 'D:\\CTH_archive\\TEST_TMAX_Grayscale_transformed.nii'
sitk.WriteImage(resampled_image, transformed_image_path)

print(f"Transformed image saved to {transformed_image_path}")

Transformed image saved to D:\CTH_archive\TEST_TMAX_Grayscale_transformed.nii
