Ce notebook n'est qu'un prototype.
Le code définitif doit être écrit dans le fichier main.py

# Imports et data

In [57]:
import vtk
import itk

In [2]:
# Ajouter ici les import nécessaires autres que vtk et itk
# import ...

In [80]:
# Les chemins relatif qui mènent au data
IMAGE1_PATH = "./Data/case6_gre1.nrrd"
IMAGE2_PATH = "./Data/case6_gre2.nrrd"

In [90]:
reader1 = vtk.vtkNrrdReader()
reader1.SetFileName(IMAGE1_PATH)
reader1.Update()
image1_vtk = reader1.GetOutput()

reader2 = vtk.vtkNrrdReader()
reader2.SetFileName(IMAGE2_PATH)
reader2.Update()
image2_vtk = reader2.GetOutput()

# Visualisation des images

In [91]:
class SliceViewerWithDifference:
    def __init__(self, vtk_image1, vtk_image2, axis='z', title="Visualisation avec différences"):
        assert axis in ['x', 'y', 'z'], "Axe invalide, doit être 'x', 'y' ou 'z'"
        self.axis = axis

        self.vtk_image1 = vtk_image1
        self.vtk_image2 = vtk_image2

        self.render_window = vtk.vtkRenderWindow()
        self.render_window.SetSize(1800, 600)

        self.renderer_left = vtk.vtkRenderer()
        self.renderer_left.SetViewport(0.0, 0.0, 0.33, 1.0)
        self.render_window.AddRenderer(self.renderer_left)

        self.renderer_middle = vtk.vtkRenderer()
        self.renderer_middle.SetViewport(0.33, 0.0, 0.66, 1.0)
        self.render_window.AddRenderer(self.renderer_middle)

        self.renderer_right = vtk.vtkRenderer()
        self.renderer_right.SetViewport(0.66, 0.0, 1.0, 1.0)
        self.render_window.AddRenderer(self.renderer_right)

        self.viewer_left = vtk.vtkImageViewer2()
        self.viewer_middle = vtk.vtkImageViewer2()
        self.viewer_right = vtk.vtkImageViewer2()
        self.diff_vtk_image = vtk.vtkImageData()

        self.viewer_left.SetInputData(self.vtk_image1)
        self.viewer_middle.SetInputData(self.vtk_image2)
        self.viewer_right.SetInputData(self.diff_vtk_image)

        self.viewer_left.SetRenderWindow(self.render_window)
        self.viewer_left.SetRenderer(self.renderer_left)

        self.viewer_middle.SetRenderWindow(self.render_window)
        self.viewer_middle.SetRenderer(self.renderer_middle)

        self.viewer_right.SetRenderWindow(self.render_window)
        self.viewer_right.SetRenderer(self.renderer_right)

        extent = vtk_image1.GetExtent()
        self.extent = extent
        if axis == 'z':
            self.min_slice = extent[4]
            self.max_slice = extent[5]
        elif axis == 'y':
            self.min_slice = extent[2]
            self.max_slice = extent[3]
        elif axis == 'x':
            self.min_slice = extent[0]
            self.max_slice = extent[1]

        self.slice = (self.min_slice + self.max_slice) // 2

        self.viewer_left.SetSliceOrientation(self.get_orientation())
        self.viewer_middle.SetSliceOrientation(self.get_orientation())
        self.viewer_right.SetSliceOrientation(self.get_orientation())

        self.viewer_left.SetSlice(self.slice)
        self.viewer_middle.SetSlice(self.slice)
        self.viewer_right.SetSlice(0)  # diff image

        for v in [self.viewer_left, self.viewer_middle]:
            v.SetColorWindow(255)
            v.SetColorLevel(128)

        self.renderer_left.SetBackground(0.1, 0.1, 0.1)
        self.renderer_middle.SetBackground(0.1, 0.1, 0.1)
        self.renderer_right.SetBackground(0.1, 0.1, 0.1)

        self.interactor = vtk.vtkRenderWindowInteractor()
        self.interactor.SetRenderWindow(self.render_window)
        self.interactor.AddObserver("KeyPressEvent", self.on_key_press)

        self.render_window.SetWindowName(title)

        self.update_difference_image()
        self.render_window.Render()

    def get_orientation(self):
        if self.axis == 'z':
            return 2  # axial
        elif self.axis == 'y':
            return 1  # coronal
        elif self.axis == 'x':
            return 0  # sagittal

    def extract_slice_numpy(self, vtk_image, slice_index):
        extent = vtk_image.GetExtent()
        dims = (extent[1] - extent[0] + 1, extent[3] - extent[2] + 1, extent[5] - extent[4] + 1)
        scalars = vtk_image.GetPointData().GetScalars()
        np_image = numpy_support.vtk_to_numpy(scalars)
        np_image = np_image.reshape(dims[::-1])  # z,y,x

        if self.axis == 'z':
            return np_image[slice_index, :, :]
        elif self.axis == 'y':
            return np_image[:, slice_index, :]
        elif self.axis == 'x':
            return np_image[:, :, slice_index]

    def update_difference_image(self):
        slice1 = self.extract_slice_numpy(self.vtk_image1, self.slice)
        slice2 = self.extract_slice_numpy(self.vtk_image2, self.slice)
        diff = np.abs(slice1.astype(np.float32) - slice2.astype(np.float32))
        seuil = 20
        diff_mask = diff > seuil
        height, width = diff.shape
        rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
        rgb_image[diff_mask, 0] = 255

        vtk_rgb = vtk.vtkImageData()
        vtk_rgb.SetDimensions(width, height, 1)
        vtk_rgb.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 3)
        flat_rgb = rgb_image.reshape(-1, 3)
        vtk_array = numpy_support.numpy_to_vtk(flat_rgb, deep=True, array_type=vtk.VTK_UNSIGNED_CHAR)
        vtk_rgb.GetPointData().SetScalars(vtk_array)
        self.diff_vtk_image.ShallowCopy(vtk_rgb)

    def on_key_press(self, obj, event):
        key = obj.GetKeySym()
        if key == "Up" and self.slice < self.max_slice:
            self.slice += 1
        elif key == "Down" and self.slice > self.min_slice:
            self.slice -= 1
        else:
            return
        self.viewer_left.SetSlice(self.slice)
        self.viewer_middle.SetSlice(self.slice)
        self.update_difference_image()
        self.render_window.Render()

    def start(self):
        self.interactor.Initialize()
        self.interactor.Start()


