In [9]:
import itk
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
PixelType = itk.F # shorthand: itk.UC

## Import data ##

In [65]:
fixed = itk.imread("Data/case6_gre1.nrrd",PixelType)

In [66]:
moving= itk.imread("Data/case6_gre2.nrrd",PixelType)

In [67]:
dimension = fixed.GetImageDimension()
FixedImageType = type(fixed)
MovingImageType = type(fixed)
MovingImageType

itk.itkImagePython.itkImageF3

In [68]:
TransformType = itk.TranslationTransform[itk.D,dimension]
initial_transform = TransformType.New()

In [70]:
optimizer = itk.RegularStepGradientDescentOptimizerv4.New()
optimizer.SetLearningRate(4.0)
optimizer.SetMinimumStepLength(0.001)
optimizer.SetNumberOfIterations(10)

In [71]:
metric = itk.MeanSquaresImageToImageMetricv4[FixedImageType, MovingImageType].New()

In [72]:
fixed_interpolation = itk.LinearInterpolateImageFunction[FixedImageType, itk.D].New()
metric.SetFixedInterpolator(fixed_interpolation)

In [73]:
registration = itk.ImageRegistrationMethodv4[FixedImageType, MovingImageType].New()
registration.SetMetric(metric)
registration.SetOptimizer(optimizer)
registration.SetFixedImage(fixed)
registration.SetMovingImage(moving)
registration.SetInitialTransform(initial_transform)

In [74]:
moving_initial_transform = TransformType.New()
initial_parameters = moving_initial_transform.GetParameters()
initial_parameters[0] = 0
initial_parameters[1] = 0
moving_initial_transform.SetParameters(initial_parameters)
registration.SetMovingInitialTransform(moving_initial_transform)

In [75]:
identity_transform = TransformType.New()
identity_transform.SetIdentity()
registration.SetFixedInitialTransform(identity_transform)

registration.SetNumberOfLevels(1)

In [76]:
registration.Update()

In [77]:
transform = registration.GetTransform()
final_parameters = transform.GetParameters()
print("Translation X: " +str(final_parameters.GetElement(0)))
print("Translation Y: " +str(final_parameters.GetElement(1)))
print("Translation Z: " +str(final_parameters.GetElement(2)))
optimizer.GetValue()

Translation X: -0.02322083939182773
Translation Y: -6.70791941471933
Translation Z: -38.00334177337564


30881.19912346788

In [80]:
CompositeTransformType = itk.CompositeTransform[itk.D, dimension]
output_composite_transform = CompositeTransformType.New()
output_composite_transform.AddTransform(moving_initial_transform)
output_composite_transform.AddTransform(registration.GetModifiableTransform())

resampler = itk.ResampleImageFilter.New(Input=moving, Transform=transform, UseReferenceImage=True,
                                            ReferenceImage=fixed)
resampler.SetDefaultPixelValue(100)

### Partie Segmentation

In [33]:
import sys
import itk
import os
import matplotlib
import matplotlib.pyplot as plt

def segment_tumor(input_filepath, output_filepath, seed, lower, upper, image_nb=50):
    # Read the input image
    input_image = itk.imread(input_filepath, pixel_type=itk.F)

    # Apply anisotropic diffusion filter for smoothing
    smoother = itk.GradientAnisotropicDiffusionImageFilter.New(
        Input=input_image, NumberOfIterations=20, TimeStep=0.04, ConductanceParameter=3)
    smoother.Update()

    # Display the smoothed image to choose a seed point
    plt.ion()
    plt.imshow(itk.GetArrayViewFromImage(smoother.GetOutput())[:, image_nb, :], cmap="gray")
    plt.title("Select seed point and press a key")
    plt.waitforbuttonpress()

    # Instantiate the ConnectedThresholdImageFilter
    connected_threshold = itk.ConnectedThresholdImageFilter.New(smoother.GetOutput())
    connected_threshold.SetReplaceValue(255)
    connected_threshold.SetLower(lower)
    connected_threshold.SetUpper(upper)

    # Set seed point for segmentation
    new_seed = (seed[0], image_nb, seed[2])
    connected_threshold.SetSeed(new_seed)
    connected_threshold.Update()
    # Rescale the intensity of the segmented output
    in_type = itk.output(connected_threshold)
    dimension = input_image.GetImageDimension()
    output_type = itk.Image[itk.UC, dimension]
    rescaler = itk.RescaleIntensityImageFilter[in_type, output_type].New(connected_threshold)
    itk.imwrite(rescaler, output_filepath)

    # Display the segmented output
    plt.imshow(itk.GetArrayViewFromImage(connected_threshold.GetOutput())[:, image_nb, :], cmap="gray")
    plt.title("Segmented tumor")
    #plt.waitforbuttonpress()

In [40]:
matplotlib.use('TkAgg')

# File paths
scan1_filepath = "Data/case6_gre1.nrrd"
scan2_filepath = "Data/case6_gre2.nrrd"
output1_filepath = 'case6_gre1_segmented.nrrd'
output2_filepath = 'case6_gre2_segmented.nrrd'

# Segmentation parameters (example values)
seed1 = (120, 105, 10)  # Adjust these coordinates for the first scan
seed2 = (120, 105, 30)  # Adjust these coordinates for the second scan
lower, upper = 20, 255

# Segment tumors in both scans
segment_tumor(scan1_filepath, output1_filepath, seed1, lower, upper)
#segment_tumor(scan2_filepath, output2_filepath, seed2, lower, upper)