In [28]:
import time
from os.path import join
import sys
import pandas as pd
import nibabel.freesurfer
from pathlib import Path

import numpy as np
import pyvista as pv
import torch
from api.deformetrica import Deformetrica
from in_out.array_readers_and_writers import read_2D_array, read_3D_array
from launch.compute_parallel_transport import compute_pole_ladder
from launch.compute_shooting import compute_shooting
from support.kernels.torch_kernel import TorchKernel

In [12]:
# === CONFIG ===
# subject_name = 'sub-001'
structure_id = 17  # Left-Hippocampus
brain_structure_filename = f"resliced_mesh_{structure_id}"
shapes_dir = Path('/Users/sak/.herbrain/data/pregnancy/neuromaternal_madrid_2021/derivatives/enigma_shape')
output_dir = Path('./output')
atlas_path = Path('./output/atlas') 

In [13]:
# === Mesh Conversion ===
def nibabel_to_pyvista(mesh):
    V, F = mesh
    faces = np.hstack(np.c_[np.full(len(F), 3), F]).astype(np.int64)
    return pv.PolyData(V, faces)

In [21]:
import re

all_subject_meshes = {}

# Extract unique subject IDs
subject_ids = set()
for subdir in os.listdir(shapes_dir):
    # print(subdir)
    match = re.match(r"sub-(\w+)_ses-[34]", subdir)
    if match:
        # print("match!")
        subject_ids.add(match.group(1))

print(subject_ids)
print(len(subject_ids))
for subject_id in sorted(subject_ids):
    sub_meshes = {}
    
    ses3_path = shapes_dir / f"sub-{subject_id}_ses-3" / "resliced_mesh_17"
    ses4_path = shapes_dir / f"sub-{subject_id}_ses-4" / "resliced_mesh_17"

    if not ses3_path.exists() or not ses4_path.exists():
        print("HELP :(")
        continue

    mesh3 = nibabel.freesurfer.read_geometry(str(ses3_path))
    mesh4 = nibabel.freesurfer.read_geometry(str(ses4_path))

    pv_mesh3 = nibabel_to_pyvista(mesh3)
    pv_mesh4 = nibabel_to_pyvista(mesh4)

    # Align ses-4 to ses-3
    aligned_mesh4 = pv_mesh4.align(
        pv_mesh3,
        max_landmarks=100,
        max_mean_distance=1e-5,
        max_iterations=500,
        check_mean_distance=True,
        start_by_matching_centroids=True,
    )

    sub_meshes['t0'] = pv_mesh3
    sub_meshes['t1'] = aligned_mesh4
    all_subject_meshes[subject_id] = sub_meshes

