In [1]:
import itk
import vtk
import os
import matplotlib
import matplotlib.pyplot as plt

## Lecture des données

In [2]:
file1_path = "Data/case6_gre1.nrrd"
file2_path = "Data/case6_gre2.nrrd"

PixelType = itk.F
fixed_image = itk.imread(file1_path, PixelType)
moving_image = itk.imread(file2_path, PixelType)
itk.imwrite(fixed_image, "fixed_image.nii.gz")
itk.imwrite(moving_image, "moving_image.nii.gz")

## Recalage d'images

In [3]:
dimension = fixed_image.GetImageDimension()
FixedImageType = type(fixed_image)
MovingImageType = type(moving_image)

### Recalage par translation

In [4]:
#Définir la transformation
TransformType = itk.TranslationTransform[itk.D, dimension]
initial_transform = TransformType.New()
initial_transform.SetIdentity()


#Définir l'optimiseur
optimizer = itk.RegularStepGradientDescentOptimizerv4.New()
optimizer.SetLearningRate(4.0)
optimizer.SetMinimumStepLength(0.001)
optimizer.SetNumberOfIterations(100)  # Diminuez ici pour tester l'effet

#Définir la métrique
metric = itk.MeanSquaresImageToImageMetricv4[FixedImageType, MovingImageType].New()

#Définir l'interpolateur
interpolator = itk.LinearInterpolateImageFunction[FixedImageType, itk.D].New()
metric.SetFixedInterpolator(interpolator)

#Préparer la méthode de recalage
registration = itk.ImageRegistrationMethodv4[FixedImageType, MovingImageType].New()
registration.SetFixedImage(fixed_image)
registration.SetMovingImage(moving_image)
registration.SetInitialTransform(initial_transform)
registration.SetMetric(metric)
registration.SetOptimizer(optimizer)

#Lancer
registration.Update()

#Résultats
final_translation_transform = registration.GetTransform()

### Recalage b-spline

In [5]:
transform_type = itk.BSplineTransform[itk.D, dimension, 3]
transform_bspline = transform_type.New()

transform_domain_mesh_size = [2] * fixed_image.GetImageDimension();
spacing = fixed_image.GetSpacing()
origin = fixed_image.GetOrigin()
direction = fixed_image.GetDirection()
region = fixed_image.GetLargestPossibleRegion()
size = region.GetSize()
physical_dimensions = [spacing[i] * size[i] for i in range(dimension)]
mesh_size = [8] * dimension

transform_bspline.SetTransformDomainOrigin(origin)
transform_bspline.SetTransformDomainPhysicalDimensions(physical_dimensions)
transform_bspline.SetTransformDomainMeshSize(mesh_size)
transform_bspline.SetTransformDomainDirection(direction)

parameters = transform_bspline.GetParameters()
for i in range(parameters.size()):
    parameters[i] = 0.0
transform_bspline.SetParameters(parameters)

## Recalage Composite

In [6]:
composite_type = itk.CompositeTransform[itk.D, dimension]
composite_transform = composite_type.New()
composite_transform.AddTransform(final_translation_transform)
composite_transform.AddTransform(transform_bspline)

## Resampler

In [7]:
resampler = itk.ResampleImageFilter[MovingImageType, FixedImageType].New()
resampler.SetInput(moving_image)
resampler.SetTransform(composite_transform)
resampler.SetUseReferenceImage(True)
resampler.SetReferenceImage(fixed_image)
resampler.SetDefaultPixelValue(0)

resampler.Update()

aligned_image = resampler.GetOutput()

itk.imwrite(aligned_image, "aligned_image.nii.gz")

## Segmentation

In [8]:
ImageTypeOut = itk.Image[itk.UC, 3]
gradient = itk.GradientAnisotropicDiffusionImageFilter.New(fixed_image)
gradient.SetNumberOfIterations(5)
gradient.SetTimeStep(0.05)
gradient.SetConductanceParameter(3.0)
gradient.Update()
smooth_image = gradient.GetOutput()

