In [28]:
"""LDDMM functions.

Set of functions that wrap around [deformetrica](
https://gitlab.com/icm-institute/aramislab/deformetrica) to perform registration, parallel
transport, geodesic and spline regression with the Large Deformations Diffeomorphic Metric
Mapping (LDDMM) framework.

For a brief introduction to LDDMM see [NG22](https://hal.science/tel-03563980v1) chapter 5.
"""

import time
from os.path import join
import sys
import pandas as pd
import nibabel.freesurfer
from pathlib import Path

import numpy as np
import pyvista as pv
import torch
from api.deformetrica import Deformetrica
from in_out.array_readers_and_writers import read_2D_array, read_3D_array
from launch.compute_parallel_transport import compute_pole_ladder
from launch.compute_shooting import compute_shooting
from support.kernels.torch_kernel import TorchKernel

In [4]:
cp_str = "DeterministicAtlas__EstimatedParameters__ControlPoints.txt"
registration_str = "DeterministicAtlas__flow__shape__subject_ventricle__tp_{}.vtk"
template_str = "DeterministicAtlas__EstimatedParameters__Template_shape.vtk"
momenta_str = "DeterministicAtlas__EstimatedParameters__Momenta.txt"
residual_str = "DeterministicAtlas__EstimatedParameters__Residuals.txt"
residual_str_spline = "SplineRegression__EstimatedParameters__Residuals.txt"
mom_str_spline = "SplineRegression__EstimatedParameters__Momenta.txt"
cp_str_spline = "SplineRegression__EstimatedParameters__ControlPoints.txt"
ext_forces_str = "SplineRegression__EstimatedParameters__ExternalForces.txt"
regression_str = "SplineRegression__Reconstruction__shape__tp_{}__age_{:.2f}.vtk"
shoot_str = "Shooting__GeodesicFlow__shape__tp_{}__age_1.00.vtk"

In [5]:
def registration(
    source,
    target,
    output_dir,
    kernel_width=20.0,
    regularisation=1.0,
    number_of_time_steps=10,
    metric="landmark",
    kernel_type="torch",
    kernel_device="cuda",
    tol=1e-5,
    use_svf=False,
    initial_control_points=None,
    max_iter=200,
    freeze_control_points=False,
    use_rk2_for_shoot=False,
    use_rk2_for_flow=False,
    dimension=3,
    use_rk4_for_shoot=False,
    preserve_volume=False,
    print_every=20,
    filter_cp=False,
    threshold=1.0,
    attachment_kernel_width=4.0,
):
    r"""Registration.

    Estimates the best possible deformation between two shapes, i.e. minimizes the following
    criterion:

    ..math::

         C(c, \mu) = \frac{1}{\alpha^2} d(q, \phi_1^{c,\mu}(\bar{q}))^2 + \| v_0^{c,
         \mu} \|_K^2.

    where $c, \mu$ are the control points and momenta that parametrize the deformation, $v_0^{c,
    \mu}$ is the associated velocity field defined by the convolution $v_t(x) = \sum_{k=1}^{N_c}
    K(x, c^{(t)}_k) \mu^{(t)}_K$, K is the Gaussian kernel, $\phi_1^{c,\mu}$ is the flow of $v_t$
    at time 1, $\bar{q}$ is the source shape being deformed, $q$ is the target shape,
    and $\alpha$ is a regularization term that controls the tradeoff between exact matching and
    smoothness of the deformation. $d$ is a distance function on shapes (point-to-point L2,
    varifold, metric, etc).

    Control points can be passed as parameters or are initialized on a grid that contains the
    source shapes. They are optimized if `freeze_control_points` is set to false.

    Resulting control points and momenta are saved in the ouput dir as txt files. Control points
    are also saved with attached momenta as a vtk file to allow visualization with paraview.

    Parameters
    ----------
    source: str or pathlib.Path
        Path to the vtk file that contains the source mesh.
    target: str or pathlib.Path
        Path to the vtk file that contains the target mesh.
    output_dir: str or pathlib.Path
        Path a directory where results will be saved.
    kernel_width: float
        Width of the Gaussian kernel. Controls the spatial smoothness of the deformation and
        influences the number of parameters required to represent the deformation.
        Optional, default: 20.
    regularisation: float
        $\alpha$ in the above equation. Smaller values will yeild larger deformations to reduce
        the data attachment term, while larger values will allow attachment errors for a smoother
        deformation.
        Optional, default: 1.
    number_of_time_steps: int
        Number used in the discretization of the flow equation.
        Optional, default: 10.
    metric: str, {landmark, varifold, current}
        Metric to use to measure attachment between meshes. Landmark refers to L2.
    attachment_kernel_width: float,
        If using varifold or currents, width of the kernel used in the attachment metric. Defines
        the scale at which differences must be taken into account.
    dimension: int {2, 3}
        Dimension of the shape embedding space.
    kernel_type: str, {torch, keops}
        Package to use for convolutions of velocity fields and loss functions.
    kernel_device: str, {cuda, cpu}
    use_svf: bool
        Whether to use stationary velocity fields instead of time evolving velocity. The
        deformation is no longer a geodesic but there is more symmetry wrt source / target.
        Optional, default: False
    initial_control_points: str or pathlib.Path
        Path to the txt file that contains the initial control points.
        Optional
    freeze_control_points: bool
        Whether to optimize control points jointly with momenta.
        Optional, default: False
    preserve_volume: bool
        Whether to use volume preserving deformation. This modifies the metric on deformations.
        Optional, default: False
    use_rk2_for_flow: bool
        Whether to use Runge-Kutta order 2 steps in the integration of the flow equation, i.e. when
        warping the shape. If False, a Euler step is used.
        Optional, default: False
    use_rk2_for_shoot: bool
        Whether to use Runge-Kutta order 2 steps in the integration of the Hamiltonian equation that
        governs the time evolution of control points and momenta. If False, a Euler step is used.
        Optional, default: False
    use_rk4_for_shoot: bool
        Whether to use Runge-Kutta order 4 steps in the integration of the Hamiltonian equation that
        governs the time evolution of control points and momenta. Overrides use_rk2_for_shoot.
        RK4 steps are required when estimating a geodesic that will be used for parallel transport.
        Optional, default: False
    print_every: int
        Sets the verbosity level of the optimization scheme.
    filter_cp: bool
        Whether to filter control points saved in the vtk file to exclude those whose momenum
        vector is not significative and does not contribute to the deformation.
        Optional, default: False
    threshold: float
        Threshold to use on momenta norm when filtering. Ignored if `filter_cp` is set to `False`.
    max_iter: int
        Maximum number of iteration in the optimization scheme.
        Optional, default: 200.
    tol: float
        Tolerance to evaluate convergence.
    """
    optimization_parameters = {
        "max_iterations": max_iter,
        "freeze_template": False,
        "freeze_control_points": freeze_control_points,
        "freeze_momenta": False,
        "use_sobolev_gradient": True,
        "sobolev_kernel_width_ratio": 1,
        "max_line_search_iterations": 50,
        "initial_control_points": initial_control_points,
        "initial_cp_spacing": None,
        "initial_momenta": None,
        "dense_mode": False,  # dense is for image vs mesh data
        "number_of_threads": 1,
        "print_every_n_iters": print_every,
        "downsampling_factor": 1,
        "dimension": dimension,
        "optimization_method_type": "ScipyLBFGS",
        "convergence_tolerance": tol,
    }

    # register source on target
    deformetrica = Deformetrica(output_dir, verbosity="DEBUG")

    model_options = {
        "deformation_kernel_type": kernel_type,
        "deformation_kernel_width": kernel_width,
        "deformation_kernel_device": kernel_device,
        "use_svf": use_svf,
        "preserve_volume": preserve_volume,
        "number_of_time_points": number_of_time_steps + 1,
        "use_rk2_for_shoot": use_rk2_for_shoot,
        "use_rk4_for_shoot": use_rk4_for_shoot,
        "use_rk2_for_flow": use_rk2_for_flow,
        "freeze_template": False,
        "freeze_control_points": freeze_control_points,
        "initial_control_points": initial_control_points,
        "dimension": dimension,
        "output_dir": output_dir,
    }

    template = {
        "shape": {
            "deformable_object_type": "SurfaceMesh",
            "kernel_type": kernel_type,
            "kernel_width": attachment_kernel_width,
            "kernel_device": kernel_device,
            "noise_std": regularisation,
            "filename": source,
            "noise_variance_prior_scale_std": None,
            "noise_variance_prior_normalized_dof": 0.01,
            "attachment_type": metric,
        }
    }

    data_set = {
        "visit_ages": [[]],
        "dataset_filenames": [[{"shape": target}]],
        "subject_ids": ["ventricle"],
    }

    deformetrica.estimate_registration(
        template_specifications=template,
        dataset_specifications=data_set,
        model_options=model_options,
        estimator_options=optimization_parameters,
    )

    path_cp = join(output_dir, lddmm_strings.cp_str)
    cp = read_2D_array(path_cp)

    path_momenta = join(output_dir, lddmm_strings.momenta_str)
    momenta = read_3D_array(path_momenta)
    poly_cp = momenta_to_vtk(cp, momenta, kernel_width, filter_cp, threshold)
    poly_cp.save(join(output_dir, "initial_control_points.vtk"))
    pv.read(target).save(join(output_dir, "target_shape.vtk"))
    return time.gmtime()