{'ccce0782e0', '0872b8db24', '17641fc443', 'a8fd90daa4', 'fcad4b7f8c', 'b75569792e', 'b694dab0c4', '6fd380db03', '254daf33ca', 'f209c3fa4b', '67eb38f7e9', '6733e7702c', 'e4608ddfa3', 'cc8bfad990', 'f6c2d6696b', '343a5d44e7', '3c92b464a0', '049ab1a591', 'a32661b55d', '66ed43826b', 'e0e4067b5d', 'bf8f724976', '082dc070c1', '15255749dc', '3f68c4406d', '63db63081a', '39dd102cbf', 'c152c325f4', '4cede517b5', '8d5824be35', '5aa6c7faa1', '6e5a9fb778', '7304a751ea', '4cf50fd8bc', '9ae83e0927', 'efe5f05d66', '31879f2221', '26e0dc601b', 'b253f8ccbc', '6dc2a2528b', 'c01a46c7f9', '717e805d07', '5f42d47ec5', '690eec56c3', 'b62f625bc7', 'bf028a52b9', '7883dbc6da', '6b9079271c', '47fb269f0c', '68f8586911', '96fab186d7', 'ed2a398ef5', 'f922d45e9a', '604e89c97e', '74ad7c8895', 'fb91476aa2', '4732ae8cde', 'c1be0ff860', '3c4a4a2b19', 'c4c18afc24', 'd367eabca4', 'c7b1145ac8', '471fecf35b', 'f7a4c2297b', '4c1ea69777', '718923d847', 'e72d2b2eff', 'a9ae29e7b8', '9e143ef040', '7bd2334282', 'b0d275dca1', '471d

In [47]:
# all_subject_meshes

In [31]:
# File paths and settings
# subject_ids = ['subj01', 'subj02', 'subj03']
output_dir = Path("output/transport")
output_dir.mkdir(parents=True, exist_ok=True)
atlas_dir = Path("output/atlas")
atlas_dir.mkdir(parents=True, exist_ok=True)

structure = 'hippocampus'

# LDDMM settings
registration_args = {
    'kernel_width': 4,
    'regularisation': 1,
    'max_iter': 2000,
    'freeze_control_points': False,
    'metric': 'varifold',
    'attachment_kernel_width': 2.0,
    'tol': 1e-10,
    'filter_cp': True,
    'threshold': 0.75,
}

transport_args = {
    'kernel_type': 'torch',
    'kernel_width': 4,
    'n_rungs': 10
}

shoot_args = {
    'use_rk2_for_flow': True,
    'kernel_width': 4,
    'number_of_time_steps': 11,
    'write_params': False
}

# Use in-memory meshes
subject_ids = list(all_subject_meshes.keys())

data_set = []
for sid in subject_ids:
    t0_mesh = all_subject_meshes[sid]['t0']
    t0_path = output_dir / f"sub-{sid}_t0.vtk"
    t0_mesh.save(t0_path)
    # write_vtk(t0_path, t0_mesh)
    data_set.append({'shape': str(t0_path)})

In [39]:
source = str(output_dir / f"{subject_ids[0]}_t0.vtk")
targets = [str(output_dir / f"{sid}_t0.vtk") for sid in subject_ids]

In [41]:
targets

['output/transport/049ab1a591_t0.vtk',
 'output/transport/073801a359_t0.vtk',
 'output/transport/082dc070c1_t0.vtk',
 'output/transport/085a9ee2fc_t0.vtk',
 'output/transport/0872b8db24_t0.vtk',
 'output/transport/0a9b5eb6f1_t0.vtk',
 'output/transport/0b6820582a_t0.vtk',
 'output/transport/11a123772c_t0.vtk',
 'output/transport/15255749dc_t0.vtk',
 'output/transport/15901c4398_t0.vtk',
 'output/transport/17641fc443_t0.vtk',
 'output/transport/17ca82ae03_t0.vtk',
 'output/transport/1db8e32656_t0.vtk',
 'output/transport/22cdbe0c1b_t0.vtk',
 'output/transport/254daf33ca_t0.vtk',
 'output/transport/26e0dc601b_t0.vtk',
 'output/transport/2d086560ad_t0.vtk',
 'output/transport/2edf82ffc5_t0.vtk',
 'output/transport/306d3a97b3_t0.vtk',
 'output/transport/31879f2221_t0.vtk',
 'output/transport/343a5d44e7_t0.vtk',
 'output/transport/39dd102cbf_t0.vtk',
 'output/transport/3b4d062226_t0.vtk',
 'output/transport/3c4a4a2b19_t0.vtk',
 'output/transport/3c92b464a0_t0.vtk',
 'output/transport/3cb624

In [49]:
mesh_paths = [[{"shape": str(output_dir / f"{sid}_t0.vtk")}] for sid in subject_ids]
mesh_paths

[[{'shape': 'output/transport/049ab1a591_t0.vtk'}],
 [{'shape': 'output/transport/073801a359_t0.vtk'}],
 [{'shape': 'output/transport/082dc070c1_t0.vtk'}],
 [{'shape': 'output/transport/085a9ee2fc_t0.vtk'}],
 [{'shape': 'output/transport/0872b8db24_t0.vtk'}],
 [{'shape': 'output/transport/0a9b5eb6f1_t0.vtk'}],
 [{'shape': 'output/transport/0b6820582a_t0.vtk'}],
 [{'shape': 'output/transport/11a123772c_t0.vtk'}],
 [{'shape': 'output/transport/15255749dc_t0.vtk'}],
 [{'shape': 'output/transport/15901c4398_t0.vtk'}],
 [{'shape': 'output/transport/17641fc443_t0.vtk'}],
 [{'shape': 'output/transport/17ca82ae03_t0.vtk'}],
 [{'shape': 'output/transport/1db8e32656_t0.vtk'}],
 [{'shape': 'output/transport/22cdbe0c1b_t0.vtk'}],
 [{'shape': 'output/transport/254daf33ca_t0.vtk'}],
 [{'shape': 'output/transport/26e0dc601b_t0.vtk'}],
 [{'shape': 'output/transport/2d086560ad_t0.vtk'}],
 [{'shape': 'output/transport/2edf82ffc5_t0.vtk'}],
 [{'shape': 'output/transport/306d3a97b3_t0.vtk'}],
 [{'shape': 

In [None]:
deterministic_atlas(
    source=str(output_dir / f"{subject_ids[0]}_t0.vtk"),
    targets=mesh_paths,  # ✅ nested dicts, not strings
    subject_id="hippocampus",
    output_dir=atlas_dir,
    max_iter=2000,
    kernel_width=4,
    regularisation=1.0,
    number_of_time_steps=11,
    attachment_kernel_width=2.0,
    initial_step_size=1e-1,
    metric="varifold",
)

Logger has been set to: DEBUG
>> No initial CP spacing given: using diffeo kernel width of 4
OMP_NUM_THREADS found in environment variables. Using value OMP_NUM_THREADS=10
context has already been set
>> No specified state-file. By default, Deformetrica state will by saved in file: output/atlas/deformetrica-state.p.
>> Using a Sobolev gradient for the template data with the ScipyLBFGS estimator memory length being larger than 1. Beware: that can be tricky.
{'max_iterations': 2000, 'freeze_template': False, 'freeze_control_points': False, 'freeze_momenta': False, 'use_sobolev_gradient': True, 'sobolev_kernel_width_ratio': 1, 'max_line_search_iterations': 50, 'initial_control_points': None, 'initial_cp_spacing': None, 'initial_momenta': None, 'dense_mode': False, 'number_of_threads': 1, 'print_every_n_iters': 20, 'downsampling_factor': 1, 'dimension': 3, 'optimization_method_type': 'ScipyLBFGS', 'convergence_tolerance': 1e-05, 'initial_step_size': 0.1, 'gpu_mode': <GpuMode.KERNEL: 4>, 's

RuntimeError: The template object with id shape is not found for the visit 0 of subject 0. Check the dataset xml.

In [42]:
deterministic_atlas(
    source=source,
    targets=targets,
    subject_id="hippocampus",
    output_dir=atlas_dir,
    t0=0,
    max_iter=2000,
    kernel_width=4,
    regularisation=1.0,
    number_of_time_steps=11,
    metric="varifold",
    kernel_type="torch",
    kernel_device="auto",
    tol=1e-10,
    freeze_control_points=False,
    use_rk2_for_flow=False,
    use_rk2_for_shoot=False,
    preserve_volume=False,
    use_svf=False,
    dimension=3,
    print_every=20,
    attachment_kernel_width=2.0,
    initial_step_size=1e-1,
)

cp_atlas = read_2D_array(atlas_dir / 'control_points.txt')
kernel = TorchKernel(kernel_width=registration_args['kernel_width'], device='auto')

transport_result = pd.DataFrame()

Logger has been set to: DEBUG
>> No initial CP spacing given: using diffeo kernel width of 4
OMP_NUM_THREADS found in environment variables. Using value OMP_NUM_THREADS=10
context has already been set
>> No specified state-file. By default, Deformetrica state will by saved in file: output/atlas/deformetrica-state.p.
>> Using a Sobolev gradient for the template data with the ScipyLBFGS estimator memory length being larger than 1. Beware: that can be tricky.
{'max_iterations': 2000, 'freeze_template': False, 'freeze_control_points': False, 'freeze_momenta': False, 'use_sobolev_gradient': True, 'sobolev_kernel_width_ratio': 1, 'max_line_search_iterations': 50, 'initial_control_points': None, 'initial_cp_spacing': None, 'initial_momenta': None, 'dense_mode': False, 'number_of_threads': 1, 'print_every_n_iters': 20, 'downsampling_factor': 1, 'dimension': 3, 'optimization_method_type': 'ScipyLBFGS', 'convergence_tolerance': 1e-10, 'initial_step_size': 0.1, 'gpu_mode': <GpuMode.KERNEL: 4>, 's

RuntimeError: The template object with id shape is not found for the visit 0 of subject 0. Check the dataset xml.

In [None]:
for sid in subject_ids:
    source = str(base_path / sid / 't0.vtk')
    target = str(base_path / sid / 't1.vtk')

    shoot_args['source'] = source
    cp_subject, mom = parallel_transport(
        source=source,
        target=target,
        atlas=str(atlas_dir / 'atlas.vtk'),
        name=sid,
        output_dir=output_dir,
        registration_args=registration_args,
        transport_args=transport_args,
        shoot_args=shoot_args
    )

    # Convolve momenta to velocity field on atlas CP
    vel = kernel.convolve(cp_atlas, cp_subject, mom)
    series = pd.Series([sid] + vel.flatten().tolist())
    transport_result = pd.concat([transport_result, series.to_frame().T], ignore_index=True)

# 4. Save results
n_cp = cp_atlas.shape[1]
transport_result.column_

In [23]:
# === Setup LDDMM + Parallel Transport ===
atlas_cp_path = atlas_path / 'cp.txt'
atlas_cp = read_2D_array(atlas_cp_path)
kernel_width = 4
kernel = TorchKernel(kernel_width=kernel_width, device='auto')

registration_args = dict(
    kernel_width=kernel_width, regularisation=1, max_iter=2000,
    freeze_control_points=False, metric='varifold', attachment_kernel_width=2.,
    tol=1e-10, filter_cp=True, threshold=0.75)

transport_args = {'kernel_type': 'torch', 'kernel_width': kernel_width, 'n_rungs': 10}
shoot_args = {
    'use_rk2_for_flow': True, 'kernel_width': kernel_width,
    'number_of_time_steps': 11, 'write_params': False,
    'source': t0_path
}

cp, mom = parallel_transport(
    t0_path, t1_path, atlas_path / 'atlas', subject_name,
    output_dir, registration_args, transport_args, shoot_args)

vel = kernel.convolve(atlas_cp, cp, mom)

# === Visualize velocity ===
poly = pv.PolyData(atlas_cp.T)
poly['Velocity'] = vel.T

plotter = pv.Plotter()
plotter.add_mesh(poly, scalars=None, render_points_as_spheres=True, point_size=10)
plotter.add_arrows(atlas_cp.T, vel.T, mag=1.0, label='Velocity')
plotter.show()

FileNotFoundError: output/atlas/cp.txt not found.