confidence_filter = itk.ConfidenceConnectedImageFilter.New(smooth_image)
confidence_filter.SetInitialNeighborhoodRadius(2)
confidence_filter.SetMultiplier(2.3)
confidence_filter.SetNumberOfIterations(2)
confidence_filter.SetReplaceValue(1)

seed = itk.Index[3]()
seed[0], seed[1], seed[2] = 84, 72, 47
confidence_filter.AddSeed(seed)

confidence_filter.Update()
segmentation = confidence_filter.GetOutput()

final = itk.RescaleIntensityImageFilter[type(segmentation), ImageTypeOut].New()
final.SetInput(segmentation)
final.SetOutputMinimum(0)
final.SetOutputMaximum(255)
final.Update()
final_image = final.GetOutput()

itk.imwrite(final_image, "segmentation_fixed.nii.gz")

In [9]:
ImageTypeOut = itk.Image[itk.UC, 3]
gradient = itk.GradientAnisotropicDiffusionImageFilter.New(aligned_image)
gradient.SetNumberOfIterations(5)
gradient.SetTimeStep(0.05)
gradient.SetConductanceParameter(3.0)
gradient.Update()
smooth_image = gradient.GetOutput()

confidence_filter = itk.ConfidenceConnectedImageFilter.New(smooth_image)
confidence_filter.SetInitialNeighborhoodRadius(2) 
confidence_filter.SetMultiplier(2.5)             
confidence_filter.SetNumberOfIterations(2)        
confidence_filter.SetReplaceValue(1)

seed = itk.Index[3]()
seed[0], seed[1], seed[2] = 84, 72, 47
confidence_filter.AddSeed(seed)

confidence_filter.Update()
segmentation_image = confidence_filter.GetOutput()

final = itk.RescaleIntensityImageFilter[type(segmentation_image), ImageTypeOut].New()
final.SetInput(segmentation_image)
final.SetOutputMinimum(0)
final.SetOutputMaximum(255)
final.Update()
final_image = final.GetOutput()

itk.imwrite(final_image, "segmentation_aligned.nii.gz")

## Visualisation

In [10]:
print_fixed = True
print_aligned = True
print_brain = False

In [11]:
def callback(obj, event):
    global print_fixed, print_brain, print_aligned
    key = obj.GetKeySym()
    if key == '1':
        print_fixed = not print_fixed
    if key == '2':
        print_aligned = not print_aligned
    if key == '3':
        print_brain = not print_brain

In [12]:
seg_image_fixed = itk.imread("segmentation_fixed.nii.gz")
seg_image_aligned = itk.imread("segmentation_aligned.nii.gz")

vtk_images = []
vtk_images.append(itk.vtk_image_from_image(seg_image_fixed))
vtk_images.append(itk.vtk_image_from_image(seg_image_aligned))

actors = []
i = 1
for vtk_image in vtk_images:
    reader = vtk.vtkMarchingCubes()
    reader.SetInputData(vtk_image)
    reader.SetValue(0, 0.5)
    reader.Update()

    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputConnection(reader.GetOutputPort())
    mapper.ScalarVisibilityOff()

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetColor(1 / i, i, i / 2)
    actor.GetProperty().SetOpacity(0.5)
    actors.append(actor)
    i += 1
    
renderer = vtk.vtkRenderer()

window = vtk.vtkRenderWindow()
window.AddRenderer(renderer)

windowInteractor = vtk.vtkRenderWindowInteractor()
windowInteractor.SetRenderWindow(window)
windowInteractor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera())
windowInteractor.AddObserver('KeyPressEvent', callback)

def show():
    renderer.RemoveAllViewProps()
    if print_fixed:
        renderer.AddActor(actors[0])
    if print_aligned:
        renderer.AddActor(actors[1])
    if print_brain:
        renderer.AddActor(actors[2])
    window.Render()

show()
windowInteractor.Initialize()
windowInteractor.Start()