In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
%pwd

'/mnt/c/Users/mtr22/code/mre-pinn/notebooks'

In [2]:
import sys, os
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt

os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

sys.path.append('..')
%aimport mre_pinn

  from .autonotebook import tqdm as notebook_tqdm
Using backend: pytorch

Loading /mnt/c/Users/mtr22/code/mre-pinn/notebooks/../mre_pinn/__init__.py


In [3]:
# load the MATLAB file
data_root = '../data/BIOQIC'
mat_base = 'four_target_phantom.mat'

mat_file = data_root + '/' + mat_base
data, rev_axes = mre_pinn.data.load_mat_data(mat_file, verbose=True)
rev_axes

Loading ../data/BIOQIC/four_target_phantom.mat
    __header__: <class 'bytes'>
    __version__: <class 'str'>
    __globals__: <class 'list'>
    u_ft: <class 'numpy.ndarray'> (100, 80, 10, 3, 6) complex128


True

In [4]:
# convert to xarray and add metadata
u_true = data['u_ft'].T
u_dims = ['frequency', 'component', 'z', 'x', 'y']
u_coords = {
    'frequency': 50 + np.arange(u_true.shape[0]) * 10, # Hz
    'x': np.arange(u_true.shape[3]) * 1e-3, # m
    'y': np.arange(u_true.shape[4]) * 1e-3, # m
    'z': np.arange(u_true.shape[2]) * 1e-3, # m
    'component': ['y', 'x', 'z'],
}
u_true = xr.DataArray(u_true, dims=u_dims, coords=u_coords)
u_true = u_true.transpose('frequency', 'x', 'y', 'z', 'component')

# downsampling
ds = 1
u_true = u_true.coarsen(x=ds, y=ds, z=ds).mean()

# single frequency 2D
u_true = u_true.sel(frequency=[60], z=0, component=['x', 'y'])

# single frequency 3D
#u_true = u_true.sel(frequency=[60], component=['x', 'y', 'z'])

# multifrequency 3D
#u_true = u_true.sel(component=['x', 'y', 'z'])

print(np.prod(u_true.shape))
u_true

16000


In [8]:
# display wave field
w_map = mre_pinn.visual.wave_color_map()
w_max = 0.001
wave_kws = dict(cmap=w_map, vmin=-w_max, vmax=w_max)

mre_pinn.visual.NDArrayViewer(u_true.real, dpi=50/ds, **wave_kws)

<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x7f79313052b0>

In [123]:
%autoreload

batch_size = 128 # collocation points
num_domain = 128 # PDE domain samples

# convert to point set boundary condition
bc = mre_pinn.data.NDArrayBC(u_true, batch_size=batch_size)
x = bc.points
x.shape, bc.values.shape, x.shape[0] / batch_size

((8000, 3), torch.Size([8000, 2]), 62.5)

In [124]:
# for single frequency, we need to add eps to the geometry range to avoid zero volume
eps = np.array([1e-5] + [0] * (x.shape[1] - 1))
x.max(axis=0) - x.min(axis=0) + eps

array([1.00000000e-05, 7.90000036e-02, 9.89999995e-02])

In [125]:
%autoreload

# set up PDE with geometry and boundary condition
residual = mre_pinn.pde.HelmholtzPDE(detach=False)
geometry = deepxde.geometry.Hypercube(x.min(axis=0) - eps/2, x.max(axis=0) + eps/2)
pde = deepxde.data.PDE(geometry, residual, bc, num_domain=num_domain)
pde.train_x.shape



(256, 3)

In [126]:
%autoreload

# initialize neural network
parallel = False
if parallel:
    net = mre_pinn.model.Parallel([
        mre_pinn.model.PINN(
            n_input=x.shape[1],
            n_layers=5,
            n_hidden=64,
            n_output=n_output,
            activ_fn=torch.sin,
            complex=True,
            dense=True,
            transform=False,
            omega0=8
        ) for n_output in [u_true.shape[-1], 1] # u and mu
    ])
else:
    net = mre_pinn.model.PINN(
        n_input=x.shape[1],
        n_layers=5,
        n_hidden=128,
        n_output=u_true.shape[-1] + 1, # u an dmu
        activ_fn=torch.sin,
        complex=True,
        dense=True,
        transform=False,
        omega0=32
    )

# standardize inputs to [-1, 1]
x = bc.points
x_loc = np.mean(x, axis=0)
x_scale = (np.max(x, axis=0) - np.min(x, axis=0)) / 2

if x_scale[0] == 0: # single frequency check
    x_scale[0] = 1

# normalize outputs
u = bc.values.cpu().numpy()
u = np.concatenate([u.real, u.imag], axis=1)
u_loc = np.mean(u, axis=0)
u_scale = np.std(u, axis=0)
print(u_scale.shape)

mu_loc = [5, 5]
mu_scale = [5, 5]

# initialize model weights
if parallel:
    net[0].init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=u_loc, output_scale=u_scale)
    net[1].init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=mu_loc, output_scale=mu_scale)
