In [2]:
%matplotlib notebook
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import os
os.environ['PKG_CONFIG_PATH'] = '/ocean/projects/asc170022p/mtragoza/mambaforge/envs/lung-project/lib/pkgconfig'

import numpy as np
import xarray as xr
import pygalmesh
import meshio
from mpi4py import MPI
import fenics as fe
import fenics_adjoint as fa
import torch
import torch.nn.functional as F
import torch_fenics

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

import sys
sys.path.append('..')
import project

# Load lung CT image data

In [None]:
case_name = 'Case1Pack'
ee_phase = 50
ei_phase = 0
lung_roi = 'lung_combined_mask'

In [None]:
ds = project.imaging.Emory4DCTDataset(
    data_root='../data/Emory-4DCT',
    case_names=[case_name],
    phases=[ei_phase, ee_phase]
)
ds.load_niftis()
ds.load_masks(roi=['lung_combined_mask'])
ds.load_displacements(fixed_phase=ee_phase)

In [None]:
case = ds[0]
resolution = list(case.resolution) # mm
print(resolution)

ee_anat_image = case.array.sel(phase=ee_phase)
ei_anat_image = case.array.sel(phase=ei_phase)

rho_water = 1000 # kg / m^3
ee_rho_image = rho_water * (1 + ee_anat_image / 1000)
ei_rho_image = rho_water * (1 + ee_anat_image / 1000)

ee_lung_mask = case.mask.sel(phase=ee_phase, roi=lung_roi)
ei_lung_mask = case.mask.sel(phase=ei_phase, roi=lung_roi)

ee_overlay = (ee_lung_mask + 1) / 2
ei_overlay = (ei_lung_mask + 1) / 2

reg_disp_image = case.disp.sel(phase=ei_phase, fixed_phase=ee_phase)
reg_disp_image *= xr.DataArray(resolution, dims=['component'], coords={'component': ['x', 'y', 'z']}) # mm

In [None]:
project.visual.view(ee_anat_image, x='x', y='z', cmap='Greys_r', vmin=-1000, vmax=1000).update_index(y=140)
project.visual.view(ei_anat_image, x='x', y='z', cmap='Greys_r', vmin=-1000, vmax=1000).update_index(y=140)
project.visual.view(reg_disp_image * ee_overlay, x='x', y='z').update_index(y=140, component=2)
project.visual.view(ee_rho_image, x='x', y='z', cmap='Greys_r', vmin=0, vmax=1000).update_index(y=140)
project.visual.view(ei_rho_image, x='x', y='z', cmap='Greys_r', vmin=0, vmax=1000).update_index(y=140)

# Anatomical mesh generation

In [None]:
%%time
ee_lung_mesh = pygalmesh.generate_from_array(
    ee_lung_mask.values.astype(np.uint16),
    voxel_size=resolution,
    max_cell_circumradius=5.0, # mm
    odt=True
)
ee_lung_mesh = project.meshing.remove_unused_points(ee_lung_mesh)
ee_lung_mesh

In [None]:
%%time
vertices = ee_lung_mesh.points
facets = ee_lung_mesh.cells[0].data
tetras = ee_lung_mesh.cells[1].data

tetra_facets = np.concatenate([
    tetras[:,[0,1,2]],
    tetras[:,[0,1,3]],
    tetras[:,[0,2,3]],
    tetras[:,[1,2,3]],
    
    tetras[:,[2,1,0]],
    tetras[:,[3,1,0]],
    tetras[:,[3,2,0]],
    tetras[:,[3,2,1]],
])

#angles = project.meshing.compute_angles_to_interior(vertices, facets, tetras)
#tetra_angles = project.meshing.compute_angles_to_interior(vertices, tetra_facets, tetras)

vertices.shape, facets.shape, tetras.shape

In [None]:
# view surface mesh

fig, ax = project.meshing.plot_mesh(
    vertices,
    facets, #[angles > 0],
    facecolors='white',
    edgecolors='black',
    shade=True,
    alpha=1.0,
    linewidth=0.4,
    figsize=(8,8)
)
ax.grid(False)
ax.xaxis.set_pane_color((1,1,1,0))
ax.yaxis.set_pane_color((1,1,1,0))
ax.zaxis.set_pane_color((1,1,1,0))
ax.set_axis_off()

In [None]:
# view volume mesh

fig, ax = project.meshing.plot_mesh(
    vertices,
    tetra_facets,
    facecolors='white',
    edgecolors='black',
    alpha=0.1,
    linewidth=0.4,
    figsize=(8,8)
)
ax.grid(False)
ax.xaxis.set_pane_color((1,1,1,0))
ax.yaxis.set_pane_color((1,1,1,0))
ax.zaxis.set_pane_color((1,1,1,0))
ax.set_axis_off()

