# Read Data

In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=False
%env JAX_ENABLE_X64=True

from typing import List, NamedTuple
import os
import pyvista as pv
pv.start_xvfb()

from morphomatics.geom import Surface

path = '/data/visual/online/projects/shape_trj/ventricle_MV/2_echo_LV'
path = '../../../data/ventricle_MV/2_echo_LV'

class Subject(NamedTuple):
    systole: Surface
    diastole: Surface

# create Surface from pv.PolyData
to_surf = lambda mesh: Surface(mesh.points, mesh.faces.reshape(-1, 4)[:, 1:])

subjects: List[Subject] = []
for d in os.listdir(path):
    files = os.listdir(f'{path}/{d}')
    files.sort()
    sub = Subject(*[to_surf(pv.read(f'{path}/{d}/{f}')) for f in files])
    subjects.append(sub)

env: XLA_PYTHON_CLIENT_PREALLOCATE=False
env: JAX_ENABLE_X64=True


# Determine Reference Shape

In [2]:
from ipywidgets import interact

import numpy as np
import jax.numpy as jnp

from morphomatics.manifold import FundamentalCoords, DifferentialCoords
from morphomatics.stats import StatisticalShapeModel

# construct SSM for diastole configurations
SSM = StatisticalShapeModel(lambda ref: FundamentalCoords(ref))
SSM.construct([s.diastole for s in subjects])
print(f'variances: {SSM.variances}')

# show SSM
pl = pv.Plotter(notebook=True)
mesh = pv.PolyData(np.asarray(SSM.mean.v), np.c_[np.full(len(SSM.mean.f), 3), SSM.mean.f])
pl.add_mesh(mesh)
viewer = pl.show(jupyter_backend='ipygany', return_viewer=True)

dev = np.sqrt(SSM.variances[0])
@interact
def plot(t=(-dev,dev,dev/10)):
    # interpolate: excite 1st principal geodesic mode
    v = SSM.space.from_coords(SSM.space.connec.exp(SSM.mean_coords, t*SSM.modes[0]))
    viewer.children[0].vertices = np.asarray(v)

viewer



12.922552288178787
0.3827014279413084
0.003720572252159453
tol 0.01993782764566655 reached
variances: [0.02045371 0.01286745 0.00793972 0.00513852 0.00433427 0.00273761
 0.00196636 0.00107949]


interactive(children=(FloatSlider(value=0.0, description='t', max=0.1430164604671463, min=-0.1430164604671463,…

Scene(background_color='#4c4c4c', camera={'position': [68.39181357543289, 77.4130229895273, 113.39347992072932…

# Construct statistical shape trajectory model (SSTM)

In [3]:
from morphomatics.geom import BezierSpline
from morphomatics.manifold import TangentBundle, ShapeSpace
from morphomatics.stats import PrincipalGeodesicAnalysis as PGA

M: ShapeSpace = SSM.space
TM = TangentBundle(M)

# compute mean and main modes of variation
pts2vec = lambda p, q: [p, M.connec.log(p, q)]
data = jnp.array([pts2vec(M.to_coords(s.diastole.v), M.to_coords(s.systole.v)) for s in subjects])
pga = PGA(TM, data)

## Visualize mean trajectory

In [4]:
exp = lambda p, v, t: M.from_coords(M.connec.exp(p, t*v))

pl = pv.Plotter(notebook=True)
pl.add_mesh(mesh)
viewer = pl.show(jupyter_backend='ipygany', return_viewer=True)

@interact
def plot(t=(0.,1.,.1)):
    viewer.children[0].vertices = np.asarray(exp(*pga.mean, t))

viewer

interactive(children=(FloatSlider(value=0.5, description='t', max=1.0), Output()), _dom_classes=('widget-inter…

Scene(background_color='#4c4c4c', camera={'position': [68.39181357543289, 77.4130229895273, 113.39347992072932…

## Visualize 1st principal mode of SSTM

In [6]:
pl = pv.Plotter(notebook=True)
pl.add_mesh(mesh, color='red')
pl.add_mesh(mesh.copy())
viewer = pl.show(jupyter_backend='ipygany', return_viewer=True)

s_ = 0
i_ = 0
excited = pga.mean

@interact
def plot(t=(0.,1.,.1), i=(0,len(subjects)-1,1), s=(-3,3,.3)):
    global i_, s_, excited
    if s != s_ or i != i_:
        s_ = s
        i_ = i
        dev = np.sqrt(pga.variances[i])
        excited = TM.connec.exp(pga.mean, s*dev*pga.modes[i])
    viewer.children[0].vertices = np.asarray(exp(*excited, t))
    viewer.children[1].vertices = np.asarray(exp(*pga.mean, t)) + np.array([100.,0.,0.])[None]
    

viewer

interactive(children=(FloatSlider(value=0.5, description='t', max=1.0), IntSlider(value=4, description='i', ma…

Scene(background_color='#4c4c4c', camera={'position': [68.39181357543289, 77.4130229895273, 113.39347992072932…

# Show input trajectories 

In [5]:
pl = pv.Plotter(notebook=True)
pl.add_mesh(mesh)
viewer = pl.show(jupyter_backend='ipygany', return_viewer=True)

@interact
def plot(t=(0.,1.,.1), idx=(0,len(data)-1,1)):
    viewer.children[0].vertices = np.asarray(exp(*data[idx], t))

viewer

interactive(children=(FloatSlider(value=0.5, description='t', max=1.0), IntSlider(value=4, description='idx', …

Scene(background_color='#4c4c4c', camera={'position': [70.13329805197279, 72.69836502946262, 112.1941819930345…