In [1]:
import os
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

In [2]:
def load_dicom_series(folder_path):
    # Load the DICOM series from the specified folder and return as a SimpleITK image.
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(folder_path)
    if not dicom_names:
        raise ValueError(f"No DICOM files found in {folder_path}")
    reader.SetFileNames(dicom_names)
    image = reader.Execute()
    return image

def normalize_image(image):
    # Normalize the image to the range [0, 1].
    img_array = sitk.GetArrayFromImage(image).astype(np.float32)
    img_min = np.min(img_array)
    img_max = np.max(img_array)
    if img_max - img_min < 1e-5:
        return np.zeros_like(img_array)
    else:
        return (img_array - img_min) / (img_max - img_min)

In [4]:
# ================== 1) Load CT and PET images ==================

ct_folder_path = r"D:\NIH\PET-CT\acrin_nsclc_fdg_pet\ACRIN-NSCLC-FDG-PET-001\1.3.6.1.4.1.14519.5.2.1.7009.2403.156046015078185438233607422806\CT_1.3.6.1.4.1.14519.5.2.1.7009.2403.192241118078441962132923230489"
pet_folder_path = r"D:\NIH\PET-CT\acrin_nsclc_fdg_pet\ACRIN-NSCLC-FDG-PET-001\1.3.6.1.4.1.14519.5.2.1.7009.2403.156046015078185438233607422806\PT_1.3.6.1.4.1.14519.5.2.1.7009.2403.121694709831221676480030303736"

# Load images
print("Loading CT image...")
ct_image = load_dicom_series(ct_folder_path)
print("Loading PET image...")
pet_image = load_dicom_series(pet_folder_path)

# Print image information
def print_image_info(name, image):
    print(f"\n{name} Image Information:")
    print(f"  Size: {image.GetSize()}")
    print(f"  Spacing: {image.GetSpacing()}")
    print(f"  Origin: {image.GetOrigin()}")
    print(f"  Direction: {image.GetDirection()}")
    print(f"  Dimension: {image.GetDimension()}")

print_image_info("CT", ct_image)
print_image_info("PET", pet_image)

# Ensure both are 3D images
if ct_image.GetDimension() != 3 or pet_image.GetDimension() != 3:
    raise ValueError("Both CT and PET images must be 3D.")

# Ensure both images are float32
ct_image = sitk.Cast(ct_image, sitk.sitkFloat32)
pet_image = sitk.Cast(pet_image, sitk.sitkFloat32)

# ================== 2) Register PET to CT ==================

# Initialize the registration method
registration_method = sitk.ImageRegistrationMethod()

# Set multi-resolution strategy (recommended)
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Set the registration metric
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)

# Set the transformation type (rigid transformation)
initial_transform = sitk.CenteredTransformInitializer(
    ct_image,
    pet_image,
    sitk.Euler3DTransform(),
    sitk.CenteredTransformInitializerFilter.GEOMETRY
)

registration_method.SetInitialTransform(initial_transform, inPlace=False)

# Set the optimizer
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, 
                                                  numberOfIterations=100, 
                                                  convergenceMinimumValue=1e-6, 
                                                  convergenceWindowSize=10)
registration_method.SetOptimizerScalesFromPhysicalShift()

# Set the interpolation method
registration_method.SetInterpolator(sitk.sitkLinear)

# Execute registration
print("\nStarting registration...")
try:
    final_transform = registration_method.Execute(sitk.Cast(ct_image, sitk.sitkFloat32), 
                                                  sitk.Cast(pet_image, sitk.sitkFloat32))
    
    print("\nRegistration completed.")
    print("Optimizer Converged:", registration_method.GetOptimizerStopConditionDescription())
    print("Final metric value:", registration_method.GetMetricValue())
except Exception as e:
    print(f"\nRegistration failed: {e}")
    raise e

# ================== 3) Resample PET image ==================

resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(ct_image)
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(final_transform)

print("\nResampling PET image to CT space...")
pet_resampled = resampler.Execute(pet_image)
print("Resampling completed.")

Loading CT image...
Loading PET image...