In [None]:
def convert_to_fe_mesh(points, cells):
    '''
    Convert from meshio to fenics mesh.
    '''
    mesh_file = 'temp.xdmf'
    meshio.write_points_cells(mesh_file, points, [(cells.type, cells.data)])
    fe_mesh = fe.Mesh()
    with fe.XDMFFile(MPI.COMM_WORLD, mesh_file) as f:
        f.read(fe_mesh)
    return fe_mesh

ee_lung_fe_mesh = convert_to_fe_mesh(ee_lung_mesh.points, ee_lung_mesh.cells[1])
#ee_lung_fe_mesh

In [None]:
%%time
u_tensor = torch.as_tensor(reg_disp_image.values)
print(u_tensor.shape)

u_func_space = fe.VectorFunctionSpace(ee_lung_fe_mesh, 'P', 1)
u_func_dofs = project.interpolate.image_to_dofs(u_tensor, resolution, u_func_space)
print(u_func_dofs.shape)

In [None]:
%%time
u_interp = project.interpolate.dofs_to_image(u_func_dofs, u_func_space, case.shape, resolution)
print(u_interp.shape)

In [None]:
project.visual.XArrayViewer((reg_disp_image * 0 + u_interp) * ee_overlay, x='x', y='z').update_index(y=140, component=2)
project.visual.XArrayViewer(reg_disp_image * ee_overlay, x='x', y='z').update_index(y=140, component=2)

In [None]:
%%time
rho_tensor = torch.as_tensor(ee_rho_image.values)
print(rho_tensor.shape)

rho_func_space = fe.FunctionSpace(ee_lung_fe_mesh, 'P', 1)
rho_func_dofs = project.interpolate.image_to_dofs(rho_tensor, resolution, rho_func_space)
print(rho_func_dofs.shape)

In [None]:
%%time
rho_interp = project.interpolate.dofs_to_image(rho_func_dofs, rho_func_space, case.shape, resolution)
print(rho_interp.shape)

In [None]:
project.visual.XArrayViewer((ee_rho_image * 0 + rho_interp) * ee_overlay, x='x', y='z', vmin=0, cmap='Greys_r').update_index(y=140)
project.visual.XArrayViewer(ee_rho_image * ee_overlay, x='x', y='z', vmin=0, cmap='Greys_r').update_index(y=140)

## Physical model

Conservation of momentum
$$
    \nabla\cdot\sigma + \mathbf{b} = 0
$$

Isotropic linear elasticity
$$
    \sigma = 2 \mu \epsilon + \lambda \operatorname{tr}(\epsilon) \mathbf{I}
$$

Strain-displacement relation
$$
    \epsilon = \tfrac{1}{2} \left( \nabla \mathbf{u} + \nabla \mathbf{u}^\top \right)
$$

Weak formulation
$$
\begin{align*}
    \nabla\cdot\sigma + \mathbf{b} &= 0 \\
    -\nabla\cdot\sigma &= \mathbf{b} \\
    -\int_\Omega (\nabla\cdot\sigma)\cdot\mathbf{v} \,d\mathbf{x} &= \int_\Omega \mathbf{b}\cdot\mathbf{v} \,d\mathbf{x} \\
    \int_\Omega \sigma\cdot\nabla\mathbf{v} \,d\mathbf{x} &= \int_\Omega \mathbf{b}\cdot\mathbf{v} \,d\mathbf{x} + \int_{\partial \Omega} (\sigma\cdot\mathbf{n})\cdot\mathbf{v} \,d\mathbf{s} \\
    \int_\Omega \sigma(\mathbf{u})\cdot\epsilon(\mathbf{v}) \,d\mathbf{x} &= \int_\Omega \mathbf{b}\cdot\mathbf{v} \,d\mathbf{x} + \int_{\partial \Omega} \mathbf{t}\cdot\mathbf{v} \,d\mathbf{s}
\end{align*}
$$



In [None]:
%%time
class PDESolver(torch_fenics.FEniCSModule):

    def __init__(self, mesh):
        super().__init__()
        
        # initialize function spaces
        self.scalar_space = fe.FunctionSpace(mesh, 'P', 1)
        self.vector_space = fe.VectorFunctionSpace(mesh, 'P', 1)
            
    def input_templates(self):
        return (
            fa.Function(self.vector_space), # u_true
            fa.Function(self.scalar_space), # mu
            fa.Function(self.scalar_space), # rho
        )

    def solve(self, u_true, mu, rho):

        # define physical parameters
        g   = 9.8 # gravitational acceleration (m/s^2)
        nu  = 0.4 # Poisson's ratio (unitless)
    
        lam = 2*mu*nu/(1 - 2*nu) # Lame's first parameter (Pa)

        # set displacement boundary condition
        u_bc = fa.DirichletBC(self.vector_space, u_true, 'on_boundary')

        # body force and traction
        b = fe.as_vector([-rho * g / 1000, 0, 0]) # m -> mm
        t = fa.Constant([0, 0, 0])

        # define stress and strain
        def epsilon(u):
            return (fe.grad(u) + fe.grad(u).T) / 2

        def sigma(u):
            I = fe.Identity(u.geometric_dimension())
            return lam * fe.div(u) * I + 2 * mu * epsilon(u)

        # weak PDE formulation
        u = fe.TrialFunction(self.vector_space)
        v = fe.TestFunction(self.vector_space)

        a = fe.inner(sigma(u), epsilon(v)) * fe.dx
        L = fe.dot(b, v) * fe.dx + fe.dot(t, v) * fe.ds

        # solve for displacement
        u_pred = fa.Function(self.vector_space)
        fa.solve(a == L, u_pred, u_bc)

        return u_pred

