In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
import sys, os, pathlib
os.environ['PKG_CONFIG_PATH'] = '/ocean/projects/asc170022p/mtragoza/mambaforge/envs/lung-project/lib/pkgconfig'

import numpy as np
import pandas as pd
import nibabel as nib
import torch

sys.path.append('../..')
import project

torch.cuda.is_available()

  import pkg_resources


True

In [3]:
# nb configuration
data_root = pathlib.Path('/ocean/projects/asc170022p/mtragoza/lung-project/data/COPDGene')
data_file = data_root / 'sample1000_2025-07-22.csv'
example_idx = 0
source_state = 'INSP'
target_state = 'EXP'

In [4]:
df = pd.read_csv(data_file, sep='\t', low_memory=False)
dataset = project.copdgene.COPDGeneDataset(df, data_root)
row, visit = dataset[example_idx]
row

sid               16514P
ccenter              TEM
kernel               STD
Emphysema              0
pctEmph           0.3373
pctEmph_Slicer    0.3147
FEV1pp_utah         77.6
FVCpp_utah          88.9
FEV1_FVC_utah       0.71
finalGold           -1.0
catEmph           normal
Name: 0, dtype: object

In [5]:
source_image = visit.load_image(variant='ISO', state=source_state, recon='STD')
target_image = visit.load_image(variant='ISO', state=target_state, recon='STD')
s_mask_image = visit.load_mask(variant='ISO', state=source_state, recon='STD', mask_name='lung_regions')
t_mask_image = visit.load_mask(variant='ISO', state=target_state, recon='STD', mask_name='lung_regions')
disp_image  = visit.load_displacement_field(variant='ISO', target_state=target_state, source_state=source_state, recon='STD')

source_mesh = visit.load_mesh(variant='ISO', state=source_state, recon='STD', mask_name='lung_regions', mesh_tag='volume')
target_mesh = visit.load_mesh(variant='ISO', state=target_state, recon='STD', mask_name='lung_regions', mesh_tag='volume')

In [6]:
source_array = source_image.get_fdata()
target_array = target_image.get_fdata()
s_mask_array = s_mask_image.get_fdata()
t_mask_array = t_mask_image.get_fdata()
disp_array = disp_image.get_fdata()

In [7]:
%autoreload
density_array = project.segmentation.compute_density_map(target_array)

In [11]:
import pyvista as pv

pv_grid = pv.ImageData(
    dimensions=target_image.shape,
    spacing=np.linalg.norm(target_image.affine[:3,:3], axis=1)
)
pv_grid.point_data['CT'] = target_array.flatten(order='F')
pv_grid.point_data['density'] = density_array.flatten(order='F')
pv_grid

Header,Data Arrays
"ImageDataInformation N Cells36096652 N Points36429056 X Bounds0.000e+00, 3.670e+02 Y Bounds0.000e+00, 3.670e+02 Z Bounds0.000e+00, 2.680e+02 Dimensions368, 368, 269 Spacing1.000e+00, 1.000e+00, 1.000e+00 N Arrays2",NameFieldTypeN CompMinMax CTPointsfloat641-1.160e+033.567e+03 densityPointsfloat6411.000e+004.567e+03

ImageData,Information
N Cells,36096652
N Points,36429056
X Bounds,"0.000e+00, 3.670e+02"
Y Bounds,"0.000e+00, 3.670e+02"
Z Bounds,"0.000e+00, 2.680e+02"
Dimensions,"368, 368, 269"
Spacing,"1.000e+00, 1.000e+00, 1.000e+00"
N Arrays,2

Name,Field,Type,N Comp,Min,Max
CT,Points,float64,1,-1160.0,3567.0
density,Points,float64,1,1.0,4567.0


In [12]:
def get_opacity_values(vmin, vmax, center, width, low=0.0, high=1.0, n=201):
    x = np.linspace(vmin, vmax, n)
    a = low + (high - low) * sigmoid((x - center) / width)
    return a.tolist()

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

def gaussian(x):
    return np.exp(-x**2)

get_opacity_values(-1000, 1000, 0, 100, n=11)

[4.5397868702434395e-05,
 0.0003353501304664781,
 0.0024726231566347743,
 0.01798620996209156,
 0.11920292202211755,
 0.5,
 0.8807970779778823,
 0.9820137900379085,
 0.9975273768433653,
 0.9996646498695336,
 0.9999546021312976]

In [13]:
vmin, vmax = (-1000, 3000)
cmap = 'jet'

center  = (vmin + vmax) / 2
width   = (center - vmin) / 8
opacity = get_opacity_values(vmin, vmax, center, width)

p = pv.Plotter()
p.add_volume(pv_grid, name='vol', scalars='CT', cmap=cmap, clim=(vmin, vmax), opacity=opacity)

def update_opacity(center):
    width = (center - vmin) / 8
    opacity = get_opacity_values(vmin, vmax, center, width)
    p.remove_actor('vol', render=False)
    p.add_volume(pv_grid, name='vol', scalars='CT', cmap=cmap, clim=(vmin, vmax), opacity=opacity)

p.add_slider_widget(update_opacity, rng=(vmin, vmax), value=center)
p.show()

