In [10]:
import os
os.environ["JAXTYPING_DISABLE"] = "1" 

import torch
import skshapes as sks
import pyvista as pv
import numpy as np

liver1 = sks.Sphere() 
liver1.points = liver1.points * torch.tensor([10.0, 8.0, 8.0]) 
liver1.points = liver1.points + torch.tensor([0.0, 0.0, 0.0]) # Centré en 0,0

liver2 = sks.Sphere()
liver2.points = liver2.points * torch.tensor([8.0, 6.0, 7.0])
liver2.points = liver2.points + torch.tensor([0.5, 0.0, 0.0]) 

source = liver1
target = liver2

source.control_points = source.bounding_grid(N=10, offset=0.5)

model = sks.ExtrinsicDeformation(
    n_steps=10,             # Fluidité de la déformation
    kernel="gaussian",      # Type de noyau
    scale=5.0,              # Rigidité (plus grand = déformation plus globale)
    control_points=True
)

loss = sks.VarifoldLoss(radial_bandwidth=2.0)

registration = sks.Registration(
    model=model,
    loss=loss,
    optimizer=sks.LBFGS(),
    n_iter=15,              # Nombre d'itérations
    verbose=True,           # Afficher la progression
    regularization_weight=0.1 # Pénalité de déformation (0.1 = équilibré)
)

morphed = registration.fit_transform(source=source, target=target)

# VISUALISATION AVEC PYVISTA
pv_target = target.to_pyvista()   # Cible (Bleu)
pv_source = source.to_pyvista()   # Source (Gris/Fil de fer)
pv_morphed = morphed.to_pyvista() # Résultat (Rouge)

# Création de la scène (2 vues côte à côte)
plotter = pv.Plotter(shape=(1, 2))

# Vue de Gauche : Avant
plotter.subplot(0, 0)
plotter.add_text("AVANT Recalage", font_size=10)
plotter.add_mesh(pv_target, color="blue", opacity=0.3, label="Cible (Post-op)")
plotter.add_mesh(pv_source, color="silver", style="wireframe", line_width=2, label="Source (Pré-op)")
plotter.add_legend()

# Vue de Droite : Après
plotter.subplot(0, 1)
plotter.add_text("APRÈS Recalage", font_size=10)
plotter.add_mesh(pv_target, color="blue", opacity=0.3, label="Cible")
plotter.add_mesh(pv_morphed, color="red", opacity=0.8, show_edges=True, label="Recalé")
plotter.add_legend()

# Lier les caméras pour zoomer sur les deux en même temps
plotter.link_views()

plotter.show()

Initial loss : 7.62e+02
  = 7.62e+02 + 0.1 * 0.00e+00 (fidelity + regularization_weight * regularization)
Loss after 1 iteration(s) : 4.60e-01
  = 8.79e-03 + 0.1 * 4.51e+00 (fidelity + regularization_weight * regularization)
Loss after 2 iteration(s) : 4.60e-01
  = 8.79e-03 + 0.1 * 4.51e+00 (fidelity + regularization_weight * regularization)
Loss after 3 iteration(s) : 4.60e-01
  = 8.79e-03 + 0.1 * 4.51e+00 (fidelity + regularization_weight * regularization)
Loss after 4 iteration(s) : 4.60e-01
  = 8.79e-03 + 0.1 * 4.51e+00 (fidelity + regularization_weight * regularization)
Loss after 5 iteration(s) : 4.60e-01
  = 8.79e-03 + 0.1 * 4.51e+00 (fidelity + regularization_weight * regularization)
Loss after 6 iteration(s) : 4.60e-01
  = 8.79e-03 + 0.1 * 4.51e+00 (fidelity + regularization_weight * regularization)
Loss after 7 iteration(s) : 4.60e-01
  = 8.79e-03 + 0.1 * 4.51e+00 (fidelity + regularization_weight * regularization)
Loss after 8 iteration(s) : 4.60e-01
  = 8.79e-03 + 0.1 * 4.5

Widget(value='<iframe src="http://localhost:49570/index.html?ui=P_0x1759034a0_9&reconnect=auto" class="pyvista…