In [1]:
%matplotlib notebook
%load_ext autoreload
%pwd

'/ocean/projects/asc170022p/mtragoza/lung-project/notebooks'

In [2]:
import sys, os, pathlib
os.environ['PKG_CONFIG_PATH'] = '/ocean/projects/asc170022p/mtragoza/mambaforge/envs/lung-project/lib/pkgconfig'

import numpy as np
import xarray as xr
import nibabel as nib
import pygalmesh
from mpi4py import MPI
import fenics as fe
import fenics_adjoint as fa
import torch
import torch.nn.functional as F
import torch_fenics
import tqdm
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

sys.path.append('..')
import project

--------------------------------------------------------------------------

  Local host:   dv004
  Local device: mlx5_0
--------------------------------------------------------------------------


In [3]:
%autoreload
emory4dct = project.imaging.Emory4DCT('../data/Emory-4DCT')

In [4]:
examples = []
for case in emory4dct.cases:
    for fixed_phase in emory4dct.phases:
        moving_phase = (fixed_phase + 10) % 100
        
        anat_file = case.nifti_file(fixed_phase)
        disp_file = case.disp_file(moving_phase, fixed_phase)
        mask_file = case.mask_file(fixed_phase, roi='lung_combined_mask')
        mesh_file = case.mesh_file(fixed_phase, radius=20)
        
        example = (anat_file, disp_file, mask_file, mesh_file)
        examples.append(example)
        
len(examples)

100

In [5]:
class Dataset(torch.utils.data.Dataset):
     
    def __init__(self, examples, dtype=torch.float32, device='cuda'):
        super().__init__()

        self.examples = examples
        self.dtype = dtype
        self.device = device

        self.cache = [None] * len(examples)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        if self.cache[idx] is None:
            self.cache[idx] = self.load_example(idx)
        return self.cache[idx]
    
    def load_example(self, idx):
        anat_file, disp_file, mask_file, mesh_file = self.examples[idx]    
        example_name = anat_file.stem
        
        # load images from NIFTI files
        anat = load_nii_file(anat_file)
        disp = load_nii_file(disp_file)
        mask = load_nii_file(mask_file)
        
        # load mesh from xdmf file
        mesh = load_mesh_file(mesh_file)

        # get image spatial resolution
        resolution = anat.header.get_zooms()

        # convert arrays to tensors with shape (c,x,y,z)
        anat = torch.as_tensor(anat.get_fdata(), dtype=self.dtype, device=self.device).unsqueeze(0)
        disp = torch.as_tensor(disp.get_fdata(), dtype=self.dtype, device=self.device).permute(3,0,1,2)
        mask = torch.as_tensor(mask.get_fdata(), dtype=self.dtype, device=self.device).unsqueeze(0)

        return anat, disp, mask, mesh, resolution, example_name
    
def load_nii_file(nii_file):
    print(f'Loading {nii_file}... ', end='')
    nifti = nib.load(nii_file)
    print(nifti.header.get_data_shape())
    return nifti

def load_mesh_file(mesh_file):
    print(f'Loading {mesh_file}... ', end='')
    mesh = fe.Mesh()
    with fe.XDMFFile(MPI.COMM_WORLD, str(mesh_file)) as f:
        f.read(mesh)
    print(mesh.num_vertices())
    return mesh

def collate_fn(batch):
    # we need a custom collate_fn bc mesh is not a tensor
    anat = torch.stack([ex[0] for ex in batch])
    mask = torch.stack([ex[1] for ex in batch])
    disp = torch.stack([ex[2] for ex in batch])
    resolution = [ex[3] for ex in batch]
    mesh = [ex[4] for ex in batch]
    name = [ex[5] for ex in batch]
    return anat, mask, disp, resolution, mesh, name

dataset = Dataset(examples, dtype=torch.float32, device='cuda')
example = dataset[0]

Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T00.nii.gz... (256, 256, 94)
Loading ../data/Emory-4DCT/Case1Pack/CorrField/case1_T10_T00.nii.gz... (256, 256, 94, 3)
Loading ../data/Emory-4DCT/Case1Pack/TotalSegment/case1_T00/lung_combined_mask.nii.gz... (256, 256, 94)
Loading ../data/Emory-4DCT/Case1Pack/pygalmesh/case1_T00_20.xdmf... 346


