In [85]:
import jax, numpy as np, scipy, rvmep, jax.numpy as jnp, jax.scipy, jax.scipy.optimize, jaxopt
import pyvista
from rvmep import computeStrain, anatomicalDirections, tools

In [394]:
m1_ed = pyvista.PolyData('Data/exp1_ed_ref.vtk')
m1_es = pyvista.PolyData('Data/exp1_es_ref.vtk')
m2_ed = pyvista.PolyData('Data/exp2_ed_ref.vtk')
m2_es = pyvista.PolyData('Data/exp2_es_ref.vtk')

In [193]:
E = computeStrain.computeDeformationTensor(m1_ed, m1_es)
apexId, pointsTricuspid, pointsPulmonary, valvesPointsId = tools.getTomtecApexValvePointsRV()


m1_long, m1_circ = anatomicalDirections.computeAnatomicalDirectionsHeatEquation(m1_ed,apexId, valvesPointsId)
dirs_m1 = np.stack((m1_long, m1_circ), axis = 2)
m2_long, m2_circ = anatomicalDirections.computeAnatomicalDirectionsHeatEquation(m2_ed,apexId,valvesPointsId)
dirs_m2 = np.stack((m2_long, m2_circ), axis = 2)

valid_triangles = np.zeros(m1_ed.n_cells, dtype = bool)
for i, t in enumerate(m1_ed.faces.reshape((-1, 4))[:, 1:]):
    valid_triangles[i] = all([p not in valvesPointsId for p in t ])

G = computeStrain.computeStrainTensorGreen(m1_ed, m1_es)
G_anatomic =  np.einsum('nji,njk, nkl->nil', dirs_m1,  G, dirs_m1)[valid_triangles]
dirs_m2 = dirs_m2[valid_triangles]

In [396]:
X = jnp.array(m2_ed.points)
D = np.array([[-1, 1, 0], [-1, 0, 1]])
E_prev = np.zeros((np.sum(valid_triangles), 3, 3))
triangles = m2_ed.faces.reshape((-1, 4))[:, 1:][valid_triangles]
for i, t in enumerate(triangles):
    E_prev[i] = np.linalg.pinv(D@ m2_ed.points[t]) @ D

def loss(x, X_ref = m2_ed.points, E_prev = E_prev, eps = 1e-5, triangles = triangles):
    X = x.reshape((-1, 3))
    X_by_triangle = X[triangles]
    FFT_anatomic =  jnp.einsum('nji,njr,nrk, nlk, ntl,nto ->nio', dirs_m2,  E_prev, X_by_triangle,X_by_triangle,E_prev,dirs_m2)
    G_new_anatomic = (FFT_anatomic - jnp.eye(2))/2
    return jnp.sum((G_new_anatomic - G_anatomic) * (G_new_anatomic - G_anatomic) ) + eps * jnp.sum((X - X_ref) * (X - X_ref))
loss_jit = jax.jit(loss)

In [397]:
%timeit loss(X)
%timeit loss_jit(X)

1.99 ms ± 11.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
82.3 µs ± 370 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [398]:
opt = jaxopt.BFGS(loss_jit, maxiter = 1000)
X = jax.numpy.array(m2_ed.points)
r = opt.run(X)

In [402]:
r[1]

BfgsState(iter_num=Array(919, dtype=int32, weak_type=True), value=Array(1.1965063, dtype=float32), grad=Array([[ 6.0869716e-06, -2.5997870e-05, -3.1843723e-05],
       [ 9.8224245e-06,  2.0676438e-05,  6.7161745e-06],
       [ 1.6551377e-05, -2.2108958e-05,  9.5137966e-06],
       ...,
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00]], dtype=float32), stepsize=Array(1., dtype=float32), error=Array(0.00096062, dtype=float32), H=Array([[ 1.8267883 , -0.06006246, -0.24014543, ...,  0.        ,
         0.        ,  0.        ],
       [-0.06006714,  2.4050725 ,  1.1119647 , ...,  0.        ,
         0.        ,  0.        ],
       [-0.24014798,  1.111931  ,  2.8374205 , ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  1.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0

In [399]:
def numpyToPyvista(x, triangles):
    return pyvista.PolyData(np.asarray(x).reshape((-1, 3)), np.hstack((np.ones((len(triangles), 1)) * 3, triangles)).astype(int).flatten())
pv = numpyToPyvista(r.params, triangles)
pv.save('Data/strain0.vtk')

In [269]:
pv

PolyData,Information
N Cells,1587
N Points,938
N Strips,0
X Bounds,"-2.196e+01, 3.072e+01"
Y Bounds,"-3.821e+01, 4.130e+01"
Z Bounds,"-4.791e+01, 3.458e+01"
N Arrays,0


In [232]:
eps = 1e-5
X_ref = m2_ed.points

X_by_triangle = X[triangles]
G_new_anatomic =  jnp.einsum('nji,njr, nrk, nkl->nil', dirs_m2,  E_prev,  X_by_triangle, dirs_m2)
jnp.linalg.norm(G_new_anatomic - G_anatomic) + eps * jnp.linalg.norm(X - X_ref)

ValueError: Einstein sum subscript 'nrk' does not contain the correct number of indices for operand 2.