CT Image Information:
  Size: (512, 512, 307)
  Spacing: (0.9765625, 0.9765625, 2.5)
  Origin: (-249.51172, -460.51172, -1007.0)
  Direction: (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
  Dimension: 3

PET Image Information:
  Size: (128, 128, 307)
  Spacing: (5.3067274, 5.3067274, 2.5)
  Origin: (-339.16485, -553.64165, -1007.0)
  Direction: (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
  Dimension: 3

Starting registration...

Registration completed.
Optimizer Converged: GradientDescentOptimizerv4Template: Convergence checker passed at iteration 39.
Final metric value: -0.352723439501319

Resampling PET image to CT space...
Resampling completed.


In [5]:
# ================== 4) Visualize fusion ==================

def show_fusion_slice(slice_idx):
    # Convert CT and resampled PET to NumPy arrays
    ct_array = sitk.GetArrayFromImage(ct_image)        # shape: [slices, height, width]
    pet_array = sitk.GetArrayFromImage(pet_resampled)  # shape: [slices, height, width]
    
    # Normalize
    ct_norm = normalize_image(ct_image)
    pet_norm = normalize_image(pet_resampled)
    
    # Get the minimum common shape
    min_z = min(ct_norm.shape[0], pet_norm.shape[0])
    min_y = min(ct_norm.shape[1], pet_norm.shape[1])
    min_x = min(ct_norm.shape[2], pet_norm.shape[2])
    
    # Ensure slice_idx is within range
    if slice_idx >= min_z:
        slice_idx = min_z - 1
    if slice_idx < 0:
        slice_idx = 0
    
    ct_slice = ct_norm[slice_idx, :min_y, :min_x]
    pet_slice = pet_norm[slice_idx, :min_y, :min_x]
    
    plt.figure(figsize=(8,8))
    plt.imshow(ct_slice, cmap='gray', interpolation='none')
    plt.imshow(pet_slice, cmap='jet', alpha=0.3, interpolation='none')
    plt.title(f"Fused Slice #{slice_idx+1}")
    plt.axis('off')
    plt.show()

num_slices = min(ct_image.GetSize()[2], pet_resampled.GetSize()[2])

interact(show_fusion_slice, slice_idx=IntSlider(min=0, max=num_slices-1, step=1, value=num_slices//2))

interactive(children=(IntSlider(value=153, description='slice_idx', max=306), Output()), _dom_classes=('widget…

<function __main__.show_fusion_slice(slice_idx)>

In [6]:
# ================== 4) Visualize fusion (faster version)==================

# Preprocess images for faster visualization
ct_array = sitk.GetArrayFromImage(ct_image)
pet_array = sitk.GetArrayFromImage(pet_resampled)
ct_norm = normalize_image(ct_image)
pet_norm = normalize_image(pet_resampled)

# Get the minimum common shape
min_z = min(ct_norm.shape[0], pet_norm.shape[0])
min_y = min(ct_norm.shape[1], pet_norm.shape[1])
min_x = min(ct_norm.shape[2], pet_norm.shape[2])
ct_norm = ct_norm[:min_z, :min_y, :min_x]
pet_norm = pet_norm[:min_z, :min_y, :min_x]

def show_fusion_slice(slice_idx):
    # Ensure slice_idx is within range
    slice_idx = max(0, min(slice_idx, min_z - 1))
    
    # Extract the slice
    ct_slice = ct_norm[slice_idx]
    pet_slice = pet_norm[slice_idx]
    
    # Plot the fused slice
    plt.figure(figsize=(8, 8))
    plt.imshow(ct_slice, cmap='gray', interpolation='none')
    plt.imshow(pet_slice, cmap='jet', alpha=0.3, interpolation='none')
    plt.title(f"Fused Slice #{slice_idx + 1}")
    plt.axis('off')
    plt.show()

num_slices = min_z

interact(show_fusion_slice, slice_idx=IntSlider(min=0, max=num_slices - 1, step=1, value=num_slices // 2))

interactive(children=(IntSlider(value=153, description='slice_idx', max=306), Output()), _dom_classes=('widget…

<function __main__.show_fusion_slice(slice_idx)>