In [6]:
def my_tensor_repr(t):
    shape = tuple(t.shape)
    is_nan = t.float().isnan()
    not_nan = ~is_nan
    num_nan = is_nan.sum()
    mean = t[not_nan].mean()
    std = t[not_nan].std() if not_nan.sum() > 1 else np.nan # hide the torch warning
    return f'Tensor(shape={shape}, μ={mean:.4f}, σ={std:.4f}, #nan={num_nan}, dtype={t.dtype}, device={t.device})'

torch.Tensor.__repr__ = my_tensor_repr

example

(Tensor(shape=(1, 256, 256, 94), μ=-487.4489, σ=477.8683, #nan=0, dtype=torch.float32, device=cuda:0),
 Tensor(shape=(3, 256, 256, 94), μ=0.0061, σ=0.2449, #nan=0, dtype=torch.float32, device=cuda:0),
 Tensor(shape=(1, 256, 256, 94), μ=0.1570, σ=0.3638, #nan=0, dtype=torch.float32, device=cuda:0),
 <dolfin.cpp.mesh.Mesh at 0x1552edaf0400>,
 (0.97, 0.97, 2.5),
 'case1_T00.nii')

In [7]:
class ConvUnit(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.norm = torch.nn.BatchNorm3d(in_channels)
        self.conv = torch.nn.Conv3d(in_channels, out_channels, kernel_size, padding='same', padding_mode='replicate')
        self.relu = torch.nn.LeakyReLU(inplace=True)
        
    def forward(self, x):
        x = self.norm(x)
        x = self.conv(x)
        x = self.relu(x)
        return x


class ConvBlock(torch.nn.Sequential):

    def __init__(self, in_channels, out_channels, kernel_size, num_conv_layers, hidden_channels=None):
        super().__init__()
        
        if not hidden_channels:
            hidden_channels = out_channels
        elif num_conv_layers < 2:
            print('Warning: hidden_channels argument only used if num_conv_layers >= 2')

        for i in range(num_conv_layers):
            layer = ConvUnit(
                in_channels=(hidden_channels if i > 0 else in_channels),
                out_channels=(hidden_channels if i < num_conv_layers - 1 else out_channels),
                kernel_size=kernel_size
            )
            self.add_module(f'conv_unit{i}', layer)
            

class Upsample(torch.nn.Module):
    
    def __init__(self, mode):
        super().__init__()
        self.mode = mode
        
    def __repr__(self):
        return f'{type(self).__name__}(mode={self.mode})'
        
    def forward(self, x, size):
        return F.interpolate(x, size=size, mode=self.mode)


class EncoderBlock(torch.nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        conv_kernel_size,
        num_conv_layers,
        hidden_channels=None,
        apply_pooling=True,
        pool_kernel_size=2,
        pool_type='max'
    ):
        super().__init__()
        assert pool_type in {'max', 'avg'}

        if apply_pooling:
            if pool_type == 'max':
                self.pooling = torch.nn.MaxPool3d(kernel_size=pool_kernel_size)
            elif pool_type == 'avg':
                self.pooling = torch.nn.AvgPool3d(kernel_size=pool_kernel_size)
        else:
            self.pooling = None
            
        self.conv_block = ConvBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=conv_kernel_size,
            num_conv_layers=num_conv_layers,
            hidden_channels=hidden_channels
        )
        
    def forward(self, x):
        if self.pooling:
            x = self.pooling(x)
        x = self.conv_block(x)
        return x


class DecoderBlock(torch.nn.Module):
    
    def __init__(
        self,
        in_channels,
        out_channels,
        conv_kernel_size,
        num_conv_layers,
        hidden_channels=None,
        upsample_mode='nearest'
    ):
        super().__init__()

        self.upsample = Upsample(mode=upsample_mode)

        self.conv_block = ConvBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=conv_kernel_size,
            num_conv_layers=num_conv_layers,
            hidden_channels=hidden_channels,
        )

    def forward(self, x, encoder_feats):
        x = self.upsample(x, size=encoder_feats.shape[2:])
        x = torch.cat([x, encoder_feats], dim=1)
        x = self.conv_block(x)
        return x


In [8]:
class UNet3D(torch.nn.Module):
    
    def __init__(
        self,
        in_channels,
        out_channels,
        num_levels,
        num_conv_layers,
        conv_channels,
        conv_kernel_size,
        pool_kernel_size=2,
        pool_type='max',
        upsample_mode='trilinear',
    ):
        super().__init__()
        assert num_levels > 0
        
        curr_channels = in_channels
        next_channels = conv_channels
        
        self.encoder = torch.nn.Sequential()
        for i in range(num_levels):
        
            encoder_block = EncoderBlock(
                in_channels=curr_channels,
                out_channels=next_channels,
                conv_kernel_size=conv_kernel_size,
                num_conv_layers=num_conv_layers,
                apply_pooling=(i > 0),
                pool_kernel_size=pool_kernel_size,
                pool_type=pool_type
            )
            self.encoder.add_module(f'level{i}', encoder_block)

            curr_channels = next_channels
            next_channels = curr_channels * 2
        
        next_channels = curr_channels // 2
        
        self.decoder = torch.nn.Sequential()
        for i in reversed(range(num_levels - 1)):

            decoder_block = DecoderBlock(
                in_channels=curr_channels + next_channels,
                out_channels=next_channels,
                conv_kernel_size=conv_kernel_size,
                num_conv_layers=num_conv_layers,
                upsample_mode=upsample_mode
            )
            self.decoder.add_module(f'level{i}', decoder_block)
            
            curr_channels = next_channels
            next_channels = curr_channels // 2
        
        self.final_conv = torch.nn.Conv3d(curr_channels, out_channels, kernel_size=1)

    def forward(self, x):
        
        # encoder part
        encoder_feats = []
        for i, encoder in enumerate(self.encoder):
            x = encoder(x)
            encoder_feats.append(x)
        
        # reverse encoder features to align with decoder
        encoder_feats = encoder_feats[::-1]
    
        # decoder part
        for i, decoder in enumerate(self.decoder):
            x = decoder(x, encoder_feats[i+1])

        return self.final_conv(x)


model = UNet3D(in_channels=1, out_channels=1, num_levels=3, num_conv_layers=2, conv_channels=4, conv_kernel_size=3)
model.cuda()

UNet3D(
  (encoder): Sequential(
    (level0): EncoderBlock(
      (conv_block): ConvBlock(
        (conv_unit0): ConvUnit(
          (norm): BatchNorm3d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv): Conv3d(1, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same, padding_mode=replicate)
          (relu): LeakyReLU(negative_slope=0.01, inplace=True)
        )
        (conv_unit1): ConvUnit(
          (norm): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv): Conv3d(4, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same, padding_mode=replicate)
          (relu): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
    )
    (level1): EncoderBlock(
      (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv_block): ConvBlock(
        (conv_unit0): ConvUnit(
          (norm): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [9]:
class LinearElasticPDE(torch_fenics.FEniCSModule):
    
    def __init__(self, mesh):
        super().__init__()
        self.mesh = mesh
        self.S = fe.FunctionSpace(mesh, 'P', 1)
        self.V = fe.VectorFunctionSpace(mesh, 'P', 1)
        
    def __repr__(self):
        return f'{type(self).__name__}({self.mesh})'
        
    def input_templates(self):
        scalar_f = fa.Function(self.S)
        vector_f = fa.Function(self.V)
        return vector_f, scalar_f, scalar_f
    
    def solve(self, u_true, mu, rho):

        # define physical parameters
        g  = 9.8e-3 # gravitational acc (mm/s^2)
        nu = 0.4    # Poisson's ratio (unitless)

        # Lame's first parameter (Pa)
        lam = 2*mu*nu/(1 - 2*nu)

        # set displacement boundary condition
        u_bc = fa.DirichletBC(self.V, u_true, 'on_boundary')

        # body force and traction
        #b = fe.as_vector([0, rho*g, 0])
        b = fa.Constant([0, 0, 0])
        t = fa.Constant([0, 0, 0])

        # define stress and strain
        def epsilon(u):
            return (fe.grad(u) + fe.grad(u).T) / 2

        def sigma(u):
            I = fe.Identity(u.geometric_dimension())
            return lam*fe.div(u)*I + 2*mu*epsilon(u)

        # weak formulation
        u = fe.TrialFunction(self.V)
        v = fe.TestFunction(self.V)

        a = fe.inner(sigma(u), epsilon(v)) * fe.dx
        L = fe.dot(b, v)*fe.dx + fe.dot(t, v)*fe.dx

        u_pred = fa.Function(self.V)
        fa.solve(a == L, u_pred, u_bc)

        return u_pred


In [10]:
def as_xarray(a, dims=None, coords=None, name=None):
    if isinstance(a, torch.Tensor):
        a = a.detach().cpu().numpy()
    if dims is None:
        dims = [f'dim{i}' for i in range(a.ndim)]
    if coords is None:
        coords = {d: np.arange(a.shape[i]) for i, d in enumerate(dims)}
    return xr.DataArray(a, dims=dims, coords=coords, name=name)

#project.visual.view(as_xarray(output_image[0], dims=['component', 'x', 'y', 'z']), cmap='seismic')

In [11]:
anat_image, u_true_image, mask, mesh, resolution, example_name = example
print(example_name)
print(anat_image)

project.visual.view(as_xarray(anat_image, dims=['channel', 'x', 'y', 'z'], name='CT')).update_index(channel=0, z=45)

case1_T00.nii
Tensor(shape=(1, 256, 256, 94), μ=-487.4489, σ=477.8683, #nan=0, dtype=torch.float32, device=cuda:0)


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='channel', options=((0, 0),), value=0), SelectionSlider(desc…

In [12]:
mu_pred_image = model.forward(anat_image.unsqueeze(0))[0]
mu_pred_image = torch.nn.functional.softplus(mu_pred_image) * 1000
print(mu_pred_image)

project.visual.view(as_xarray(mu_pred_image * mask, dims=['channel', 'x', 'y', 'z'], name='mu'), vmax=1e4).update_index(channel=0, z=45)

  return F.conv3d(


Tensor(shape=(1, 256, 256, 94), μ=788.2083, σ=86.8659, #nan=0, dtype=torch.float32, device=cuda:0)


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='channel', options=((0, 0),), value=0), SelectionSlider(desc…

In [13]:
rho_image = (1 + anat_image/1000) * 1000
print(rho_image)

project.visual.view(as_xarray(rho_image * mask, dims=['channel', 'x', 'y', 'z']), cmap='Greys_r', vmin=0, vmax=1000).update_index(channel=0, z=45)

Tensor(shape=(1, 256, 256, 94), μ=512.5510, σ=477.8683, #nan=0, dtype=torch.float32, device=cuda:0)


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='channel', options=((0, 0),), value=0), SelectionSlider(desc…

In [14]:
project.visual.view(as_xarray(u_true_image * mask, dims=['channel', 'x', 'y', 'z'])).update_index(channel=2, z=45)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='channel', options=((0, 0), (1, 1), (2, 2)), value=0), Selec…

In [15]:
pde = LinearElasticPDE(mesh)

In [17]:
u_true_dofs = project.interpolate.image_to_dofs(u_true_image, resolution, pde.V).cpu()
u_true_dofs

Tensor(shape=(346, 3), μ=0.0055, σ=0.1595, #nan=0, dtype=torch.float64, device=cpu)

In [18]:
mu_pred_dofs = project.interpolate.image_to_dofs(mu_pred_image, resolution, pde.S).cpu()
mu_pred_dofs

Tensor(shape=(346,), μ=804.6316, σ=80.7819, #nan=0, dtype=torch.float64, device=cpu)

In [19]:
rho_dofs = project.interpolate.image_to_dofs(rho_image, resolution, pde.S).cpu()
rho_dofs

Tensor(shape=(346,), μ=515.6650, σ=211.0198, #nan=0, dtype=torch.float64, device=cpu)

In [None]:
u_pred_dofs = pde.forward(
    u_true_dofs.unsqueeze(0),
    mu_pred_dofs.unsqueeze(0),
    rho_dofs.unsqueeze(0),
)[0]

u_pred_dofs

In [None]:
%%time
u_pred_image = project.interpolate.dofs_to_image(u_pred_dofs, pde.V, u_true_image.shape[-3:], resolution)
u_pred_image

In [None]:
project.visual.view(as_xarray(u_pred_image * mask.permute(1,2,3,0).detach().cpu().numpy(), dims=['x', 'y', 'z', 'c'])).update_index(c=0, z=45)

In [None]:
def compute_norm(u):
    u_norm2 = (u**2).sum(dim=-1)
    return torch.mean(u_norm2)

def compute_loss(u_pred, u_true, eps=1e-8):
    u_diff = (u_pred - u_true)
    u_diff_norm2 = (u_diff**2).sum(dim=-1)
    u_true_norm2 = (u_true**2).sum(dim=-1) + eps
    return torch.mean(u_diff_norm2 / u_true_norm2)

compute_loss(u_pred_dofs, u_true_dofs)

In [None]:
class Trainer(object):
    
    def __init__(self, model, dataset, batch_size, learning_rate):
        self._model = model
        self.train_loader = torch.utils.data.DataLoader(
            dataset, batch_size, shuffle=True, collate_fn=collate_fn
        )
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.epoch = 0
        
        index_cols = ['epoch', 'batch', 'example', 'phase']
        self.metrics = pd.DataFrame(columns=index_cols)
        self.metrics.set_index(index_cols, inplace=True)
        
    @property
    def model(self):
        return self._model
    
    @property
    def dataset(self):
        return self.train_loader.dataset
        
    @property
    def batch_size(self):
        return self.train_loader.batch_sampler.batch_size
    
    @property
    def learning_rate(self):
        return self.optimizer.param_groups[0]['lr']
    
    def __repr__(self):
        if self.epoch > 0:
            loss = self.metrics.loc[self.epoch, 'loss'].mean()
        else:
            loss = None
        return f'{type(self).__name__}(epoch={self.epoch}, loss={loss})'
        
    def train(self, num_epochs):
        
        start_epoch = self.epoch
        stop_epoch = self.epoch + num_epochs

        print('Training...')
        for i in range(start_epoch, stop_epoch):
            print(f'Epoch {i+1}/{stop_epoch}')
            
            for j, batch in enumerate(self.train_loader):
                anat_image, u_true_image, mask, mesh, resolution, example_name = batch

                # predict elasticity from anatomical image
                mu_pred_image = self.model.forward(anat_image)
                mu_pred_image = torch.nn.functional.softplus(mu_pred_image) * 1000
                rho_image = (1 + anat_image/1000) * 1000

                # physical FEM simulation
                loss = 0
                for k in range(self.batch_size):          
                    pde = LinearElasticPDE(mesh[k])

                    # convert tensors to FEM basis coefficients
                    u_true_dofs = project.interpolate.image_to_dofs(u_true_image[k], resolution[k], pde.V).cpu()
                    mu_pred_dofs = project.interpolate.image_to_dofs(mu_pred_image[k], resolution[k], pde.S).cpu()
                    rho_dofs = project.interpolate.image_to_dofs(rho_image[k], resolution[k], pde.S).cpu()
    
                    # solve FEM for simulated displacement coefficients
                    u_pred_dofs = pde.forward(
                        u_true_dofs.unsqueeze(0),
                        mu_pred_dofs.unsqueeze(0),
                        rho_dofs.unsqueeze(0),
                    )[0]
    
                    # compare to true displacement coefficients
                    loss_k = compute_loss(u_pred_dofs, u_true_dofs)
                    loss += loss_k
                
                    # compute additional metrics
                    key = (i+1, j+1, example_name[k], 'train')
                    self.metrics.loc[key, 'loss'] = loss_k.item()
                    self.metrics.loc[key, 'mu_pred_norm'] = compute_norm(mu_pred_dofs).item()
                    self.metrics.loc[key, 'u_pred_norm'] = compute_norm(u_pred_dofs).item()
                    self.metrics.loc[key, 'u_true_norm'] = compute_norm(u_true_dofs).item()
            
                loss /= self.batch_size
                print(f'{example_name} loss = {loss:.4f}')
                
                loss.backward()
                self.optimizer.step()
            
            self.epoch += 1
    
    def test(self, example):
        anat_image, u_true_image, mask, mesh, resolution, example_name = example
        
        # predict elasticity from anatomical image
        mu_pred_image = self.model.forward(anat_image.unsqueeze(0))[0]
        mu_pred_image = torch.nn.functional.softplus(mu_pred_image) * 1000
        rho_image = (1 + anat_image/1000) * 1000
        
        # physical FEM simulation
        pde = LinearElasticPDE(mesh)
        
        # convert tensors to FEM basis coefficients
        u_true_dofs = project.interpolate.image_to_dofs(u_true_image, resolution, pde.V).cpu()
        mu_pred_dofs = project.interpolate.image_to_dofs(mu_pred_image, resolution, pde.S).cpu()
        rho_dofs = project.interpolate.image_to_dofs(rho_image, resolution, pde.S).cpu()

        # solve FEM for simulated displacement coefficients
        u_pred_dofs = pde.forward(
            u_true_dofs.unsqueeze(0),
            mu_pred_dofs.unsqueeze(0),
            rho_dofs.unsqueeze(0),
        )[0]

        # compare to true displacement coefficients
        loss = compute_loss(u_pred_dofs, u_true_dofs)
        
        # convert simulated displacement field to image domain      
        u_pred_image = project.interpolate.dofs_to_image(u_pred_dofs, pde.V, u_true_image.shape[-3:], resolution)
        u_pred_image = torch.as_tensor(u_pred_image).permute(3,0,1,2)
        
        return mu_pred_image, u_pred_image

trainer = Trainer(model, dataset, batch_size=4, learning_rate=1e-5)
trainer

In [None]:
%%time
trainer.train(100)

In [None]:
trainer.metrics

In [None]:
trainer.metrics.groupby('epoch').mean().reset_index().plot(y='loss', x='epoch')

In [None]:
trainer.metrics.groupby('epoch').mean().reset_index().plot(y='mu_pred_norm', x='epoch')

In [None]:
trainer.metrics.groupby('epoch').mean().reset_index().plot(y='u_pred_norm', x='epoch')

In [None]:
trainer.metrics.groupby('epoch').mean().reset_index().plot(y='u_true_norm', x='epoch')

In [None]:
mu_pred_image, u_pred_image = trainer.test(dataset[0])
mu_pred_image

In [None]:
u_pred_image

In [None]:
project.visual.view(as_xarray(u_true_image * mask, dims=['c', 'x', 'y', 'z'], name='u'))

In [None]:
anat_image, u_true_image, mask, mesh, resolution, example_name = dataset[0]
shape = tuple(anat_image.shape[1:])
shape, resolution

In [None]:
project.visual.view(as_xarray(
    u_pred_image * mask.cpu(),
    dims=['c', 'x', 'y', 'z'],
    coords={
        'c': ['x', 'y', 'z'],
        'x': np.arange(shape[0]) * resolution[0],
        'y': np.arange(shape[1]) * resolution[1],
        'z': np.arange(shape[2]) * resolution[2],
    },
    name='u'
), y='z')

In [None]:
anat_image = dataset[0][0]
mu_pred_image = trainer.model.forward(anat_image.unsqueeze(0))[0]
mu_pred_image = torch.nn.functional.softplus(mu_pred_image) * 1000

In [None]:
project.visual.view(as_xarray(
    mu_pred_image * mask,
    dims=['c', 'x', 'y', 'z'],
    coords={
        'c': [0],
        'x': np.arange(shape[0]) * resolution[0],
        'y': np.arange(shape[1]) * resolution[1],
        'z': np.arange(shape[2]) * resolution[2],
    },
    name='mu'
), y='z')

In [None]:
mu_pred_image = mu_pred_image.detach()
mu_pred_image.requires_grad = True
mu_pred_image

In [None]:
pde = LinearElasticPDE(mesh)

mu_pred_dofs = project.interpolate.image_to_dofs(mu_pred_image, resolution, pde.S)
mu_pred_dofs

In [None]:
L = mu_pred_dofs.sum()
L

In [None]:
L.backward()

In [None]:
project.visual.view(as_xarray(
    mu_pred_image.grad,
    dims=['c', 'x', 'y', 'z'],
    coords={
        'c': [0],
        'x': np.arange(shape[0]) * resolution[0],
        'y': np.arange(shape[1]) * resolution[1],
        'z': np.arange(shape[2]) * resolution[2],
    }
), y='z', cmap='seismic')