else:
    output_loc = np.append(u_loc, mu_loc)
    output_scale = np.append(u_scale, mu_scale)
    net.init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=output_loc, output_scale=output_scale)

net

(4,)


PINN(
  (linear0): Linear(in_features=3, out_features=128, bias=True)
  (linear1): Linear(in_features=131, out_features=128, bias=True)
  (linear2): Linear(in_features=259, out_features=128, bias=True)
  (linear3): Linear(in_features=387, out_features=128, bias=True)
  (linear4): Linear(in_features=515, out_features=6, bias=True)
)

In [127]:
%autoreload

def msae_loss(y_true, y_pred):
    '''
    Mean squared absolute error.
    '''
    error = y_true - y_pred
    return torch.mean(torch.abs(error)**2)

model = deepxde.Model(pde, net)
model.compile(
    optimizer='adam',
    lr=1e-3,
    loss=msae_loss,
    loss_weights=[1e-10, 100]
)

Compiling model...
'compile' took 0.000140 s



In [128]:
deepxde.display.training_display = mre_pinn.visual.TrainingPlot()
try:
    model.train(1000000, display_every=10)
except KeyboardInterrupt as e:
    print('Interrupt', file=sys.stderr)

Training model...



<IPython.core.display.Javascript object>

Interrupt


In [138]:
model.train_state.loss_train

array([7.75794347e-13, 1.00523031e-11])

In [139]:
# model predictions
x = bc.points
outputs = model.predict(x)
u_pred  = outputs[:,:-1].reshape(u_true.shape)
mu_pred = outputs[:,-1:].reshape(u_true.shape[:-1])
lu_pred = model.predict(
    x, operator=lambda x, y: mre_pinn.pde.laplacian(y[:,:-1], x, dim=1)
).reshape(u_true.shape)

In [140]:
%autoreload

# display wave field
w_map = mre_pinn.visual.wave_color_map()
w_max = 0.001
wave_kws = dict(cmap=w_map, vmin=-w_max, vmax=w_max)

mre_pinn.visual.NDArrayViewer(u_true.real, labels=u_true.dims, dpi=50/ds, **wave_kws)
mre_pinn.visual.NDArrayViewer(u_pred.real, labels=u_true.dims, dpi=50/ds, **wave_kws)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x7f78fb3c2160>

In [134]:
# compute discrete laplacian of wave field

def discrete_laplacian(u, resolution=1, dim=0):
    '''
    Discrete Laplacian operator.
    '''
    components = []
    grad = np.gradient
    for i in range(u.shape[-1]):
        component = 0
        for j in range(dim, len(u.shape) - 1):
            component += grad(grad(u[...,i], axis=j), axis=j) / resolution**2
        components.append(component)
    return np.stack(components, axis=-1)

dx = 1e-3
Lu_true = discrete_laplacian(u_true, resolution=dx, dim=1)
Lu_pred = discrete_laplacian(u_pred, resolution=dx, dim=1)

# display wave field laplacians
L_max = 10 #1e3
laplace_kws = dict(cmap=w_map, vmin=-L_max, vmax=L_max)

mre_pinn.visual.NDArrayViewer(Lu_true.real, labels=u_true.dims, dpi=50/ds, **laplace_kws)
#mre_pinn.visual.NDArrayViewer(Lu_pred.real, labels=u_true.dims, dpi=50/ds, **laplace_kws)
mre_pinn.visual.NDArrayViewer(lu_pred.real, labels=u_true.dims, dpi=50/ds, **laplace_kws)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x7f78f9c14970>

In [135]:
from scipy.ndimage import gaussian_filter

# compare elastograms

def direct_inversion(u, Lu, omega, rho=1):
    return (-rho * (2*np.pi*omega)**2 * u / Lu).mean(axis=-1)

omega = float(u_true.frequency)
mu_data = direct_inversion(u_true, Lu_true, omega)
mu_u_Lu = direct_inversion(u_pred, Lu_pred, omega)
mu_u_lu = direct_inversion(u_pred, lu_pred, omega)

# display elastogram
e_map = mre_pinn.visual.elast_color_map()
e_max = 25
elast_kws = dict(cmap=e_map, vmin=0, vmax=e_max)

mre_pinn.visual.NDArrayViewer(mu_data.real, labels=u_true.dims[:-1], dpi=50/ds, **elast_kws)
#mre_pinn.visual.NDArrayViewer(mu_u_Lu.real, labels=u_true.dims[:-1], dpi=50/ds, **elast_kws)
#mre_pinn.visual.NDArrayViewer(mu_u_lu.real, labels=u_true.dims[:-1], dpi=50/ds, **elast_kws)
mre_pinn.visual.NDArrayViewer(mu_pred.real, labels=u_true.dims[:-1], dpi=50/ds, **elast_kws)

<IPython.core.display.Javascript object>

  slider = matplotlib.widgets.Slider(


<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x7f78f9c0bfd0>