def spline_regression(
    source,
    targets,
    output_dir,
    times,
    subject_id=None,
    t0=0,
    max_iter=200,
    kernel_width=15.0,
    regularisation=1.0,
    number_of_time_steps=10,
    initial_step_size=1e-4,
    kernel_type="torch",
    kernel_device="cuda",
    initial_control_points=None,
    tol=1e-5,
    freeze_control_points=False,
    use_rk2_for_flow=False,
    use_rk2_for_shoot=False,
    dimension=3,
    freeze_external_forces=False,
    target_weights=None,
    geodesic_weight=0.1,
    metric="landmark",
    filter_cp=False,
    threshold=1.0,
    attachment_kernel_width=15.0,
    print_every=20,
):
    r"""Geodesic or Spline Regression.

    Estimates the best possible time-constrained deformation to fit a set of observations indexed
    by a covariable.

    The following criterion is minimized:

    ..math::

        C_S(c, \mu, u_t) &=  \frac{1}{\alpha^2d} \sum_{i=1}^d d( x_{t_i}, \phi_{t_i}(x_{t_0}))^2  +
        \int_0^1 \|u^{(t)}\|^2 dt + \|v_0^{c,\mu}\|_K^2,

    where $x_{t_i}$ are the observations observed at variable $t_i$, $c,\mu, u$ parametrize the
    deformation. $c,\mu$ define a velocity field by the convolution $v_t(x) = \sum_{k=1}^{N_c}
    K(x, c^{(t)}_k) \mu^{(t)}_K$ where K is the Gaussian kernel. $u^t$ is a second-order term
    that can be interpreted as random external forces smoothly perturbing the trajectory around a
    mean geodesic. If `freeze_external_forces` is set to True, they are fixed to 0 and in this
    case the regression model estimates a geodesic.

    Parameters
    ----------
    source: str or pathlib.Path
        Path to the vtk file that contains the source mesh.
    targets: list of dict
        Path to the vtk files that contain the target meshes. Must be formatted as a list of
        dictionaries, where each dict represents a time points and has a key 'shape' with the
        path to the shape as value.
    times: list of floats in [0, 1]
        Covariable used in the regression.
    subject_id: list of str
        Not used.
    t0: float
        Time of the first shape.
    initial_step_size: float
        Initial learning rate.
        Optional, default: 1e-4.
    freeze_external_forces: bool
        Whether to use external forces in the regression model. When used, splines are used
        instead of geodesics.
    output_dir: str or pathlib.Path
        Path to a directory where results will be saved. It will be created if it does not
        already exist.
    kernel_width: float
        Width of the Gaussian kernel. Controls the spatial smoothness of the deformation and
        influences the number of parameters required to represent the deformation.
        Optional, default: 20.
    regularisation: float
        $\alpha$ in the above equation. Smaller values will yeild larger deformations to reduce
        the data attachment term, while larger values will allow attachment errors for a smoother
        deformation.
        Optional, default: 1.
    number_of_time_steps: int
        Number used in the discretization of the flow equation.
        Optional, default: 10.
    metric: str, {landmark, varifold, current}
        Metric to use to measure attachment between meshes. Landmark refers to L2.
    attachment_kernel_width: float,
        If using varifold or currents, width of the kernel used in the attachment metric. Defines
        the scale at which differences must be taken into account.
    dimension: int {2, 3}
        Dimension of the shape embedding space.
    kernel_type: str, {torch, keops}
        Package to use for convolutions of velocity fields and loss functions.
    kernel_device: str, {cuda, cpu}
    initial_control_points: str or pathlib.Path
        Path to the txt file that contains the initial control points.
        Optional
    freeze_control_points: bool
        Whether to optimize control points jointly with momenta.
        Optional, default: False
    use_rk2_for_flow: bool
        Whether to use Runge-Kutta order 2 steps in the integration of the flow equation, i.e. when
        warping the shape. If False, a Euler step is used.
        Optional, default: False
    use_rk2_for_shoot: bool
        Whether to use Runge-Kutta order 2 steps in the integration of the Hamiltonian equation that
        governs the time evolution of control points and momenta. If False, a Euler step is used.
        Optional, default: False
    print_every: int
        Sets the verbosity level of the optimization scheme.
    target_weights: list or array
        Coefficient to weight observations' contributions to the loss function.
    geodesic_weight: float
        Coefficient to weight the geodesic part compared to the external forces.
        Optional, default: 0.1.
    filter_cp: bool
        Whether to filter control points saved in the vtk file to exclude those whose momenum
        vector is not significative and does not contribute to the deformation.
        Optional, default: False
    threshold: float
        Threshold to use on momenta norm when filtering. Ignored if `filter_cp` is set to `False`.
    max_iter: int
        Maximum number of iteration in the optimization scheme.
        Optional, default: 200.
    tol: float
        Tolerance to evaluate convergence.
    """
    if subject_id is None:
        subject_id = ["patient"]
    template = {
        "shape": {
            "deformable_object_type": "SurfaceMesh",
            "kernel_type": kernel_type,
            "kernel_width": attachment_kernel_width,
            "kernel_device": kernel_device,
            "noise_std": regularisation,
            "filename": source,
            "noise_variance_prior_scale_std": None,
            "noise_variance_prior_normalized_dof": 0.01,
            "attachment_type": metric,
        }
    }

    data_set = {
        "visit_ages": [times],
        "dataset_filenames": [targets],
        "subject_ids": subject_id,
    }

    model = {
        "deformation_kernel_type": kernel_type,
        "deformation_kernel_width": kernel_width,
        "deformation_kernel_device": kernel_device,
        "number_of_time_points": number_of_time_steps + 1,
        "concentration_of_time_points": number_of_time_steps,
        "use_rk2_for_flow": use_rk2_for_flow,
        "use_rk2_for_shoot": use_rk2_for_shoot,
        "freeze_template": True,
        "freeze_control_points": freeze_control_points,
        "freeze_external_forces": freeze_external_forces,
        "freeze_momenta": False,
        "freeze_noise_variance": False,
        "use_sobolev_gradient": True,
        "sobolev_kernel_width_ratio": 1,
        "initial_control_points": initial_control_points,
        "initial_cp_spacing": None,
        "initial_momenta": None,
        "dense_mode": False,
        "number_of_processes": 1,
        "dimension": dimension,
        "random_seed": None,
        "t0": t0,
        "tmin": min(times),
        "tmax": max(times),
        "target_weights": target_weights,
        "geodesic_weight": geodesic_weight,
    }

    optimization_parameters = {
        "initial_step_size": initial_step_size,
        "scale_initial_step_size": True,
        "line_search_shrink": 0.5,
        "line_search_expand": 1.5,
        "max_line_search_iterations": 30,
        "optimized_log_likelihood": "complete",
        "optimization_method_type": "ScipyLBFGS",
        "max_iterations": max_iter,
        "convergence_tolerance": tol,
        "print_every_n_iters": print_every,
        "save_every_n_iters": 100,
        "state_file": None,
        "load_state_file": False,
    }

    if subject_id != "patient":
        patient_output_dir = join(output_dir, subject_id[0])
    else:
        patient_output_dir = output_dir

    deformetrica = Deformetrica(patient_output_dir, verbosity="DEBUG")
    deformetrica.estimate_spline_regression(
        template_specifications=template,
        dataset_specifications=data_set,
        model_options=model,
        estimator_options=optimization_parameters,
    )

    # agregate results in vtk file for paraview
    path_cp = join(output_dir, lddmm_strings.cp_str_spline)
    cp = read_2D_array(path_cp)
    path_momenta = join(output_dir, lddmm_strings.mom_str_spline)
    momenta = read_3D_array(path_momenta)
    poly_cp = momenta_to_vtk(cp, momenta, kernel_width, filter_cp, threshold)
    poly_cp.save(join(output_dir, "initial_control_points.vtk"))

    if not freeze_external_forces:
        forces = read_3D_array(join(output_dir, lddmm_strings.ext_forces_str))
        external_forces_to_vtk(cp, forces, output_dir, filter_cp, threshold)

    return time.gmtime()


