# LDDMM: how to do regression?

In [1]:
from pathlib import Path

import herbrain.lddmm as lddmm
import herbrain.lddmm.strings as lddmm_strings

import polpo.preprocessing.pd as ppd
from polpo.preprocessing import (
    BranchingPipeline,
    IndexMap,
    IndexSelector,
    Map,
    NestingSwapper,
    PartiallyInitializedStep,
)
from polpo.preprocessing.dict import (
    DictFilter,
    DictMap,
    DictMerger,
    DictToTuplesList,
    Hash,
)
from polpo.preprocessing.load.pregnancy import (
    FigsharePregnancyDataLoader,
    PregnancyPilotSegmentationsLoader,
)
from polpo.preprocessing.mesh.conversion import PvFromData
from polpo.preprocessing.mesh.filter import PvSelectColor
from polpo.preprocessing.mesh.io import PvWriter
from polpo.preprocessing.mesh.registration import PvAlign
from polpo.preprocessing.mesh.smoothing import PvSmoothTaubin
from polpo.preprocessing.mesh.transform import MeshCenterer
from polpo.preprocessing.mri import (
    BRAINSTRUCT2COLOR,
    MeshExtractorFromSegmentedImage,
    MriImageLoader,
)

No CUDA runtime is found, using CUDA_HOME='/usr'


In [2]:
T_MIN = 1.0
T_MAX = 25.0

TEMPLATE_SESSION = 3
TARGET_SESSION = 14

STRUCT_NAME = "PostHipp"

OUTPUTS_DIR = Path("results") / "regression"

REGISTRATION_DIR = OUTPUTS_DIR / "registration"

OUTPUTS_DIR.mkdir(exist_ok=False)

# If not None, uses already computed points (assumes consistency)
CTRL_POINTS_FILE = (
    Path("results") / "registration" / "initial_registration" / lddmm_strings.cp_str
)

## Load predictor

Following [How to load a csv file?](./load_csv.ipynb) and doing preprocessing:

In [3]:
loader = FigsharePregnancyDataLoader(
    data_dir="~/.herbrain/data/pregnancy",
    remote_path="28Baby_Hormones.csv",
    use_cache=True,
)

prep_pipe = (
    ppd.UpdateColumnValues(
        column_name="sessionID", func=lambda entry: int(entry.split("-")[1])
    )
    + ppd.IndexSetter(key="sessionID", drop=True)
    + ppd.ColumnsSelector("gestWeek")
    + ppd.SeriesToDict()
    + DictFilter(lambda value: T_MIN <= value <= T_MAX)
)

predictor = (loader + ppd.CsvReader() + prep_pipe)()

predictor

INFO: Data has already been downloaded... using cached file ('/home/luisfpereira/.herbrain/data/pregnancy/28Baby_Hormones.csv').


{3: 1.0,
 4: 1.5,
 5: 2.0,
 6: 3.0,
 7: 9.0,
 8: 12.0,
 9: 14.0,
 10: 15.0,
 11: 17.0,
 12: 19.0,
 13: 22.0,
 14: 24.0}

## Load meshes

Following data loading of [LDDMM: how to register a mesh against a template?](./lddmm_register_mesh_template.ipynb).

In [5]:
files_pipe = PregnancyPilotSegmentationsLoader(
    predictor.keys(),
    as_dict=True,
)


mri2mesh = MriImageLoader() + MeshExtractorFromSegmentedImage() + PvFromData()

if STRUCT_NAME == -1:
    struct_selector = lambda x: x

else:
    struct_selector = PvSelectColor(
        color=BRAINSTRUCT2COLOR[STRUCT_NAME],
        extract_surface=True,
    )

pipe = files_pipe + DictMap(mri2mesh + struct_selector)

In [6]:
meshes = pipe()

meshes.keys()

INFO: Data has already been downloaded... using cached file ('/home/luisfpereira/.herbrain/data/pregnancy/Segmentations').


dict_keys([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])

## Preprocessing meshes

Following preprocessing of [LDDMM: how to register a mesh against a template?](./lddmm_register_mesh_template.ipynb),
we center, smooth, and rigid align the meshes against the template.

In [7]:
# TODO: consider decimation if above a given number of points