pde_solver = PDESolver(ee_lung_fe_mesh)

u_tensor = torch.as_tensor(reg_disp_image.values)
mu_tensor = torch.ones(case.shape, dtype=torch.float64) * 3e3
rho_tensor = torch.as_tensor(ee_rho_image.values, dtype=torch.float64)

u_func_dofs = project.interpolate.image_to_dofs(u_tensor, case.resolution, pde_solver.vector_space)
u_func_dofs.requires_grad = False

mu_func_dofs = project.interpolate.image_to_dofs(mu_tensor, case.resolution, pde_solver.scalar_space)
mu_func_dofs.requires_grad = True

rho_func_dofs = project.interpolate.image_to_dofs(rho_tensor, case.resolution, pde_solver.scalar_space)
rho_func_dofs.requires_grad = False

u_pred_func_dofs = pde_solver.forward(
    u_func_dofs.unsqueeze(0), mu_func_dofs.unsqueeze(0), rho_func_dofs.unsqueeze(0)
)[0]

In [None]:
# compute PDE-constrained loss and mu sensitivity

L = F.mse_loss(u_pred_func_dofs, u_func_dofs)
print(L)
L.backward()
print(u_func_dofs.grad)
print(mu_func_dofs.grad)

In [None]:
# PDE-constrained optimization

def relative_error(u_pred, u_true):
    u_diff = u_pred - u_true
    u_diff_norm2 = (u_diff**2).sum(axis=-1)
    u_true_norm2 = (u_true**2).sum(axis=-1)
    return torch.mean(u_diff_norm2 / u_true_norm2)

mu_func_dofs.requires_grad = True
optimizer = torch.optim.Adam([mu_func_dofs], lr=1e1)

for i in range(1000):
    optimizer.zero_grad()

    u_pred_func_dofs = pde_solver.forward(
        u_func_dofs.unsqueeze(0), mu_func_dofs.unsqueeze(0), rho_func_dofs.unsqueeze(0)
    )[0]

    loss = relative_error(u_pred_func_dofs, u_func_dofs)
    print((i+1), loss.item())

    loss.backward()
    optimizer.step()


In [None]:
%%time
u_pred_image = project.interpolate.dofs_to_image(u_pred_func_dofs, pde_solver.vector_space, case.shape, case.resolution)

In [None]:
%%time
mu_pred_image = project.interpolate.dofs_to_image(mu_func_dofs, pde_solver.scalar_space, case.shape, case.resolution)

In [None]:
project.visual.XArrayViewer((reg_disp_image * 0 + u_pred_image) * ee_lung_mask, x='x', y='z').update_index(y=140, component=2)
project.visual.XArrayViewer(reg_disp_image * ee_lung_mask, x='x', y='z').update_index(y=140, component=2)
project.visual.XArrayViewer(mu_pred_image * ee_lung_mask, x='x', y='z', cmap=project.visual.mre_color_map(), vmin=-2e4, vmax=2e4).update_index(y=140)

In [None]:
def estimate_limit(x, expand=0.1):
    x_min, x_max = np.min(x), np.max(x)
    x_range = (x_max - x_min)
    x_min -= expand * x_range / 2
    x_max += expand * x_range / 2
    return x_min, x_max

def plot_func(f):

    mesh = f.function_space().mesh()

    vertices = mesh.coordinates()
    n_vertices = mesh.num_vertices()
    values = f.compute_vertex_values(mesh)

    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111, projection='3d', aspect='equal')

    X = [vertices[:,i] for i in range(3)]
    U = [values[i * n_vertices:(i + 1) * n_vertices] for i in range(3)]
    
    print(np.mean(np.linalg.norm(U, axis=1), axis=0))
       
    args = X + U
    ax.quiver(*args)
    
    ax.set_xlim(estimate_limit(vertices.flatten()))
    ax.set_ylim(estimate_limit(vertices.flatten()))
    ax.set_zlim(estimate_limit(vertices.flatten()))
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    return fig, ax

plot_func(disp_fe_func)

In [None]:
project.visual.view((reg_disp_image * 0 + sim_disp_image) * ee_lung_mask, x='x', y='z')