In [3]:
%matplotlib notebook
%load_ext autoreload
%pwd

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


'/ocean/projects/asc170022p/mtragoza/lung-project/notebooks'

In [4]:
import sys, os, pathlib
os.environ['PKG_CONFIG_PATH'] = '/ocean/projects/asc170022p/mtragoza/mambaforge/envs/lung-project/lib/pkgconfig'

import numpy as np
import xarray as xr
import nibabel as nib
import pygalmesh
from mpi4py import MPI
import fenics as fe
import fenics_adjoint as fa
import torch
import torch.nn.functional as F
import torch_fenics

import matplotlib as mpl
import matplotlib.pyplot as plt

sys.path.append('..')
import project

--------------------------------------------------------------------------

  Local host:   dv001
  Local device: mlx5_0
--------------------------------------------------------------------------


In [5]:
%autoreload
emory4dct = project.imaging.Emory4DCT('../data/Emory-4DCT')

In [6]:
examples = []
for case in emory4dct.cases:
    for fixed_phase in emory4dct.phases:
        moving_phase = (fixed_phase + 10) % 100
        
        anat_file = case.nifti_file(fixed_phase)
        mask_file = case.mask_file(fixed_phase, roi='lung_combined_mask')
        disp_file = case.disp_file(moving_phase, fixed_phase)
        mesh_file = case.mesh_file(fixed_phase, radius=5)
        
        example = (anat_file, mask_file, disp_file, mesh_file)
        examples.append(example)
        
len(examples)

100

In [13]:
def load_nii_file(nii_file):
    print(f'Loading {nii_file}... ', end='')
    nifti = nib.load(nii_file)
    shape = nifti.header.get_data_shape()
    print(f'{shape}')
    return nifti

def load_mesh_file(mesh_file):
    print(f'Loading {mesh_file}... ', end='')
    mesh = fe.Mesh()
    with fe.XDMFFile(MPI.COMM_WORLD, str(mesh_file)) as f:
        f.read(mesh)
    n_vertices = mesh.num_vertices()
    print(f'{n_vertices} vertices')
    return mesh

class Dataset(torch.utils.data.Dataset):
     
    def __init__(self, examples):
        self.examples = examples
    
    def __len__(self):
        return len(self.examples)
        
    def __getitem__(self, idx): 
        anat_file, mask_file, disp_file, mesh_file = self.examples[idx]
        
        # load arrays from NIFTI files
        anat = load_nii_file(anat_file)
        mask = load_nii_file(mask_file)
        disp = load_nii_file(disp_file)

        # get array spatial resolution
        resolution = anat.header.get_zooms()
        
        # load mesh from xdmf file
        mesh = load_mesh_file(mesh_file)
        
        # convert to float32 tensors with shape (c,x,y,z)
        anat = torch.as_tensor(anat.get_fdata(), dtype=torch.float32).unsqueeze(0)
        mask = torch.as_tensor(mask.get_fdata(), dtype=torch.float32).unsqueeze(0)
        disp = torch.as_tensor(disp.get_fdata(), dtype=torch.float32).permute(3,0,1,2)

        resolution = torch.as_tensor(resolution, dtype=torch.float32)

        return anat, mask, disp, resolution, mesh
    
    def collate_fn(self, batch):
        anat = torch.stack([item[0] for item in batch])
        mask = torch.stack([item[1] for item in batch])
        disp = torch.stack([item[2] for item in batch])
        resolution = torch.stack([item[3] for item in batch])
        mesh = [item[4] for item in batch]
        return anat, mask, disp, resolution, mesh

dataset = Dataset(examples)
dataset[0]

Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T00.nii.gz... (256, 256, 94)
Loading ../data/Emory-4DCT/Case1Pack/TotalSegment/case1_T00/lung_combined_mask.nii.gz... (256, 256, 94)
Loading ../data/Emory-4DCT/Case1Pack/CorrField/case1_T10_T00.nii.gz... (256, 256, 94, 3)
Loading ../data/Emory-4DCT/Case1Pack/pygalmesh/case1_T00_5.xdmf... 15273 vertices