[0m[33m2025-10-10 10:45:41.686 (  23.710s) [    14D977182200]vtkXOpenGLRenderWindow.:1416  WARN| bad X server connection. DISPLAY=[0m


Widget(value='<iframe id="pyvista-jupyter_trame__template_P_0x14d7eebf3a10_0" src="https://ondemand.bridges2.p…

In [14]:
def voxel_to_world_coords(points, affine):
    assert points.ndim == 2 and points.shape[1] == 3
    assert affine.shape == (4, 4) and np.allclose(affine[3], [0,0,0,1])
    points_h = np.c_[points, np.ones(points.shape[0])]
    return (affine @ points_h.T).T[:,:3]

def world_to_voxel_coords(points, affine):
    assert points.ndim == 2 and points.shape[1] == 3
    assert affine.shape == (4, 4) and np.allclose(affine[3], [0,0,0,1])
    points_h = np.c_[points, np.ones(points.shape[0])]
    return np.linalg.solve(affine, points_h.T).T[:,:3]

points_array = world_to_voxel_coords(target_mesh.points, target_image.affine)
points_array

array([[132.67303467, 101.42856598, 109.7533493 ],
       [129.50520325, 247.43318176, 160.07617188],
       [132.0176239 , 241.45373535, 174.35232544],
       ...,
       [ 75.49620056, 210.52752686, 100.24464417],
       [ 83.56447601, 213.04563904,  96.55213165],
       [124.33358002, 221.4677887 , 154.52536011]], shape=(17314, 3))

In [15]:
def print_tensor_info(*args, **kwargs):
    for i, v in enumerate(args):
        print(f'{i}: {v.shape} {v.dtype} {v.device}')
    for k, v in kwargs.items():
        print(f'{k}: {v.shape} {v.dtype} {v.device}')

device = 'cuda'

image_tensor   = torch.as_tensor(target_array,  dtype=torch.float32, device=device).unsqueeze(-1)
mask_tensor    = torch.as_tensor(t_mask_array,  dtype=torch.int32,   device=device).unsqueeze(-1)
density_tensor = torch.as_tensor(density_array, dtype=torch.float32, device=device).unsqueeze(-1)
disp_tensor    = torch.as_tensor(disp_array,    dtype=torch.float32, device=device)
points_tensor  = torch.as_tensor(points_array,  dtype=torch.float32, device=device)

print_tensor_info(image_tensor, mask_tensor, density_tensor, disp_tensor, points_tensor)

0: torch.Size([368, 368, 269, 1]) torch.float32 cuda:0
1: torch.Size([368, 368, 269, 1]) torch.int32 cuda:0
2: torch.Size([368, 368, 269, 1]) torch.float32 cuda:0
3: torch.Size([368, 368, 269, 3]) torch.float32 cuda:0
4: torch.Size([17314, 3]) torch.float32 cuda:0


In [16]:
import corrfield
import torch.nn.functional as F

def interpolate_image(image, points, mode='bilinear'):
    '''
    Args:
        image:  (X, Y, Z, C) input image tensor
        points: (N, 3) tensor of voxel indices (ijk)
    Returns:
        (N, C) tensor of interpolated values
    '''
    points_pt = corrfield.utils.kpts_pt(points, image.shape[:3], align_corners=True)
    return F.grid_sample(
        input=image.permute(3, 0, 1, 2).unsqueeze(0),
        grid=points_pt.unsqueeze(0).unsqueeze(0).unsqueeze(0),
        mode=mode,
        padding_mode='border',
        align_corners=True
    ).permute(0, 2, 3, 4, 1).squeeze(0).squeeze(0).squeeze(0)

mm_to_m = 1e-3

ct_values    = interpolate_image(image_tensor, points_tensor, mode='bilinear')
rho_values   = interpolate_image(density_tensor, points_tensor, mode='bilinear')
u_obs_values = interpolate_image(disp_tensor, points_tensor, mode='bilinear') * mm_to_m

mask_values  = interpolate_image(mask_tensor.float(), points_tensor, mode='nearest').int()

print_tensor_info(ct_values, mask_values, rho_values, u_obs_values)

0: torch.Size([17314, 1]) torch.float32 cuda:0
1: torch.Size([17314, 1]) torch.int32 cuda:0
2: torch.Size([17314, 1]) torch.float32 cuda:0
3: torch.Size([17314, 3]) torch.float32 cuda:0


In [17]:
def compute_lame_parameters(E, nu):
    '''
    Convert (E, nu) -> (mu, lambda).
    '''
    mu  = E / (2*(1 + nu))
    lam = E * nu / ((1 + nu)*(1 - 2*nu))
    return mu, lam

E_values = torch.ones_like(rho_values) * 3e3 # kPa
E_values.requires_grad = True

mu_values, lam_values = compute_lame_parameters(E_values, nu=0.4)

print_tensor_info(E_values, mu_values, lam_values)

0: torch.Size([17314, 1]) torch.float32 cuda:0
1: torch.Size([17314, 1]) torch.float32 cuda:0
2: torch.Size([17314, 1]) torch.float32 cuda:0


In [18]:
pv_mesh = pv.from_meshio(target_mesh)
pv_mesh.point_data['CT']         = ct_values.detach().cpu().numpy()
pv_mesh.point_data['rho']        = rho_values.detach().cpu().numpy()
pv_mesh.point_data['u_obs']      = u_obs_values.detach().cpu().numpy()
pv_mesh.point_data['u_obs_norm'] = np.linalg.norm(pv_mesh.point_data['u_obs'], axis=1)
pv_mesh

Header,Data Arrays
"UnstructuredGridInformation N Cells88042 N Points17314 X Bounds-1.421e+02, 9.507e+01 Y Bounds7.690e+01, 2.581e+02 Z Bounds-2.651e+02, -5.242e+01 N Arrays5",NameFieldTypeN CompMinMax CTPointsfloat321-9.898e+026.840e+02 rhoPointsfloat3211.751e+011.684e+03 u_obsPointsfloat323-3.677e-022.145e-02 u_obs_normPointsfloat3211.065e-033.969e-02 labelCellsint6411.000e+007.000e+00

UnstructuredGrid,Information
N Cells,88042
N Points,17314
X Bounds,"-1.421e+02, 9.507e+01"
Y Bounds,"7.690e+01, 2.581e+02"
Z Bounds,"-2.651e+02, -5.242e+01"
N Arrays,5

Name,Field,Type,N Comp,Min,Max
CT,Points,float32,1,-989.8,684.0
rho,Points,float32,1,17.51,1684.0
u_obs,Points,float32,3,-0.03677,0.02145
u_obs_norm,Points,float32,1,0.001065,0.03969
label,Cells,int64,1,1.0,7.0


In [19]:
p = pv.Plotter()
p.add_mesh(
    pv_mesh.threshold((1, 5), scalars='label'),
    scalars='label',
    cmap='Set1',
    clim=(0, 8),
    opacity=0.5,
)
p.add_mesh(
    pv_mesh.threshold((6, 7), scalars='label'),
    scalars='label',
    cmap='Set1',
    clim=(0, 8),
    opacity=1.0
)
p.enable_depth_peeling(10)
p.show()

Widget(value='<iframe id="pyvista-jupyter_trame__template_P_0x14d77cac7190_1" src="https://ondemand.bridges2.p…

In [20]:
import matplotlib.colors

cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    name='density',
    colors=[(1,1,1), (0,0,1), (0,0,0)]
)

p = pv.Plotter()
p.add_volume(
    pv_mesh,
    scalars='rho',
    opacity=[0.0, 1/8, 1/2],
    cmap=cmap,
    clim=(0, 2000)
)
p.show()

Widget(value='<iframe id="pyvista-jupyter_trame__template_P_0x14d77d7d14d0_2" src="https://ondemand.bridges2.p…

In [21]:
p = pv.Plotter()
arrows = pv_mesh.glyph(scale='u_obs', orient='u_obs', factor=1e3)
p.add_mesh(arrows, scalars='u_obs_norm', cmap='jet', clim=(0, 0.05))
p.show()

Widget(value='<iframe id="pyvista-jupyter_trame__template_P_0x14d77ca95550_3" src="https://ondemand.bridges2.p…

# NVIDIA warp.fem

In [None]:
import warp as wp
import warp.fem
import warp.optim.linear
wp.init()

In [None]:
geometry = wp.fem.Tetmesh(
    wp.array(target_mesh.cells_dict['tetra'], dtype=wp.int32),
    wp.array(target_mesh.points * mm_to_m,    dtype=wp.vec3f)
)

# integration domains
domain = wp.fem.Cells(geometry)
boundary = wp.fem.BoundarySides(geometry)

# function spaces
S = wp.fem.make_polynomial_space(geometry, degree=1, dtype=wp.float32)
V = wp.fem.make_polynomial_space(geometry, degree=1, dtype=wp.vec3f)

# trial and test functions
u_trial = wp.fem.make_trial(V, domain=domain)
v_test  = wp.fem.make_test(V, domain=domain)

ub_trial = wp.fem.make_trial(V, domain=boundary)
vb_test  = wp.fem.make_test(V, domain=boundary)

Physical model:
\begin{align}
    \nabla \cdot \sigma + \rho \mathbf{g} &= 0 \\
    \sigma &= 2 \mu \epsilon + \lambda \operatorname{tr}(\epsilon) \mathbf{I} \\
    \epsilon &= \tfrac{1}{2} \left( \nabla \mathbf{u} + \nabla \mathbf{u}^\top \right)
\end{align}
Where:
\begin{align}
    \sigma: \Omega &\to \mathbb{R}^{3 \times 3} & \rho: \Omega &\to \mathbb{R} & \mathbf{g} &\in \mathbb{R}^3 \\
    \epsilon: \Omega &\to \mathbb{R}^{3 \times 3} & \mu: \Omega &\to \mathbb{R} & \lambda: \Omega &\to \mathbb{R} \\
    \mathbf{u}: \Omega &\to \mathbb{R}^3 \\
\end{align}

In [None]:
# physical fields and constants
u_obs_field = V.make_field()
u_sim_field = V.make_field()
r_field     = V.make_field()

mu_field  = S.make_field()
lam_field = S.make_field()
rho_field = S.make_field()

g = wp.vec3f([0, 0, -9.81])
I = wp.diag(wp.vec3f(1.0))

In [25]:
# assign dof values
mu_field.dof_values    = wp.from_torch(mu_values.contiguous(), dtype=wp.float32, requires_grad=True)
lam_field.dof_values   = wp.from_torch(lam_values.contiguous(), dtype=wp.float32, requires_grad=True)

rho_field.dof_values   = wp.from_torch(rho_values.contiguous(), dtype=wp.float32, requires_grad=True)
u_obs_field.dof_values = wp.from_torch(u_obs_values.contiguous(), dtype=wp.vec3f, requires_grad=True)

u_sim_field.dof_values.zero_()
u_sim_field.dof_values.requires_grad = True

r_field.dof_values.zero_()
r_field.dof_values.requires_grad = True

  if t.grad is not None:


Weak formulation:
\begin{align}
    \nabla \cdot \sigma + \rho \mathbf{g} &= 0 \\
    -\int_\Omega \left( \nabla \cdot \sigma \right) \cdot \mathbf{v} \ d\mathbf{x} &= \int_\Omega \rho \mathbf{g} \cdot \mathbf{v} \ d\mathbf{x} \\
    \int_\Omega \sigma(\mathbf{u}) : \nabla \mathbf{v} \ d\mathbf{x} &= \int_\Omega \rho \mathbf{g}\cdot \mathbf{v} \ d\mathbf{x} + \int_{\partial \Omega} \left(\sigma \cdot \mathbf{n} \right) \cdot \mathbf{v} \ d\mathbf{s} \\
    \int_\Omega \sigma(\mathbf{u}) : \epsilon(\mathbf{v}) \ d\mathbf{x} &= \int_\Omega \rho \mathbf{g}\cdot \mathbf{v} \ d\mathbf{x} + \int_{\partial \Omega} \mathbf{t} \cdot \mathbf{v} \ d\mathbf{s} \\
    \int_\Omega a(\mathbf{u},\mathbf{v}) \ d\mathbf{x} &= \int_\Omega L(\mathbf{v}) \ d\mathbf{x} \\
    \ldots \\
    \mathbf{K} \mathbf{u} &= \mathbf{f}
\end{align}

In [26]:
@wp.fem.integrand
def pde_bilinear_form(
    s: wp.fem.Sample,
    u: wp.fem.Field,
    v: wp.fem.Field,
    mu: wp.fem.Field,
    lam: wp.fem.Field
):
    eps_u = wp.fem.D(u, s) # symmetric gradient
    eps_v = wp.fem.D(v, s)
    div_u = wp.fem.div(u, s)
    sigma_u = 2.0*mu(s)*eps_u + lam(s)*div_u*I
    return wp.ddot(sigma_u, eps_v)

@wp.fem.integrand
def pde_linear_form(
    s: wp.fem.Sample,
    v: wp.fem.Field,
    rho: wp.fem.Field,
    g: wp.vec3
):
    return rho(s) * wp.dot(g, v(s))


In [27]:
%%time
# assemble linear system
K = wp.fem.integrate(
    pde_bilinear_form,
    fields={'u': u_trial, 'v': v_test, 'mu': mu_field, 'lam': lam_field},
    domain=domain,
    output_dtype=wp.float32
)
f = wp.fem.integrate(
    pde_linear_form,
    fields={'v': v_test, 'rho': rho_field},
    values={'g': g},
    domain=domain,
    output_dtype=wp.vec3f
)
K, f

Module __main__.pde_bilinear_form__itg_f4f8_uTrialFieldTetmesh_T_39b2334d c0290b1 load on device 'cuda:0' took 2.99 ms  (cached)
Module warp.fem.field.virtual.dyn.dispatch_bilinear_kernel_fn_TrialFieldTetmesh_Tet_P1_Tet_P1__e0684a98 0ef8876 load on device 'cuda:0' took 6.09 ms  (cached)
Module warp.sparse fec1336 load on device 'cuda:0' took 2.97 ms  (cached)
Module __main__.pde_linear_form__itg_f4f8_vTestFieldTetmesh_Te_a4f0b61c eee9421 load on device 'cuda:0' took 2.70 ms  (cached)
Module warp.fem.field.virtual.dyn.dispatch_linear_kernel_fn_TestFieldTetmesh_Tet_P1_Tet_P1_v_3827beed e9d3cb7 load on device 'cuda:0' took 2.70 ms  (cached)
CPU times: user 97.3 ms, sys: 14.2 ms, total: 112 ms
Wall time: 198 ms


(BsrMatrix_float32_3_3(
 	nrow=17314,
 	ncol=17314,
 	nnz=1408672,
 	offsets=array(shape=(17315,), dtype=int32),
 	columns=array(shape=(1408672,), dtype=int32),
 	values=array(shape=(1408672,), dtype=mat33(f)),
 ),
 array(shape=(17314,), dtype=vec3f))

In [28]:
@wp.fem.integrand
def dbc_form(s: wp.fem.Sample, u: wp.fem.Field, v: wp.fem.Field, alpha: float):
    return alpha * wp.dot(u(s), v(s))

In [29]:
%%time
# apply boundary conditions
alpha = 1e6

wp.fem.integrate(
    dbc_form,
    fields={'u': ub_trial, 'v': vb_test},
    values={'alpha': alpha},
    domain=boundary,
    output=K,
    add=True
)
wp.fem.integrate(
    dbc_form,
    fields={'u': u_obs_field.trace(), 'v': vb_test},
    values={'alpha': alpha},
    domain=boundary,
    output=f,
    add=True
)
K, f

Module __main__.dbc_form__itg_f4f8_uTrialFieldTetmesh_T_6c12eb67 215eaae load on device 'cuda:0' took 3.10 ms  (cached)
Module warp.fem.field.virtual.dyn.dispatch_bilinear_kernel_fn_TrialFieldTetmesh_Tet_P1_Tet_P1__6b9704c0 46a0144 load on device 'cuda:0' took 3.04 ms  (cached)
Module warp.sparse 78297c8 load on device 'cuda:0' took 2.69 ms  (cached)
Module __main__.dbc_form__itg_f4f8_uNodalFieldTrace_Te_b398abbc 2d3caea load on device 'cuda:0' took 2.73 ms  (cached)
Module warp.fem.field.virtual.dyn.dispatch_linear_kernel_fn_TestFieldTetmesh_Tet_P1_Tet_P1_v_79ccaa80 0a3d5e7 load on device 'cuda:0' took 2.73 ms  (cached)
CPU times: user 315 ms, sys: 13.1 ms, total: 328 ms
Wall time: 334 ms


(BsrMatrix_float32_3_3(
 	nrow=17314,
 	ncol=17314,
 	nnz=2175136,
 	offsets=array(shape=(17315,), dtype=int32),
 	columns=array(shape=(2175136,), dtype=int32),
 	values=array(shape=(2175136,), dtype=mat33(f)),
 ),
 array(shape=(17314,), dtype=vec3f))

In [32]:
%%time
# solve linear system
warp.optim.linear.cg(
    A=K,
    b=f,
    x=u_sim_field.dof_values,
    M=warp.optim.linear.preconditioner(K, ptype='diag'),
    tol=1e-4
)

Module warp.optim.linear 5c44d62 load on device 'cuda:0' took 3.30 ms  (cached)
CPU times: user 5.99 ms, sys: 20.8 ms, total: 26.8 ms
Wall time: 27.4 ms


(0, 0.0052820605194991025, 0.005571819810573425)

In [33]:
@wp.fem.integrand
def error_form(s: wp.fem.Sample, y_pred: wp.fem.Field, y_true: wp.fem.Field):
    error = y_pred(s) - y_true(s)
    return wp.dot(error, error)

@wp.fem.integrand
def norm_form(s: wp.fem.Sample, y_true: wp.fem.Field):
    y_s = y_true(s)
    return wp.dot(y_s, y_s)

@wp.fem.integrand
def volume_form(s: wp.fem.Sample):
    return 1.0


In [34]:
%%time
# compute domain error

error = wp.fem.integrate(
    error_form,
    fields={'y_pred': u_sim_field, 'y_true': u_obs_field},
    domain=domain
)
norm = wp.fem.integrate(
    norm_form,
    fields={'y_true': u_obs_field},
    domain=domain
)
vol = wp.fem.integrate(volume_form, domain=domain)

eps = 1e-12

rms_error = wp.sqrt(error / (vol + eps))
rel_error = wp.sqrt(error / (norm + eps))

print(rms_error / mm_to_m, 'mm')
print(rel_error * 100, '%')

Module __main__.error_form__itg_f8f8_y_predNodalField_Tet_5500e500 4d44c8b load on device 'cuda:0' took 2.93 ms  (cached)
Module __main__.norm_form__itg_f8f8_y_trueNodalField_Tet_a4e1d7eb 02ba242 load on device 'cuda:0' took 3.10 ms  (cached)
Module __main__.volume_form__itg_f8f8_RegularQuadrature_Te_bccf3f8a e39acea load on device 'cuda:0' took 3.02 ms  (cached)
2.2916406668152423 mm
12.734242210346828 %
CPU times: user 25.2 ms, sys: 8.7 ms, total: 33.9 ms
Wall time: 37.6 ms


In [35]:
# compute boundary error

error_b = wp.fem.integrate(
    error_form,
    fields={'y_pred': u_sim_field.trace(), 'y_true': u_obs_field.trace()},
    domain=boundary
)
norm_b = wp.fem.integrate(
    norm_form,
    fields={'y_true': u_obs_field.trace()},
    domain=boundary
)
vol_b = wp.fem.integrate(volume_form, domain=boundary)

rms_error_b = wp.sqrt(error_b / (vol_b + eps))
rel_error_b = wp.sqrt(error_b / (norm_b + eps))

print(rms_error_b / mm_to_m, 'mm')
print(rel_error_b * 100, '%')

Module __main__.error_form__itg_f8f8_y_predNodalFieldTra_29f0cf49 abd8076 load on device 'cuda:0' took 3.03 ms  (cached)
Module __main__.norm_form__itg_f8f8_y_trueNodalFieldTra_2613b841 0a77733 load on device 'cuda:0' took 2.72 ms  (cached)
Module __main__.volume_form__itg_f8f8_RegularQuadrature_Te_28285060 2422843 load on device 'cuda:0' took 2.78 ms  (cached)
1.444328928974238 mm
7.6163842055649065 %


In [36]:
pv_mesh.point_data['u_sim']      = u_sim_field.dof_values.numpy()
pv_mesh.point_data['u_sim_norm'] = np.linalg.norm(pv_mesh.point_data['u_sim'], axis=1)

pv_mesh.point_data['u_err']      = pv_mesh.point_data['u_obs'] - pv_mesh.point_data['u_sim']
pv_mesh.point_data['u_err_norm'] = np.linalg.norm(pv_mesh.point_data['u_err'], axis=1)

pv_mesh

Header,Data Arrays
"UnstructuredGridInformation N Cells88042 N Points17314 X Bounds-1.421e+02, 9.507e+01 Y Bounds7.690e+01, 2.581e+02 Z Bounds-2.651e+02, -5.242e+01 N Arrays9",NameFieldTypeN CompMinMax CTPointsfloat321-9.898e+026.840e+02 rhoPointsfloat3211.751e+011.684e+03 u_obsPointsfloat323-3.677e-022.145e-02 u_obs_normPointsfloat3211.065e-033.969e-02 u_simPointsfloat323-3.590e-021.955e-02 u_sim_normPointsfloat3212.034e-043.960e-02 u_errPointsfloat323-4.797e-036.498e-03 u_err_normPointsfloat3212.125e-056.754e-03 labelCellsint6411.000e+007.000e+00

UnstructuredGrid,Information
N Cells,88042
N Points,17314
X Bounds,"-1.421e+02, 9.507e+01"
Y Bounds,"7.690e+01, 2.581e+02"
Z Bounds,"-2.651e+02, -5.242e+01"
N Arrays,9

Name,Field,Type,N Comp,Min,Max
CT,Points,float32,1,-989.8,684.0
rho,Points,float32,1,17.51,1684.0
u_obs,Points,float32,3,-0.03677,0.02145
u_obs_norm,Points,float32,1,0.001065,0.03969
u_sim,Points,float32,3,-0.0359,0.01955
u_sim_norm,Points,float32,1,0.0002034,0.0396
u_err,Points,float32,3,-0.004797,0.006498
u_err_norm,Points,float32,1,2.125e-05,0.006754
label,Cells,int64,1,1.0,7.0


In [37]:
p = pv.Plotter()
arrows = pv_mesh.glyph(scale='u_sim', orient='u_sim', factor=1e3)
p.add_mesh(arrows, scalars='u_sim_norm', cmap='jet', clim=(0, 0.05))
p.show()

Widget(value='<iframe id="pyvista-jupyter_trame__template_P_0x14d75f575310_4" src="https://ondemand.bridges2.p…

In [38]:
p = pv.Plotter()
arrows = pv_mesh.glyph(scale='u_err', orient='u_err', factor=1e3)
p.add_mesh(arrows, scalars='u_err_norm', cmap='jet', clim=(0, 0.05))
p.show()

Widget(value='<iframe id="pyvista-jupyter_trame__template_P_0x14d75e79c310_5" src="https://ondemand.bridges2.p…

# Inverse problem

PDE constraint:
$$
\begin{align}
    \mathbf{K}(\theta) \mathbf{u} &= \mathbf{f}(\theta) \\
    \mathbf{r}(\mathbf{u}, \theta) &= \mathbf{f}(\theta) - \mathbf{K}(\theta)\mathbf{u} \\
    \mathbf{r}(\mathbf{u}, \theta) &= 0
\end{align}
$$

Implicit differentiation:
$$
\begin{align}
    \frac{\partial \bf r}{\partial \bf u} \frac{d \bf u}{d \theta} + \frac{\partial \bf r}{\partial \theta} &= 0 \\
    \frac{d \bf u}{d \theta} &= -\left( \frac{\partial \bf r}{\partial \bf u} \right)^{-1}\frac{\partial \bf r}{\partial \theta} 
\end{align}
$$

Loss gradient:
$$
\begin{align}
    \frac{\partial L}{\partial \theta} &= \frac{\partial L}{\partial \bf u} \frac{d \bf u}{d \theta} \\
    \frac{\partial L}{\partial \theta} &= -\frac{\partial L}{\partial \bf u} \left( \frac{\partial \bf r}{\partial \bf u} \right)^{-1}\frac{\partial \bf r}{\partial \theta} \\
    \frac{\partial L}{\partial \theta} &= \frac{\partial L}{\partial \bf r} \frac{\partial \bf r}{\partial \theta} \\
\end{align}
$$

Adjoint method:
$$
\begin{align}
    \frac{\partial L}{\partial \bf r} &= -\frac{\partial L}{\partial \bf u} \left( \frac{\partial \bf r}{\partial \bf u} \right)^{-1} \\
    \frac{\partial L}{\partial \bf r} \left( \frac{\partial \bf r}{\partial \bf u} \right) &= -\frac{\partial L}{\partial \bf u} \\
    \left( \frac{\partial \bf r}{\partial \bf u} \right)^\top \frac{\partial L}{\partial \bf r}  &= -\frac{\partial L}{\partial \bf u} \\
    \mathbf{K}^\top \Psi &= \frac{d L}{d \bf u}
\end{align}
$$

In [39]:
@wp.fem.integrand
def pde_residual_form(
    s: wp.fem.Sample,
    u: wp.fem.Field,
    v: wp.fem.Field,
    mu: wp.fem.Field,
    lam: wp.fem.Field,
    rho: wp.fem.Field,
    g: wp.vec3
):
    lhs = pde_bilinear_form(s, u, v, mu, lam)
    rhs = pde_linear_form(s, v, rho, g)
    return rhs - lhs

@wp.fem.integrand
def dbc_residual_form(
    s: wp.fem.Sample,
    u_sim: wp.fem.Field,
    u_obs: wp.fem.Field,
    v: wp.fem.Field,
    alpha: float
):
    lhs = dbc_form(s, u_sim, v, alpha)
    rhs = dbc_form(s, u_obs, v, alpha)
    return rhs - lhs


In [40]:
%%time
tape = wp.Tape()

r_field.dof_values.zero_()
r_field.dof_values.requires_grad = True

with tape:
    wp.fem.integrate(
        pde_residual_form,
        fields={
            'u': u_sim_field,
            'v': v_test,
            'mu': mu_field,
            'lam': lam_field,
            'rho': rho_field
        },
        values={'g': g},
        domain=domain,
        output=r_field.dof_values
    )
    wp.fem.integrate(
        dbc_residual_form,
        fields={
            'u_sim': u_sim_field.trace(),
            'u_obs': u_obs_field.trace(),
            'v': vb_test
        },
        values={'alpha': alpha},
        domain=boundary,
        output=r_field.dof_values,
        add=True
    )

def solve_adjoint_system():
    return wp.optim.linear.cg(
        A=K,
        b=u_sim_field.dof_values.grad,
        x=r_field.dof_values.grad,
        M=warp.optim.linear.preconditioner(K, ptype='diag'),
        tol=1e-4
    )
    
tape.record_func(
    backward=solve_adjoint_system,
    arrays=[r_field.dof_values, u_sim_field.dof_values]
)

numer = wp.empty(1, dtype=wp.float32, requires_grad=True)
denom = wp.empty(1, dtype=wp.float32, requires_grad=True)

with tape:
    wp.fem.integrate(
        error_form,
        fields={'y_pred': u_sim_field, 'y_true': u_obs_field},
        domain=domain,
        output=numer
    )
    wp.fem.integrate(
        norm_form,
        fields={'y_true': u_obs_field},
        domain=domain,
        output=denom
    )
    rmse = (numer / (denom + eps))**0.5

rmse.grad = wp.ones(1, dtype=wp.float32)
tape.backward()

wp.to_torch(mu_field.dof_values.grad).norm()

Module __main__.pde_residual_form__itg_f4f8_uNodalField_Tetmesh__6f525958 45551bf load on device 'cuda:0' took 3.32 ms  (cached)
Module warp.fem.field.virtual.dyn.dispatch_linear_kernel_fn_TestFieldTetmesh_Tet_P1_Tet_P1_v_62372c9f 6a2088d load on device 'cuda:0' took 3.42 ms  (cached)
Module __main__.dbc_residual_form__itg_f4f8_u_simNodalFieldTrac_0848140d 004b428 load on device 'cuda:0' took 3.13 ms  (cached)
Module warp.fem.field.virtual.dyn.dispatch_linear_kernel_fn_TestFieldTetmesh_Tet_P1_Tet_P1_v_ff8cd726 b22b522 load on device 'cuda:0' took 2.90 ms  (cached)
Module __main__.error_form__itg_f4f8_y_predNodalField_Tet_be3a9c9d e37768c load on device 'cuda:0' took 2.79 ms  (cached)
Module warp.utils 5e751fc load on device 'cuda:0' took 2.67 ms  (cached)
Module __main__.norm_form__itg_f4f8_y_trueNodalField_Tet_c8e884cb 20877cc load on device 'cuda:0' took 2.96 ms  (cached)
Module map_add 11f3678 load on device 'cuda:0' took 2.39 ms  (cached)
Module map_div 97723a9 load on device 'cuda

tensor(9.3631e-07, device='cuda:0')

In [41]:
tape.zero()

In [42]:
class WarpFEMCore:

    def __init__(self, geometry, alpha=1e6, tol=1e-4, eps=1e-12):
        self.geometry = geometry
    
        # integration domains
        self.domain = wp.fem.Cells(geometry)
        self.boundary = wp.fem.BoundarySides(geometry)

        # function spaces
        self.S = wp.fem.make_polynomial_space(geometry, degree=1, dtype=wp.float32)
        self.V = wp.fem.make_polynomial_space(geometry, degree=1, dtype=wp.vec3f)

        # trial and test functions
        self.u_trial = wp.fem.make_trial(self.V, domain=self.domain)
        self.v_test  = wp.fem.make_test(self.V, domain=self.domain)

        self.ub_trial = wp.fem.make_trial(self.V, domain=self.boundary)
        self.vb_test  = wp.fem.make_test(self.V, domain=self.boundary)

        # physical fields and constants
        self.u_obs_field = self.V.make_field()
        self.u_sim_field = self.V.make_field()
        self.u_sim_field.dof_values.requires_grad = True

        self.r_field = self.V.make_field()
        self.r_field.dof_values.requires_grad = True

        self.mu_field  = self.S.make_field()
        self.lam_field = self.S.make_field()
        self.rho_field = self.S.make_field()

        self.g = wp.vec3f([0, 0, -9.81])
        self.I = wp.diag(wp.vec3f(1.0))

        # hyperparameters
        self.alpha = alpha
        self.tol = tol
        self.eps = eps

    def assign_fixed_values(self, rho, u_obs):
        self.rho_field.dof_values   = as_warp_array(rho, dtype=wp.float32, requires_grad=True)
        self.u_obs_field.dof_values = as_warp_array(u_obs, dtype=wp.vec3f, requires_grad=True)

    def assign_param_values(self, mu, lam):
        self.mu_field.dof_values  = as_warp_array(mu, dtype=wp.float32, requires_grad=True)
        self.lam_field.dof_values = as_warp_array(lam, dtype=wp.float32, requires_grad=True)

    def assemble_pde_operator(self):
        self.K_pde = wp.fem.integrate(
            pde_bilinear_form,
            fields={
                'u': self.u_trial,
                'v': self.v_test,
                'mu': self.mu_field,
                'lam': self.lam_field
            },
            domain=self.domain,
            output_dtype=wp.float32
        )

    def assemble_pde_rhs(self):
        self.f_pde = wp.fem.integrate(
            pde_linear_form,
            fields={'v': self.v_test, 'rho': self.rho_field},
            values={'g': self.g},
            domain=self.domain,
            output_dtype=wp.vec3f
        )

    def assemble_dbc_operator(self):
        self.K_bc = wp.fem.integrate(
            dbc_form,
            fields={'u': self.ub_trial, 'v': self.vb_test},
            values={'alpha': self.alpha},
            domain=self.boundary,
            output_dtype=wp.float32
        )

    def assemble_dbc_rhs(self):
        self.f_bc = wp.fem.integrate(
            dbc_form,
            fields={'u': self.u_obs_field.trace(), 'v': self.vb_test},
            values={'alpha': self.alpha},
            domain=self.boundary,
            output_dtype=wp.vec3f
        )

    def apply_boundary_condition(self):
        self.K = self.K_pde + self.K_bc
        self.f = self.f_pde + self.f_bc
        self.M = wp.optim.linear.preconditioner(self.K, ptype='diag')

    def solve_forward_system(self):
        return wp.optim.linear.cg(
            A=self.K,
            b=self.f,
            x=self.u_sim_field.dof_values,
            M=self.M,
            tol=self.tol
        )

    def solve_adjoint_system(self):
        return wp.optim.linear.cg(
            A=self.K,
            b=self.u_sim_field.dof_values.grad,
            x=self.r_field.dof_values.grad,
            M=self.M,
            tol=self.tol
        )

    def compute_residual(self):
        wp.fem.integrate(
            pde_residual_form,
            fields={
                'u': self.u_sim_field,
                'v': self.v_test,
                'mu': self.mu_field,
                'lam': self.lam_field,
                'rho': self.rho_field
            },
            values={'g': self.g},
            domain=self.domain,
            output=self.r_field.dof_values
        )
        wp.fem.integrate(
            dbc_residual_form,
            fields={
                'u_sim': self.u_sim_field.trace(),
                'u_obs': self.u_obs_field.trace(),
                'v': self.vb_test
            },
            values={'alpha': self.alpha},
            domain=self.boundary,
            output=self.r_field.dof_values,
            add=True
        )

    def compute_error(self, relative=False):
        numer = wp.empty(1, dtype=wp.float32, requires_grad=True)
        denom = wp.empty(1, dtype=wp.float32, requires_grad=True)

        wp.fem.integrate(
            error_form,
            fields={'y_pred': self.u_sim_field, 'y_true': self.u_obs_field},
            domain=self.domain,
            output=numer
        )
        if relative:
            wp.fem.integrate(
                norm_form,
                fields={'y_true': self.u_obs_field},
                domain=self.domain,
                output=denom
            )
        else:
            wp.fem.integrate(
                volume_form,
                domain=self.domain,
                output=denom
            )

        return numer / (denom + self.eps)

def as_warp_array(t, **kwargs):
    return wp.from_torch(t.contiguous().detach(), **kwargs)


In [43]:
class WarpFEMModule(torch.nn.Module):

    def __init__(self, geometry, rho, u_obs, **kwargs):
        super().__init__()
        self.core = WarpFEMCore(geometry, **kwargs)
        self.core.assign_fixed_values(rho, u_obs)
        self.core.assemble_dbc_operator()
        self.core.assemble_dbc_rhs()
        self.core.assemble_pde_rhs()

    def forward(self, mu, lam):
        return WarpFEMFn.apply(self.core, mu, lam)


class WarpFEMFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, core, mu, lam):
        ctx.core = core
        core.assign_param_values(mu, lam)
        core.assemble_pde_operator()
        core.apply_boundary_condition()
        core.solve_forward_system()

        ctx.tape = wp.Tape()
        with ctx.tape:
            core.compute_residual()

        ctx.tape.record_func(
            backward=core.solve_adjoint_system,
            arrays=[
                core.r_field.dof_values,
                core.u_sim_field.dof_values
            ]
        )
        with ctx.tape:
            ctx.error = core.compute_error()

        return wp.to_torch(ctx.error)

    @staticmethod
    def backward(ctx, grad_out):
        ctx.error.grad = wp.from_torch(grad_out.detach())
        ctx.tape.backward()
        mu_grad  = wp.to_torch(ctx.core.mu_field.dof_values.grad)
        lam_grad = wp.to_torch(ctx.core.lam_field.dof_values.grad)
        return (None, mu_grad, lam_grad)


fem = WarpFEMModule(geometry, rho_values, u_obs_values, alpha=1e6, tol=1e-6)
fem

WarpFEMModule()

In [None]:
rho_values = interpolate_image(density_tensor, points_tensor, mode='bilinear')
u_obs_values = interpolate_image(disp_tensor, points_tensor, mode='bilinear') * mm_to_m

E0 = 3e3
theta = torch.zeros_like(rho_values, requires_grad=True) 
optim = torch.optim.Adam([theta], lr=1e-3)

In [70]:
%%time
wp.config.quiet = True

for i in range(1000):
    optim.zero_grad()
    E_values = E0 * torch.exp(theta)
    loss = fem.forward(*compute_lame_parameters(E_values, nu=0.4))**0.5 / mm_to_m
    loss.backward()
    optim.step()
    grad_norm = theta.grad.norm()
    print(f'iteration {i} | loss={loss.item()**0.5:.8f} | grad_norm={grad_norm.item():.8e}')
    assert grad_norm > 0


iteration 0 | loss=0.87536290 | grad_norm=8.15776456e-03
iteration 1 | loss=0.87523393 | grad_norm=8.15249700e-03
iteration 2 | loss=0.87510595 | grad_norm=8.14736541e-03
iteration 3 | loss=0.87497728 | grad_norm=8.14212486e-03
iteration 4 | loss=0.87484982 | grad_norm=8.13707057e-03
iteration 5 | loss=0.87472114 | grad_norm=8.13181233e-03
iteration 6 | loss=0.87459350 | grad_norm=8.12661275e-03
iteration 7 | loss=0.87446486 | grad_norm=8.12141225e-03
iteration 8 | loss=0.87433763 | grad_norm=8.11632443e-03
iteration 9 | loss=0.87420993 | grad_norm=8.11118167e-03
iteration 10 | loss=0.87408164 | grad_norm=8.10598303e-03
iteration 11 | loss=0.87395425 | grad_norm=8.10092501e-03
iteration 12 | loss=0.87382667 | grad_norm=8.09576735e-03
iteration 13 | loss=0.87369853 | grad_norm=8.09059292e-03
iteration 14 | loss=0.87357149 | grad_norm=8.08551349e-03
iteration 15 | loss=0.87344293 | grad_norm=8.08033440e-03
iteration 16 | loss=0.87331634 | grad_norm=8.07529222e-03
iteration 17 | loss=0.87

In [71]:
pv_mesh.point_data['E'] = E_values.detach().cpu().numpy()
pv_mesh

Header,Data Arrays
"UnstructuredGridInformation N Cells88042 N Points17314 X Bounds-1.421e+02, 9.507e+01 Y Bounds7.690e+01, 2.581e+02 Z Bounds-2.651e+02, -5.242e+01 N Arrays10",NameFieldTypeN CompMinMax CTPointsfloat321-9.898e+026.840e+02 rhoPointsfloat3211.751e+011.684e+03 u_obsPointsfloat323-3.677e-022.145e-02 u_obs_normPointsfloat3211.065e-033.969e-02 u_simPointsfloat323-3.590e-021.955e-02 u_sim_normPointsfloat3212.034e-043.960e-02 u_errPointsfloat323-4.797e-036.498e-03 u_err_normPointsfloat3212.125e-056.754e-03 EPointsfloat3219.829e+014.506e+05 labelCellsint6411.000e+007.000e+00

UnstructuredGrid,Information
N Cells,88042
N Points,17314
X Bounds,"-1.421e+02, 9.507e+01"
Y Bounds,"7.690e+01, 2.581e+02"
Z Bounds,"-2.651e+02, -5.242e+01"
N Arrays,10

Name,Field,Type,N Comp,Min,Max
CT,Points,float32,1,-989.8,684.0
rho,Points,float32,1,17.51,1684.0
u_obs,Points,float32,3,-0.03677,0.02145
u_obs_norm,Points,float32,1,0.001065,0.03969
u_sim,Points,float32,3,-0.0359,0.01955
u_sim_norm,Points,float32,1,0.0002034,0.0396
u_err,Points,float32,3,-0.004797,0.006498
u_err_norm,Points,float32,1,2.125e-05,0.006754
E,Points,float32,1,98.29,450600.0
label,Cells,int64,1,1.0,7.0


In [69]:
p = pv.Plotter()
p.add_volume(
    pv_mesh,
    scalars='E',
    cmap='jet',
    opacity=0.05,
    clim=(0, 1e4)
)
p.show()

Widget(value='<iframe id="pyvista-jupyter_trame__template_P_0x14d5fb3e0350_17" src="https://ondemand.bridges2.…