def transport(
    control_points,
    momenta,
    control_points_to_transport,
    momenta_to_transport,
    output_dir,
    kernel_type="torch",
    kernel_width=15,
    kernel_device="cuda",
    n_rungs=10,
):
    """Compute parallel transport with the pole ladder.

    Transports a tangent vector along a geodesic (called main geodesic). Both must have been
    estimated by using the `registration` function. The main geodesic must be estimated using RK4
    steps. Kernel parameters should match the ones used in the registration function.

    Parameters
    ----------
    control_points: str or pathlib.Path
        Path to the txt file that contains the initial control points for the main geodesic.
    momenta: str or pathlib.Path
        Path to the txt file that contains the initial momenta for the main geodesic.
    control_points_to_transport: str or pathlib.Path
        Path to the txt file that contains the initial control points of the deformation to
        transport.
    momenta_to_transport: str or pathlib.Path
        Path to the txt file that contains the initial momenta to transport.
    output_dir: str or pathlib.Path
        Path to a directory where results will be saved. It will be created if it does not
        already exist.
    kernel_width: float
        Width of the Gaussian kernel. Controls the spatial smoothness of the deformation and
        influences the number of parameters required to represent the deformation.
        Optional, default: 20.
    kernel_type: str, {torch, keops}
        Package to use for convolutions of velocity fields and loss functions.
    kernel_device: str, {cuda, cpu}
    n_rungs: int
        Number of discretization steps in the pole ladder algorithm. Should match
        number_of_time_points in the registration of the main geodesic.
        Optional, default: 10.
    """
    Deformetrica(output_dir, verbosity="INFO")

    deformation_parameters = {
        "deformation_kernel_type": kernel_type,
        "deformation_kernel_width": kernel_width,
        "deformation_kernel_device": kernel_device,
        "concentration_of_time_points": n_rungs,
        "number_of_time_points": n_rungs + 1,
        "tmin": 0,
        "tmax": 1,
        "output_dir": output_dir,
    }

    transported_cp, transported_mom = compute_pole_ladder(
        initial_control_points=control_points,
        initial_momenta=momenta,
        initial_momenta_to_transport=momenta_to_transport,
        initial_control_points_to_transport=control_points_to_transport,
        **deformation_parameters,
    )
    return transported_cp, transported_mom