(tensor([[[[-1000., -1000., -1000.,  ..., -1000., -1000., -1000.],
           [-1000., -1000., -1000.,  ..., -1000., -1000., -1000.],
           [-1000., -1000., -1000.,  ..., -1000., -1000., -1000.],
           ...,
           [ -626.,  -632.,  -680.,  ...,  -716.,  -664.,  -796.],
           [ -545.,  -551.,  -562.,  ...,  -388.,  -458.,  -494.],
           [ -399.,  -405.,  -384.,  ...,  -381.,  -347.,  -422.]],
 
          [[-1000., -1000., -1000.,  ..., -1000., -1000., -1000.],
           [-1000., -1000., -1000.,  ..., -1000., -1000., -1000.],
           [-1000., -1000., -1000.,  ..., -1000., -1000., -1000.],
           ...,
           [ -616.,  -624.,  -670.,  ...,  -724.,  -672.,  -809.],
           [ -548.,  -533.,  -547.,  ...,  -380.,  -452.,  -499.],
           [ -402.,  -398.,  -367.,  ...,  -375.,  -353.,  -414.]],
 
          [[-1000., -1000., -1000.,  ..., -1000., -1000., -1000.],
           [-1000., -1000., -1000.,  ..., -1000., -1000., -1000.],
           [-1000., -100

In [42]:
class CNN(torch.nn.Module):
    
    def __init__(self, n_inputs, n_outputs, n_filters, kernel_size, activ_fn):
        super().__init__()
        self.conv1 = conv3d(n_inputs, n_filters, kernel_size)
        self.conv2 = conv3d(n_filters, n_filters, kernel_size)
        self.conv3 = conv3d(n_filters, n_outputs, kernel_size)
        self.activ_fn = get_activ_fn(activ_fn)
        
    def forward(self, a):
        z1 = self.activ_fn(self.conv1(a))
        z2 = self.activ_fn(self.conv2(z1))
        mu = self.conv3(z2)
        return mu


def conv3d(n_inputs, n_outputs, kernel_size):
    k = kernel_size
    return torch.nn.Conv3d(n_inputs, n_outputs, k, padding='same')

def get_activ_fn(name):
    try:
        return getattr(torch.nn.functional, name)
    except AttributeError:
        return getattr(torch, name)

model = CNN(n_inputs=1, n_outputs=1, n_filters=32, kernel_size=3, activ_fn='leaky_relu')
model

CNN(
  (conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
  (conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
  (conv3): Conv3d(32, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
)

In [43]:
class PDE(torch.autograd.Function):

    @staticmethod
    def forward(ctx, u_true, mu, rho, mesh):
        scalar_space = fe.FunctionSpace(mesh, 'P', 1)
        vector_space = fe.VectorFunctionSpace(mesh, 'P', 1)
        
        scalar_template = fa.Function(scalar_space)
        vector_template = fa.Function(vector_space)
        
        u_true = torch_fenics.numpy_fenics.numpy_to_fenics(u_true.detach().numpy(), vector_template)
        mu = torch_fenics.numpy_fenics.numpy_to_fenics(mu.detach().numpy(), scalar_template)
        rho = torch_fenics.numpy_fenics.numpy_to_fenics(rho.detach().numpy(), scalar_template)
        
        tape = fa.Tape()
        fa.set_working_tape(tape)
        
        u_pred = PDE.solve(u_true, mu, rho)
        
        ctx.pde_inputs = [u_true, mu, rho, mesh]
        ctx.u_pred = u_pred
        ctx.tape = tape
        
        u_pred = torch.from_numpy(torch_fenics.numpy_fenics.fenics_to_numpy(u_pred))
        return u_pred

    def backward(ctx, grad_output):   
        u_true, mu, rho, mesh = ctx.pde_inputs

        adj_value = torch_fenics.numpy_fenics.numpy_to_fenics(grad_output.numpy(), u_true).vector()
        
        controls = [
            fa.Control(c) for c,g in zip(ctx.pde_inputs, ctx.needs_input_grad) if g
        ]    
        grads = fa.compute_gradient(ctx.u_pred, controls, tape=ctx.tape, adj_value=adj_value)

        for i, g in enumerate(grads):
            if g is not None:
                grads[i] = torch.from_numpy(torch_fenics.numpy_fenics.fenics_to_numpy(g))
    
        grad_iter = iter(grads)
        return tuple(None if not g else next(grad_iter) for g in ctx.needs_input_grad)
        
    def solve(u_true, mu, rho):
        scalar_space = mu.function_space()
        vector_space = u_true.function_space()

        # define physical parameters
        g  = 9.8e-3 # gravitational acc (mm/s^2)
        nu = 0.4    # Poisson's ratio (unitless)

        # Lame's first parameter (Pa)
        lam = 2*mu*nu/(1 - 2*nu)

        # set displacement boundary condition
        u_bc = fa.DirichletBC(vector_space, u_true, 'on_boundary')

        # body force and traction
        b = fe.as_vector([-rho*g, 0, 0])
        t = fa.Constant([0, 0, 0])

        # define stress and strain
        def epsilon(u):
            return (fe.grad(u) + fe.grad(u).T) / 2

        def sigma(u):
            I = fe.Identity(u.geometric_dimension())
            return lam*fe.div(u)*I + 2*mu*epsilon(u)

        # weak formulation
        u = fe.TrialFunction(vector_space)
        v = fe.TestFunction(vector_space)

        a = fe.inner(sigma(u), epsilon(v)) * fe.dx
        L = fe.dot(b, v)*fe.dx + fe.dot(t, v)*fe.dx
        
        # solve for displacement
        u_pred = fa.Function(vector_space)
        fa.solve(a == L, u_pred, u_bc)

        return u_pred

pde = PDE()
pde

  pde = PDE()


<__main__.PDE at 0x1480c4cc7de0>

In [None]:
%%time
import tqdm

def relative_error(u_pred, u_true):
    u_diff = u_pred - u_true
    u_diff_norm2 = (u_diff**2).sum(axis=-1)
    u_true_norm2 = (u_true**2).sum(axis=-1) + 1e-8
    return torch.mean(u_diff_norm2 / u_true_norm2)

def train(dataset, model, pde, batch_size, n_epochs, learning_rate):
    
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, collate_fn=dataset.collate_fn
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(n_epochs):
        print(f'Epoch {epoch+1}/{n_epochs}')
        
        for i, batch in enumerate(pbar:=tqdm.tqdm(data_loader)):

            anat_image, mask, u_image, res, mesh = batch
        
            mu_image = model.forward(anat_image)
            rho_image = 1000*(1 + anat_image/1000)
            
            loss = 0
            for j in range(batch_size):
                scalar_space = fe.FunctionSpace(mesh[j], 'P', 1)
                vector_space = fe.VectorFunctionSpace(mesh[j], 'P', 1)
    
                u_dofs = project.interpolate.image_to_dofs(u_image[j], res[j], vector_space)
                mu_dofs = project.interpolate.image_to_dofs(mu_image[j], res[j], scalar_space)
                rho_dofs = project.interpolate.image_to_dofs(rho_image[j], res[j], scalar_space)
                
                u_sim_dofs = pde.apply(u_dofs, mu_dofs, rho_dofs, mesh[j])   
                loss = loss + relative_error(u_sim_dofs, u_dofs)

            loss = loss / batch_size
            pbar.set_description(f'loss = {loss.item():.4f}')
            
            loss.backward()
            optimizer.step()

train(dataset, model, pde, batch_size=1, n_epochs=10, learning_rate=1e-5)

Epoch 1/10


  0%|          | 0/100 [00:00<?, ?it/s]

Loading ../data/Emory-4DCT/Case3Pack/NIFTI/case3_T50.nii.gz... (256, 256, 94)
Loading ../data/Emory-4DCT/Case3Pack/TotalSegment/case3_T50/lung_combined_mask.nii.gz... (256, 256, 94)
Loading ../data/Emory-4DCT/Case3Pack/CorrField/case3_T60_T50.nii.gz... (256, 256, 94, 3)
Loading ../data/Emory-4DCT/Case3Pack/pygalmesh/case3_T50_5.xdmf... 23094 vertices


loss = nan:   0%|          | 0/100 [00:10<?, ?it/s]