prep_pipe = DictMap(
    MeshCenterer() + PvSmoothTaubin(n_iter=20)
) + PartiallyInitializedStep(
    Step=lambda **kwargs: DictMap(PvAlign(**kwargs)),
    _target=lambda x: x[TEMPLATE_SESSION],
    max_iterations=10,
)

In [8]:
meshes = prep_pipe(meshes)

Save meshes in `vtk` format (as required by `deformetrica`).

In [10]:
meshes_writer = Map(PvWriter(dirname=OUTPUTS_DIR, ext="vtk"))

write_pipe = DictToTuplesList() + BranchingPipeline(
    [
        Map(IndexSelector(0)),
        Map(
            [
                lambda datum: list(datum),
                IndexMap(index=0, step=lambda session: f"mesh_{str(session).zfill(2)}"),
                PvWriter(dirname=OUTPUTS_DIR, ext="vtk"),
            ]
        ),
    ],
    merger=NestingSwapper() + Hash(),
)

mesh_filenames_dict = write_pipe(meshes)

We can now create the dataset:

In [11]:
(times, mesh_filenames) = (DictMerger() + NestingSwapper())(
    [predictor, mesh_filenames_dict]
)

And we also normalize time:

In [12]:
# TODO: do it in a sklearn style
min_time = min(times)
maxmindiff_time = max(times) - min_time

times = [(time_ - min_time) / maxmindiff_time for time_ in times]

In [13]:
# TODO: just to make it run, needs improvement
mesh_filenames = [{"shape": filename} for filename in mesh_filenames]

## LDDMM

### Step 1: find control points

Follows [LDDMM: how to register a mesh against a template?](./lddmm_register_mesh_template.ipynb).

In [14]:
# TODO: need to adapt registration parameters to substructure
registration_kwargs = dict(
    kernel_width=4.0,
    regularisation=1.0,
    max_iter=2000,
    freeze_control_points=False,
    attachment_kernel_width=2.0,
    metric="varifold",
    tol=1e-16,
    filter_cp=True,
    threshold=0.75,
)

if CTRL_POINTS_FILE is not None:
    initial_control_points = CTRL_POINTS_FILE
else:
    lddmm.registration(
        mesh_filenames[TEMPLATE_SESSION],
        mesh_filenames[TARGET_SESSION],
        output_dir=INITIAL_REGISTRATION_DIR,
        **registration_kwargs,
    )
    initial_control_points = INITIAL_REGISTRATION_DIR / lddmm_strings.cp_str

### Step 2: perform regression

In [15]:
spline_kwargs = dict(
    initial_step_size=100,
    regularisation=1.0,
    freeze_external_forces=True,
    freeze_control_points=True,
)

kwargs = registration_kwargs.copy()
kwargs.update(spline_kwargs)

target_weights = [1 / len(times)] * len(times)


lddmm.spline_regression(
    source=mesh_filenames[0]["shape"],
    targets=mesh_filenames,
    output_dir=OUTPUTS_DIR,
    times=times,
    subject_id=[""],
    t0=min(times),
    target_weights=target_weights,
    initial_control_points=initial_control_points,
    **kwargs,
)

Logger has been set to: DEBUG
OMP_NUM_THREADS was not found in environment variables. An automatic value will be set.
OMP_NUM_THREADS will be set to 10
>> Initial t0 set by the user to 0.00 ; note that the mean visit age is 0.46
context has already been set
>> No specified state-file. By default, Deformetrica state will by saved in file: results/regression/deformetrica-state.p.
instantiating kernel torch with kernel_width 4.0 and gpu_mode GpuMode.KERNEL. addr: 0x7d265fde2f10
instantiating kernel torch with kernel_width 2.0 and gpu_mode GpuMode.KERNEL. addr: 0x7d265fdc8ad0
>> Reading 110 initial control points from file results/registration/initial_registration/DeterministicAtlas__EstimatedParameters__ControlPoints.txt.
>> Momenta initialized to zero.
dtype=float32
>> Started estimator: ScipyOptimize

>> Scipy optimization method: L-BFGS-B

------------------------------------- Iteration: 1 -------------------------------------

------------------------------------- Iteration: 20 ------

time.struct_time(tm_year=2025, tm_mon=3, tm_mday=18, tm_hour=2, tm_min=24, tm_sec=41, tm_wday=1, tm_yday=77, tm_isdst=0)

## Further reading

* [LDDMM: how to visualize regression results?](./ldddm_regression_viz.ipynb)