In [93]:
# Axiale (Z)
viewer_z = SliceViewerWithDifference(image1_vtk, image2_vtk, axis='z', title="Coupe axiale (Z)")
viewer_z.start()

# Coronale (Y)
viewer_y = SliceViewerWithDifference(image1_vtk, image2_vtk, axis='y', title="Coupe coronale (Y)")
viewer_y.start()

# Sagittale (X)
viewer_x = SliceViewerWithDifference(image1_vtk, image2_vtk, axis='x', title="Coupe sagittale (X)")
viewer_x.start()


# Recalage des images

In [67]:
def rigid_registration(fixed_path, moving_path, output_path=None):
    fixed = itk.imread(fixed_path, itk.F)
    moving = itk.imread(moving_path, itk.F)

    TransformType = itk.BSplineTransform[itk.D, 3, 3]
    initial_transform = TransformType.New()

    MetricType = itk.MattesMutualInformationImageToImageMetricv4[
        itk.Image[itk.F, 3], itk.Image[itk.F, 3]
    ]
    metric = MetricType.New()
    metric.SetNumberOfHistogramBins(50)

    OptimizerType = itk.RegularStepGradientDescentOptimizerv4[itk.D]
    optimizer = OptimizerType.New()
    optimizer.SetLearningRate(4.0)
    optimizer.SetMinimumStepLength(0.001)
    optimizer.SetNumberOfIterations(200)

    RegistrationType = itk.ImageRegistrationMethodv4[
        itk.Image[itk.F, 3], itk.Image[itk.F, 3]
    ]
    registration = RegistrationType.New()
    registration.SetFixedImage(fixed)
    registration.SetMovingImage(moving)
    registration.SetInitialTransform(initial_transform)
    registration.SetMetric(metric)
    registration.SetOptimizer(optimizer)
    registration.SetShrinkFactorsPerLevel([4,2,1])
    registration.SetSmoothingSigmasPerLevel([2,1,0])

    registration.Update()
    final_transform = registration.GetTransform()

    ResampleFilterType = itk.ResampleImageFilter[
        itk.Image[itk.F,3], itk.Image[itk.F,3]
    ]
    resampler = ResampleFilterType.New()
    resampler.SetInput(moving)
    resampler.SetTransform(final_transform)
    resampler.SetUseReferenceImage(True)
    resampler.SetReferenceImage(fixed)
    resampler.SetInterpolator(
        itk.LinearInterpolateImageFunction[
            itk.Image[itk.F,3], itk.D
        ].New()
    )
    resampler.Update()

    if output_path:
        itk.imwrite(resampler.GetOutput(), output_path)
    return final_transform, resampler.GetOutput()

In [68]:
IMAGE_RECALEE_PATH = "Data/case6_gre2_rigid.nrrd"

transform, moved = rigid_registration(
    IMAGE1_PATH,
    IMAGE2_PATH,
    output_path=IMAGE_RECALEE_PATH
)

itk.imwrite(moved, IMAGE_RECALEE_PATH)

# Segmentation de la tumeur

# Analyse et visualisation des changements