def shoot(
    source,
    control_points,
    momenta,
    output_dir,
    kernel_width=20.0,
    regularisation=1.0,
    number_of_time_steps=10,
    kernel_type="torch",
    kernel_device="cuda",
    write_params=True,
    deformation="geodesic",
    external_forces=None,
    use_rk2_for_flow=False,
    use_rk2_for_shoot=False,
):
    """Exponential map.

    Compute the deformation of a source shape by the flow parametrized by control points and
    momenta.

    Parameters
    ----------
    source: str or pathlib.Path
        Path to the vtk file that contains the source mesh.
    control_points: str or pathlib.Path
        Path to the txt file that contains the initial control points.
    momenta: str or pathlib.Path
        Path to the txt file that contains the initial momenta.
    kernel_width: float
        Width of the Gaussian kernel. Controls the spatial smoothness of the deformation and
        influences the number of parameters required to represent the deformation.
        Optional, default: 20.
    kernel_type: str, {torch, keops}
        Package to use for convolutions of velocity fields and loss functions.
    kernel_device: str, {cuda, cpu}
    regularisation: unused.
    write_params: bool
    external_forces: str or pathlib.Path
        Path to the vtk file that contains the external forces to compute a spline deformation.
    use_rk2_for_flow: bool
        Whether to use Runge-Kutta order 2 steps in the integration of the flow equation, i.e. when
        warping the shape. If False, a Euler step is used.
        Optional, default: False
    use_rk2_for_shoot: bool
        Whether to use Runge-Kutta order 2 steps in the integration of the Hamiltonian equation that
        governs the time evolution of control points and momenta. If False, a Euler step is used.
        Optional, default: False
    """
    deformation_parameters = {
        "deformation_model": deformation,
        "deformation_kernel_type": kernel_type,
        "deformation_kernel_width": kernel_width,
        "deformation_kernel_device": kernel_device,
        "concentration_of_time_points": number_of_time_steps,
        "number_of_time_points": number_of_time_steps + 1,
        "use_rk2_for_flow": use_rk2_for_flow,
        "use_rk2_for_shoot": use_rk2_for_shoot,
        "output_dir": output_dir,
        "write_adjoint_parameters": write_params,
    }

    template_specifications = {
        "shape": {
            "deformable_object_type": "landmark",
            "kernel_type": kernel_type,
            "kernel_width": kernel_width,
            "kernel_device": kernel_device,
            "noise_std": regularisation,
            "filename": source,
            "noise_variance_prior_scale_std": None,
            "noise_variance_prior_normalized_dof": 0.01,
        }
    }

    Deformetrica(output_dir, verbosity="INFO")
    compute_shooting(
        template_specifications,
        initial_control_points=control_points,
        external_forces=external_forces,
        initial_momenta=momenta,
        **deformation_parameters,
    )

    return time.gmtime()


