# Read Data

In [5]:
%env JAX_ENABLE_X64=True

from typing import List, NamedTuple
import os
import pyvista as pv

from morphomatics.geom import Surface

path = '/data/visual/online/projects/shape_trj/ventricle_MV/2_echo_LV'

class Subject(NamedTuple):
    systole: Surface
    diastole: Surface

# create Surface from pv.PolyData
to_surf = lambda mesh: Surface(mesh.points, mesh.faces.reshape(-1, 4)[:, 1:])

subjects: List[Subject] = []
for d in os.listdir(path):
    files = os.listdir(f'{path}/{d}')
    files.sort()
    sub = Subject(*[to_surf(pv.read(f'{path}/{d}/{f}')) for f in files])
    subjects.append(sub)

env: JAX_ENABLE_X64=True


# Determine Reference Shape

In [6]:
from ipywidgets import interact

import numpy as np
import jax.numpy as jnp

from morphomatics.manifold import FundamentalCoords, DifferentialCoords
from morphomatics.stats import StatisticalShapeModel

# construct SSM for diastole configurations
SSM = StatisticalShapeModel(lambda ref: DifferentialCoords(ref))
SSM.construct([s.diastole for s in subjects])
print(f'variances: {SSM.variances}')

# show SSM
pl = pv.Plotter(notebook=True)
mesh = pv.PolyData(np.asarray(SSM.mean.v), np.c_[np.full(len(SSM.mean.f), 3), SSM.mean.f])
pl.add_mesh(mesh)
viewer = pl.show(jupyter_backend='ipygany', return_viewer=True)

dev = np.sqrt(SSM.variances[0])
@interact
def plot(t=(-dev,dev,dev/10)):
    # interpolate: excite 1st principal geodesic mode
    v = SSM.space.from_coords(SSM.space.connec.exp(SSM.mean_coords, t*SSM.modes[0]))
    viewer.children[0].vertices = np.asarray(v)

viewer

12.086290492774422
0.6852953940161997
0.09134331739290773
0.025202615292526216
0.02249216238716494
0.022001253914874258
0.02345693510186036 > 0.022001253914874258 --> divergence
variances: [7.24471167e+02 5.19666191e+02 3.15014672e+02 1.93083761e+02
 1.29079316e+02 9.83369661e+01 7.92818487e+01 5.48260584e+01
 2.15675218e-02]


interactive(children=(FloatSlider(value=0.0, description='t', max=26.91600206295891, min=-26.91600206295891, s…

Scene(background_color='#4c4c4c', camera={'position': [70.78602653197838, 74.33933922695165, 111.9984249706177…

# Construct statistical shape trajectory model (SSTM)

In [None]:
from morphomatics.geom import BezierSpline
from morphomatics.manifold import Bezierfold, ShapeSpace

M: ShapeSpace = SSM.space
B = Bezierfold(M, 1, 1)

# map: shape pairs -> splines -> pts. in Bezierfold 
splines = [BezierSpline(M, jnp.array([[M.to_coords(s.diastole.v), M.to_coords(s.systole.v)]])) for s in subjects]
trjs = jnp.asarray([B.to_coords(s) for s in splines])

# compute mean and Gram matrix
mean, geos = Bezierfold.FunctionalBasedStructure.mean(B, trjs)
G = B.metric.gram(mean, geos)

0: 0.3315946378782998
1: 0.29787299027276193
2: 0.30869938513116724
3: 0.3152406746620049
4: 0.2904820668869489
5: 0.2919339614302815
6: 0.30347541242671217
7: 0.31452418745339217
8: 0.21645786368959885
9: 0.2926248330297124
10: 0.2718673627557189
11: 0.23655150860907181
12: 0.26397293176812897
13: 0.2927235793624464
14: 0.2098204594024026
15: 0.22395345954041085
16: 0.17032824010971603
17: 0.19456093011982478
18: 0.2742982360701446
19: 0.0990668205360318
20: 0.15425847980901236
21: 0.1820995529915339
22: 0.09701540608927904
23: 0.14066926443902095
24: 0.13193834683842226
25: 0.2627679389527597
26: 0.09186443419727179
27: 0.09867870640791815
28: 0.07507871882547962
29: 0.1072489912663972
30: 0.21033604228550035
31: 0.04112448200757293
32: 0.07030436096007502
33: 0.02674240077540887
34: 0.054040746282997265
35: 0.05484377769000969
36: 0.07130728617118771
37: 0.30564023074233876
38: 0.10658397446554133
39: 0.11950294555009466
40: 0.06753731968297724
41: 0.1166668151502317
42: 0.063726277

## Visualize mean trajectory

In [4]:
mu: BezierSpline = B.from_coords(mean)
    
pl = pv.Plotter(notebook=True)
pl.add_mesh(mesh)
viewer = pl.show(jupyter_backend='ipygany', return_viewer=True)

@interact
def plot(t=(0.,1.,.1)):
    v = M.from_coords(mu.eval(t))
    viewer.children[0].vertices = np.asarray(v)

viewer

interactive(children=(FloatSlider(value=0.5, description='t', max=1.0), Output()), _dom_classes=('widget-inter…

Scene(background_color='#4c4c4c', camera={'position': [69.84129523599651, 76.87526928002913, 116.2644064186490…

In [5]:
import jax
from morphomatics.manifold import PowerManifold

# eigen decomposition of Gram matrix
vals, vecs = jnp.linalg.eigh(G)
print('variances:', vals)

# compute modes
N = PowerManifold(M, B.K+1)
logs = jax.vmap(N.connec.log, (None, 0))(mean, geos[:,1]) * B.nsteps
modes = jnp.diag(1/jnp.sqrt(len(G)*vals[::-1])) @ vecs[:,::-1].T @ logs.reshape(len(G),-1)
modes = modes.reshape((len(G),)+N.point_shape)

variances: [0.00010014 0.00295755 0.0044298  0.00596598 0.00836817 0.01321803
 0.01437479 0.03599146 0.04748391]


## Visualize 1st principal mode of SSTM

In [4]:
pl = pv.Plotter(notebook=True)
pl.add_mesh(mesh, color='red')
pl.add_mesh(mesh.copy())
viewer = pl.show(jupyter_backend='ipygany', return_viewer=True)

dev = np.sqrt(vals[-1])
# excited = B.from_coords(B.metric.exp(mean, 2*dev*modes[-1]))
excited = B.from_coords(N.metric.exp(mean, 2*dev*modes[-1])) # extrinsic exp.


@interact
def plot(t=(0.,1.,.1)):
    viewer.children[0].vertices = np.asarray(M.from_coords(mu.eval(t)))
    viewer.children[1].vertices = np.asarray(M.from_coords(excited.eval(t)))

viewer

NameError: name 'vals' is not defined