def deterministic_atlas(
    source,
    targets,
    subject_id,
    output_dir,
    t0=0,
    max_iter=200,
    kernel_width=15.0,
    regularisation=1.0,
    number_of_time_steps=11,
    metric="landmark",
    kernel_type="torch",
    kernel_device="auto",
    initial_control_points=None,
    tol=1e-5,
    freeze_control_points=False,
    use_rk2_for_flow=False,
    use_rk2_for_shoot=False,
    preserve_volume=False,
    use_svf=False,
    dimension=3,
    print_every=20,
    attachment_kernel_width=4.0,
    initial_step_size=1e-4,
    **kwargs,
):
    r"""Atlas computation.

    Estimates an average shape from a collection of shapes, and the
    deformations from this average to each sample in the collection. This is similar to
    computing a Frechet mean, i.e. to minimize the following:

    ..math::

         C(c, \mu) = \frac{1}{\alpha^2} \sum_i \left( d(q_i, \phi_1^{c_i,\mu_i}(\bar{q}))^2 + \|
         v_0^{c_i,\mu_i} \|_K^2 \right).

    where $c_i, \mu_i$ are the control points and momenta that parametrize the deformation, $v_0^{c,
    \mu}$ is the associated velocity field defined by the convolution $v_t(x) = \sum_{k=1}^{N_c}
    K(x, c^{(t)}_k) \mu^{(t)}_K$, K is the Gaussian kernel, $\phi_1^{c,\mu}$ is the flow of $v_t$
    at time 1, $\bar{q}$ is the source shape being deformed, $q$ is the target shape,
    and $\alpha$ is a regularization term that controls the tradeoff between exact matching and
    smoothness of the deformation. $d$ is a distance function on shapes (point-to-point L2,
    varifold, metric, etc).

    Resulting control points and momenta are saved in the ouput dir as txt files.

    Parameters
    ----------
    source: str or pathlib.Path
        Path to the vtk file that contains the source mesh.
    targets: list of dict
        Path to the vtk files that contain the target meshes. Must be formatted as a list of
        dictionaries, where each dict represents a time points and has a key 'shape' with the
        path to the shape as value.
    output_dir: str or pathlib.Path
        Path a directory where results will be saved.
    kernel_width: float
        Width of the Gaussian kernel. Controls the spatial smoothness of the deformation and
        influences the number of parameters required to represent the deformation.
        Optional, default: 20.
    regularisation: float
        $\alpha$ in the above equation. Smaller values will yeild larger deformations to reduce
        the data attachment term, while larger values will allow attachment errors for a smoother
        deformation.
        Optional, default: 1.
    number_of_time_steps: int
        Number used in the discretization of the flow equation.
        Optional, default: 10.
    metric: str, {landmark, varifold, current}
        Metric to use to measure attachment between meshes. Landmark refers to L2.
    attachment_kernel_width: float,
        If using varifold or currents, width of the kernel used in the attachment metric. Defines
        the scale at which differences must be taken into account.
    dimension: int {2, 3}
        Dimension of the shape embedding space.
    kernel_type: str, {torch, keops}
        Package to use for convolutions of velocity fields and loss functions.
    kernel_device: str, {cuda, cpu}
    use_svf: bool
        Whether to use stationary velocity fields instead of time evolving velocity. The
        deformation is no longer a geodesic but there is more symmetry wrt source / target.
        Optional, default: False
    initial_control_points: str or pathlib.Path
        Path to the txt file that contains the initial control points.
        Optional
    freeze_control_points: bool
        Whether to optimize control points jointly with momenta.
        Optional, default: False
    preserve_volume: bool
        Whether to use volume preserving deformation. This modifies the metric on deformations.
        Optional, default: False
    use_rk2_for_flow: bool
        Whether to use Runge-Kutta order 2 steps in the integration of the flow equation, i.e. when
        warping the shape. If False, a Euler step is used.
        Optional, default: False
    use_rk2_for_shoot: bool
        Whether to use Runge-Kutta order 2 steps in the integration of the Hamiltonian equation that
        governs the time evolution of control points and momenta. If False, a Euler step is used.
        Optional, default: False
    print_every: int
        Sets the verbosity level of the optimization scheme.
    max_iter: int
        Maximum number of iteration in the optimization scheme.
        Optional, default: 200.
    tol: float
        Tolerance to evaluate convergence.
    initial_step_size: float
        Initial learning rate.
        Optional, default: 1e-4.
    """
    template = {
        "shape": {
            "deformable_object_type": "SurfaceMesh",
            "kernel_type": kernel_type,
            "kernel_width": attachment_kernel_width,
            "kernel_device": kernel_device,
            "noise_std": regularisation,
            "filename": source,
            "noise_variance_prior_scale_std": None,
            "noise_variance_prior_normalized_dof": 0.01,
            "attachment_type": metric,
        }
    }

    data_set = {
        "dataset_filenames": [[k] for k in targets],
        "visit_ages": None,  # [[1.]] * len(targets),
        "subject_ids": [subject_id] * len(targets),
    }

    model = {
        "deformation_kernel_type": kernel_type,
        "deformation_kernel_width": kernel_width,
        "deformation_kernel_device": kernel_device,
        "number_of_time_points": number_of_time_steps,
        "concentration_of_time_points": number_of_time_steps - 1,
        "use_rk2_for_flow": use_rk2_for_flow,
        "use_rk2_for_shoot": use_rk2_for_shoot,
        "use_svf": use_svf,
        "preserve_volume": preserve_volume,
        "freeze_template": False,
        "freeze_control_points": freeze_control_points,
        "freeze_momenta": False,
        "freeze_noise_variance": False,
        "use_sobolev_gradient": True,
        "sobolev_kernel_width_ratio": 1,
        "initial_control_points": initial_control_points,
        "initial_cp_spacing": None,
        "initial_momenta": None,
        "dense_mode": False,
        "number_of_processes": 1,
        "dimension": dimension,
        "random_seed": None,
        "t0": t0,
        "tmin": t0,
        "tmax": 1.0,
    }

    optimization_parameters = {
        "max_iterations": max_iter,
        "freeze_template": False,
        "freeze_control_points": freeze_control_points,
        "freeze_momenta": False,
        "use_sobolev_gradient": True,
        "sobolev_kernel_width_ratio": 1,
        "max_line_search_iterations": 50,
        "initial_control_points": initial_control_points,
        "initial_cp_spacing": None,
        "initial_momenta": None,
        "dense_mode": False,
        "number_of_threads": 1,
        "print_every_n_iters": print_every,
        "downsampling_factor": 1,
        "dimension": dimension,
        "optimization_method_type": "ScipyLBFGS",
        "convergence_tolerance": tol,
        "initial_step_size": initial_step_size,
    }

    deformetrica = Deformetrica(output_dir, verbosity="DEBUG")
    deformetrica.estimate_deterministic_atlas(
        template_specifications=template,
        dataset_specifications=data_set,
        model_options=model,
        estimator_options=optimization_parameters,
    )
    return time.gmtime()


def momenta_to_vtk(cp, momenta, kernel_width=5.0, filter_cp=True, threshold=1.0):
    """Attach momenta and velocity field to control points and save as vtk."""
    kernel = TorchKernel(kernel_width=kernel_width)
    velocity = kernel.convolve(cp, cp, momenta).cpu()

    if filter_cp:
        vel_thresholded = np.linalg.norm(velocity, axis=-1) > threshold
        cp = cp[vel_thresholded, :]
        momenta = momenta[vel_thresholded, :]
        velocity = velocity[vel_thresholded, :]

    poly = pv.PolyData(cp)
    poly["Momentum"] = momenta
    poly["Velocity"] = velocity
    return poly


def external_forces_to_vtk(cp, forces, output_dir, filter_cp=True, threshold=1.0):
    """Attach external forces to control points and save as vtk."""
    mask = np.linalg.norm(forces, axis=-1) > threshold
    for i, f in enumerate(forces[:-1]):
        filename = join(output_dir, f"cp_with_external_forces_{i}.vtk")
        if filter_cp:
            cp_filtered = cp[mask[i], :]
            f = f[mask[i]]
            poly_cp = pv.PolyData(cp_filtered)
        else:
            poly_cp = pv.PolyData(cp)
        poly_cp["external_force"] = f
        poly_cp.save(filename)


def ssd(atlas_dir, kernel_width):
    """Compute the Sum of squared Riemannian distances from atlas to subjects' shape."""
    momenta = torch.from_numpy(read_3D_array(atlas_dir / lddmm_strings.momenta_str))
    cp = torch.from_numpy(read_2D_array(atlas_dir / lddmm_strings.cp_str))
    kernel = TorchKernel(kernel_width=kernel_width)
    kernel_matrix = kernel.get_kernel_matrix(cp, cp)
    ssd = (torch.einsum("...ij,...kj->...ik", momenta, momenta) * kernel_matrix).sum()
    return ssd

In [12]:
# === CONFIG ===
# subject_name = 'sub-001'
structure_id = 17  # Left-Hippocampus
brain_structure_filename = f"resliced_mesh_{structure_id}"
shapes_dir = Path('/Users/sak/.herbrain/data/pregnancy/neuromaternal_madrid_2021/derivatives/enigma_shape')
output_dir = Path('./output')
atlas_path = Path('./output/atlas') 

In [13]:
# === Mesh Conversion ===
def nibabel_to_pyvista(mesh):
    V, F = mesh
    faces = np.hstack(np.c_[np.full(len(F), 3), F]).astype(np.int64)
    return pv.PolyData(V, faces)

In [21]:
import re

all_subject_meshes = {}

# Extract unique subject IDs
subject_ids = set()
for subdir in os.listdir(shapes_dir):
    # print(subdir)
    match = re.match(r"sub-(\w+)_ses-[34]", subdir)
    if match:
        # print("match!")
        subject_ids.add(match.group(1))

print(subject_ids)
print(len(subject_ids))
for subject_id in sorted(subject_ids):
    sub_meshes = {}
    
    ses3_path = shapes_dir / f"sub-{subject_id}_ses-3" / "resliced_mesh_17"
    ses4_path = shapes_dir / f"sub-{subject_id}_ses-4" / "resliced_mesh_17"

    if not ses3_path.exists() or not ses4_path.exists():
        print("HELP :(")
        continue

    mesh3 = nibabel.freesurfer.read_geometry(str(ses3_path))
    mesh4 = nibabel.freesurfer.read_geometry(str(ses4_path))

    pv_mesh3 = nibabel_to_pyvista(mesh3)
    pv_mesh4 = nibabel_to_pyvista(mesh4)

    # Align ses-4 to ses-3
    aligned_mesh4 = pv_mesh4.align(
        pv_mesh3,
        max_landmarks=100,
        max_mean_distance=1e-5,
        max_iterations=500,
        check_mean_distance=True,
        start_by_matching_centroids=True,
    )

    sub_meshes['t0'] = pv_mesh3
    sub_meshes['t1'] = aligned_mesh4
    all_subject_meshes[subject_id] = sub_meshes

{'ccce0782e0', '0872b8db24', '17641fc443', 'a8fd90daa4', 'fcad4b7f8c', 'b75569792e', 'b694dab0c4', '6fd380db03', '254daf33ca', 'f209c3fa4b', '67eb38f7e9', '6733e7702c', 'e4608ddfa3', 'cc8bfad990', 'f6c2d6696b', '343a5d44e7', '3c92b464a0', '049ab1a591', 'a32661b55d', '66ed43826b', 'e0e4067b5d', 'bf8f724976', '082dc070c1', '15255749dc', '3f68c4406d', '63db63081a', '39dd102cbf', 'c152c325f4', '4cede517b5', '8d5824be35', '5aa6c7faa1', '6e5a9fb778', '7304a751ea', '4cf50fd8bc', '9ae83e0927', 'efe5f05d66', '31879f2221', '26e0dc601b', 'b253f8ccbc', '6dc2a2528b', 'c01a46c7f9', '717e805d07', '5f42d47ec5', '690eec56c3', 'b62f625bc7', 'bf028a52b9', '7883dbc6da', '6b9079271c', '47fb269f0c', '68f8586911', '96fab186d7', 'ed2a398ef5', 'f922d45e9a', '604e89c97e', '74ad7c8895', 'fb91476aa2', '4732ae8cde', 'c1be0ff860', '3c4a4a2b19', 'c4c18afc24', 'd367eabca4', 'c7b1145ac8', '471fecf35b', 'f7a4c2297b', '4c1ea69777', '718923d847', 'e72d2b2eff', 'a9ae29e7b8', '9e143ef040', '7bd2334282', 'b0d275dca1', '471d

In [47]:
# all_subject_meshes

In [31]:
# File paths and settings
# subject_ids = ['subj01', 'subj02', 'subj03']
output_dir = Path("output/transport")
output_dir.mkdir(parents=True, exist_ok=True)
atlas_dir = Path("output/atlas")
atlas_dir.mkdir(parents=True, exist_ok=True)

structure = 'hippocampus'

# LDDMM settings
registration_args = {
    'kernel_width': 4,
    'regularisation': 1,
    'max_iter': 2000,
    'freeze_control_points': False,
    'metric': 'varifold',
    'attachment_kernel_width': 2.0,
    'tol': 1e-10,
    'filter_cp': True,
    'threshold': 0.75,
}

transport_args = {
    'kernel_type': 'torch',
    'kernel_width': 4,
    'n_rungs': 10
}

shoot_args = {
    'use_rk2_for_flow': True,
    'kernel_width': 4,
    'number_of_time_steps': 11,
    'write_params': False
}

# Use in-memory meshes
subject_ids = list(all_subject_meshes.keys())

data_set = []
for sid in subject_ids:
    t0_mesh = all_subject_meshes[sid]['t0']
    t0_path = output_dir / f"sub-{sid}_t0.vtk"
    t0_mesh.save(t0_path)
    # write_vtk(t0_path, t0_mesh)
    data_set.append({'shape': str(t0_path)})

In [39]:
source = str(output_dir / f"{subject_ids[0]}_t0.vtk")
targets = [str(output_dir / f"{sid}_t0.vtk") for sid in subject_ids]

In [41]:
targets

['output/transport/049ab1a591_t0.vtk',
 'output/transport/073801a359_t0.vtk',
 'output/transport/082dc070c1_t0.vtk',
 'output/transport/085a9ee2fc_t0.vtk',
 'output/transport/0872b8db24_t0.vtk',
 'output/transport/0a9b5eb6f1_t0.vtk',
 'output/transport/0b6820582a_t0.vtk',
 'output/transport/11a123772c_t0.vtk',
 'output/transport/15255749dc_t0.vtk',
 'output/transport/15901c4398_t0.vtk',
 'output/transport/17641fc443_t0.vtk',
 'output/transport/17ca82ae03_t0.vtk',
 'output/transport/1db8e32656_t0.vtk',
 'output/transport/22cdbe0c1b_t0.vtk',
 'output/transport/254daf33ca_t0.vtk',
 'output/transport/26e0dc601b_t0.vtk',
 'output/transport/2d086560ad_t0.vtk',
 'output/transport/2edf82ffc5_t0.vtk',
 'output/transport/306d3a97b3_t0.vtk',
 'output/transport/31879f2221_t0.vtk',
 'output/transport/343a5d44e7_t0.vtk',
 'output/transport/39dd102cbf_t0.vtk',
 'output/transport/3b4d062226_t0.vtk',
 'output/transport/3c4a4a2b19_t0.vtk',
 'output/transport/3c92b464a0_t0.vtk',
 'output/transport/3cb624

In [49]:
mesh_paths = [[{"shape": str(output_dir / f"{sid}_t0.vtk")}] for sid in subject_ids]
mesh_paths

[[{'shape': 'output/transport/049ab1a591_t0.vtk'}],
 [{'shape': 'output/transport/073801a359_t0.vtk'}],
 [{'shape': 'output/transport/082dc070c1_t0.vtk'}],
 [{'shape': 'output/transport/085a9ee2fc_t0.vtk'}],
 [{'shape': 'output/transport/0872b8db24_t0.vtk'}],
 [{'shape': 'output/transport/0a9b5eb6f1_t0.vtk'}],
 [{'shape': 'output/transport/0b6820582a_t0.vtk'}],
 [{'shape': 'output/transport/11a123772c_t0.vtk'}],
 [{'shape': 'output/transport/15255749dc_t0.vtk'}],
 [{'shape': 'output/transport/15901c4398_t0.vtk'}],
 [{'shape': 'output/transport/17641fc443_t0.vtk'}],
 [{'shape': 'output/transport/17ca82ae03_t0.vtk'}],
 [{'shape': 'output/transport/1db8e32656_t0.vtk'}],
 [{'shape': 'output/transport/22cdbe0c1b_t0.vtk'}],
 [{'shape': 'output/transport/254daf33ca_t0.vtk'}],
 [{'shape': 'output/transport/26e0dc601b_t0.vtk'}],
 [{'shape': 'output/transport/2d086560ad_t0.vtk'}],
 [{'shape': 'output/transport/2edf82ffc5_t0.vtk'}],
 [{'shape': 'output/transport/306d3a97b3_t0.vtk'}],
 [{'shape': 

In [None]:
deterministic_atlas(
    source=str(output_dir / f"{subject_ids[0]}_t0.vtk"),
    targets=mesh_paths,  # ✅ nested dicts, not strings
    subject_id="hippocampus",
    output_dir=atlas_dir,
    max_iter=2000,
    kernel_width=4,
    regularisation=1.0,
    number_of_time_steps=11,
    attachment_kernel_width=2.0,
    initial_step_size=1e-1,
    metric="varifold",
)

Logger has been set to: DEBUG
>> No initial CP spacing given: using diffeo kernel width of 4
OMP_NUM_THREADS found in environment variables. Using value OMP_NUM_THREADS=10
context has already been set
>> No specified state-file. By default, Deformetrica state will by saved in file: output/atlas/deformetrica-state.p.
>> Using a Sobolev gradient for the template data with the ScipyLBFGS estimator memory length being larger than 1. Beware: that can be tricky.
{'max_iterations': 2000, 'freeze_template': False, 'freeze_control_points': False, 'freeze_momenta': False, 'use_sobolev_gradient': True, 'sobolev_kernel_width_ratio': 1, 'max_line_search_iterations': 50, 'initial_control_points': None, 'initial_cp_spacing': None, 'initial_momenta': None, 'dense_mode': False, 'number_of_threads': 1, 'print_every_n_iters': 20, 'downsampling_factor': 1, 'dimension': 3, 'optimization_method_type': 'ScipyLBFGS', 'convergence_tolerance': 1e-05, 'initial_step_size': 0.1, 'gpu_mode': <GpuMode.KERNEL: 4>, 's

RuntimeError: The template object with id shape is not found for the visit 0 of subject 0. Check the dataset xml.

In [42]:
deterministic_atlas(
    source=source,
    targets=targets,
    subject_id="hippocampus",
    output_dir=atlas_dir,
    t0=0,
    max_iter=2000,
    kernel_width=4,
    regularisation=1.0,
    number_of_time_steps=11,
    metric="varifold",
    kernel_type="torch",
    kernel_device="auto",
    tol=1e-10,
    freeze_control_points=False,
    use_rk2_for_flow=False,
    use_rk2_for_shoot=False,
    preserve_volume=False,
    use_svf=False,
    dimension=3,
    print_every=20,
    attachment_kernel_width=2.0,
    initial_step_size=1e-1,
)

cp_atlas = read_2D_array(atlas_dir / 'control_points.txt')
kernel = TorchKernel(kernel_width=registration_args['kernel_width'], device='auto')

transport_result = pd.DataFrame()

Logger has been set to: DEBUG
>> No initial CP spacing given: using diffeo kernel width of 4
OMP_NUM_THREADS found in environment variables. Using value OMP_NUM_THREADS=10
context has already been set
>> No specified state-file. By default, Deformetrica state will by saved in file: output/atlas/deformetrica-state.p.
>> Using a Sobolev gradient for the template data with the ScipyLBFGS estimator memory length being larger than 1. Beware: that can be tricky.
{'max_iterations': 2000, 'freeze_template': False, 'freeze_control_points': False, 'freeze_momenta': False, 'use_sobolev_gradient': True, 'sobolev_kernel_width_ratio': 1, 'max_line_search_iterations': 50, 'initial_control_points': None, 'initial_cp_spacing': None, 'initial_momenta': None, 'dense_mode': False, 'number_of_threads': 1, 'print_every_n_iters': 20, 'downsampling_factor': 1, 'dimension': 3, 'optimization_method_type': 'ScipyLBFGS', 'convergence_tolerance': 1e-10, 'initial_step_size': 0.1, 'gpu_mode': <GpuMode.KERNEL: 4>, 's

RuntimeError: The template object with id shape is not found for the visit 0 of subject 0. Check the dataset xml.

In [None]:
for sid in subject_ids:
    source = str(base_path / sid / 't0.vtk')
    target = str(base_path / sid / 't1.vtk')

    shoot_args['source'] = source
    cp_subject, mom = parallel_transport(
        source=source,
        target=target,
        atlas=str(atlas_dir / 'atlas.vtk'),
        name=sid,
        output_dir=output_dir,
        registration_args=registration_args,
        transport_args=transport_args,
        shoot_args=shoot_args
    )

    # Convolve momenta to velocity field on atlas CP
    vel = kernel.convolve(cp_atlas, cp_subject, mom)
    series = pd.Series([sid] + vel.flatten().tolist())
    transport_result = pd.concat([transport_result, series.to_frame().T], ignore_index=True)

# 4. Save results
n_cp = cp_atlas.shape[1]
transport_result.column_

In [23]:
# === Setup LDDMM + Parallel Transport ===
atlas_cp_path = atlas_path / 'cp.txt'
atlas_cp = read_2D_array(atlas_cp_path)
kernel_width = 4
kernel = TorchKernel(kernel_width=kernel_width, device='auto')

registration_args = dict(
    kernel_width=kernel_width, regularisation=1, max_iter=2000,
    freeze_control_points=False, metric='varifold', attachment_kernel_width=2.,
    tol=1e-10, filter_cp=True, threshold=0.75)

transport_args = {'kernel_type': 'torch', 'kernel_width': kernel_width, 'n_rungs': 10}
shoot_args = {
    'use_rk2_for_flow': True, 'kernel_width': kernel_width,
    'number_of_time_steps': 11, 'write_params': False,
    'source': t0_path
}

cp, mom = parallel_transport(
    t0_path, t1_path, atlas_path / 'atlas', subject_name,
    output_dir, registration_args, transport_args, shoot_args)

vel = kernel.convolve(atlas_cp, cp, mom)

# === Visualize velocity ===
poly = pv.PolyData(atlas_cp.T)
poly['Velocity'] = vel.T

plotter = pv.Plotter()
plotter.add_mesh(poly, scalars=None, render_points_as_spheres=True, point_size=10)
plotter.add_arrows(atlas_cp.T, vel.T, mag=1.0, label='Velocity')
plotter.show()

FileNotFoundError: output/atlas/cp.txt not found.