In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

v003.ib.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/MRE-PINN/notebooks


In [2]:
import sys, os
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt

os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

sys.path.append('..')
%aimport mre_pinn

torch.cuda.is_available()

  from .autonotebook import tqdm as notebook_tqdm
Using backend: pytorch

Loading /ocean/projects/asc170022p/mtragoza/MRE-PINN/notebooks/../mre_pinn/__init__.py


True

In [100]:
# load the MATLAB file
data_root = '../data/BIOQIC'
mat_base = 'four_target_phantom.mat'

mat_file = data_root + '/' + mat_base
data, rev_axes = mre_pinn.data.load_mat_data(mat_file, verbose=True)
rev_axes

Loading ../data/BIOQIC/four_target_phantom.mat
    __header__: <class 'bytes'>
    __version__: <class 'str'>
    __globals__: <class 'list'>
    u_ft: <class 'numpy.ndarray'> (100, 80, 10, 3, 6) complex128


True

In [408]:
# convert to xarray and add metadata
u_true = data['u_ft'].T
u_dims = ['frequency', 'component', 'z', 'x', 'y']
u_coords = {
    'frequency': 50 + np.arange(u_true.shape[0]) * 10, # Hz
    'x': np.arange(u_true.shape[3]) * 1, # mm
    'y': np.arange(u_true.shape[4]) * 1, # mm
    'z': np.arange(u_true.shape[2]) * 1, # mm
    'component': ['y', 'x', 'z'],
}
u_true = xr.DataArray(u_true, dims=u_dims, coords=u_coords) # mm?
u_true = u_true.transpose('frequency', 'x', 'y', 'z', 'component')

# downsampling
ds = 1
u_true = u_true.coarsen(x=ds, y=ds, z=ds).mean()

# single frequency 2D
u_true = u_true.sel(frequency=[60], z=0, component=['x', 'y'])

# single frequency 3D
#u_true = u_true.sel(frequency=[60], component=['x', 'y', 'z'])

# multifrequency 3D
#u_true = u_true.sel(component=['x', 'y', 'z'])

print(np.prod(u_true.shape))
u_true

16000


In [172]:
# display true wave field
w_map = mre_pinn.visual.wave_color_map()
w_max = 0.001
wave_kws = dict(cmap=w_map, vmin=-w_max, vmax=w_max)

mre_pinn.visual.NDArrayViewer(u_true.real, dpi=50/ds, **wave_kws)
mre_pinn.visual.NDArrayViewer(u_true.imag, dpi=50/ds, **wave_kws)

<IPython.core.display.Javascript object>

  #slider.poly.set_visible(False)


<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x154cc490d220>

In [321]:
%autoreload

batch_size = 128 # collocation points
num_domain = 128 # PDE domain samples

# convert to point set boundary condition
bc = mre_pinn.data.NDArrayBC(u_true, batch_size=batch_size)
x = bc.points
x.shape, bc.values.shape, x.shape[0] / batch_size

((8000, 3), torch.Size([8000, 2]), 62.5)

In [322]:
%autoreload

# set up PDE with geometry and boundary condition
residual = mre_pinn.pde.HelmholtzPDE(detach=False)

# for single frequency, we need to add eps to the geometry range to avoid zero volume
eps = np.zeros(x.shape[1])
eps[0] = 1e-5
geometry = deepxde.geometry.Hypercube(x.min(axis=0), x.max(axis=0) + eps)

pde = deepxde.data.PDE(geometry, residual, bc, num_domain=num_domain)
pde.train_x.shape



(256, 3)

In [384]:
%autoreload

# define model architecture
parallel = False

if parallel:
    net = mre_pinn.model.Parallel([
        mre_pinn.model.PINN(
            n_input=x.shape[1],
            n_layers=5,
            n_hidden=8,
            n_output=n_output,
            activ_fn=torch.sin,
            complex=True,
            dense=True,
            omega0=8
        ) for n_output in [u_true.shape[-1], 1] # u and mu
    ])
else:
    net = mre_pinn.model.PINN(
        n_input=x.shape[1],
        n_layers=5,
        n_hidden=128,
        n_output=u_true.shape[-1] + 1, # u and mu
        activ_fn=torch.sin,
        complex=True,
        dense=True,
        omega0=8
    )
net

PINN(
  (linear0): Linear(in_features=3, out_features=128, bias=True)
  (linear1): Linear(in_features=131, out_features=128, bias=True)
  (linear2): Linear(in_features=259, out_features=128, bias=True)
  (linear3): Linear(in_features=387, out_features=128, bias=True)
  (linear4): Linear(in_features=515, out_features=6, bias=True)
)

In [430]:
import seaborn as sns

# establish data and model weight distribution

# standardize inputs to [-1, 1]
x = bc.points
x_loc = np.mean(x, axis=0)
x_scale = (np.max(x, axis=0) - np.min(x, axis=0)) / 2

if x_scale[0] == 0: # avoid division by zero for single frequency
    x_scale[0] = 1

# normalize outputs with mean and std
u = mre_pinn.model.as_real(bc.values).cpu().numpy()
u_loc = np.mean(u, axis=0)
u_scale = np.std(u, axis=0)
print(u_loc.shape)

mu_loc = [5e6, 5e6]
mu_scale = [5e6, 5e6]

output_loc = np.append(u_loc, mu_loc)
output_scale = np.append(u_scale, mu_scale)

# initialize model weights
if parallel:
    net[0].init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=u_loc, output_scale=u_scale)
    net[1].init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=mu_loc, output_scale=mu_scale)
else:
    net.init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=output_loc, output_scale=output_scale)
    print(output_loc.shape)

# investigate data and model output distribution

outputs = net.forward(torch.as_tensor(x).requires_grad_(False)).cpu().detach().numpy()
u_pred  = outputs[:,:-1]
u_pred  = np.stack([u_pred.real, u_pred.imag], axis=-1).reshape(-1, 2*u_pred.shape[1])
mu_pred = outputs[:,-1:]
print(outputs.shape, u_pred.shape)

fig, axes = mre_pinn.visual.subplot_grid(3, 2, 2, 3, space=0.3, pad=[0.9, 0.4, 0.5, 0.4])

# x
sns.histplot(x[:,1:], bins=20, label='x', ax=axes[0,0])
sns.histplot((x[:,1:] - x_loc[1:]) / x_scale[1:], bins=20, label='x', ax=axes[0,1])

# u_pred
sns.histplot(u_pred[:,:], bins=20, ax=axes[1,0])
sns.histplot((u_pred[:,:] - u_loc[:1]) / u_scale[:], bins=20, ax=axes[1,1])

# u_true
sns.histplot(u[:,:], bins=20, ax=axes[2,0])
sns.histplot((u[:,:] - u_loc[:]) / u_scale[:], bins=20, ax=axes[2,1])

u_pred.shape

(4,)
(6,)
(8000, 3) (8000, 4)


<IPython.core.display.Javascript object>

(8000, 4)

In [431]:
u_norm = torch.norm(bc.values, dim=-1).mean()
hh_norm = (2*np.pi*60)**2 * u_norm
u_norm, hh_norm

(tensor(0.0011, dtype=torch.float64), tensor(154.4801, dtype=torch.float64))

In [455]:
%autoreload

def my_loss(norm):
    '''
    Mean relative L2 error.
    '''
    def loss(y_true, y_pred):
        return torch.mean(
            torch.norm(y_true - y_pred) / norm
        )
    return loss

pde_loss = my_loss(hh_norm)
data_loss = my_loss(u_norm)

def u_pred_norm(y_true, y_pred):
    '''
    Mean squared norm of predicted displacements.
    '''
    return np.mean(np.linalg.norm(y_pred[:,:-1], axis=-1))

def mu_pred_norm(y_true, y_pred):
    '''
    Mean squared norm of predicted stiffness.
    '''
    return np.mean(np.linalg.norm(y_pred[:,-1:], axis=-1))

model = deepxde.Model(pde, net)
model.compile(
    optimizer='adam',
    lr=1e-5,
    loss_weights=[1, 100],
    loss=[pde_loss, data_loss],
    metrics=[mu_pred_norm, u_pred_norm],
)

Compiling model...
'compile' took 0.000162 s



In [None]:
%autoreload
deepxde.display.training_display = mre_pinn.visual.TrainingPlot(
    losses=['pde_loss', 'data_loss'],
    metrics=['mu_norm', 'u_norm']
)
try:
    model.train(1000000, display_every=10)
except KeyboardInterrupt as e:
    print('Interrupt', file=sys.stderr)

Training model...



<IPython.core.display.Javascript object>

In [None]:
# model predictions
x = bc.points
outputs = model.predict(x)
u_pred  = outputs[:,:-1].reshape(u_true.shape)
mu_pred = outputs[:,-1:].reshape(u_true.shape[:-1])
lu_pred = model.predict(
    x, operator=lambda x, y: mre_pinn.pde.laplacian(y[:,:-1], x, dim=1)
).reshape(u_true.shape)

In [450]:
%autoreload

# display wave field
w_map = mre_pinn.visual.wave_color_map()
w_max = 0.001
wave_kws = dict(cmap=w_map, vmin=-w_max, vmax=w_max)

mre_pinn.visual.NDArrayViewer(u_true.real, labels=u_true.dims, dpi=50/ds, **wave_kws)
mre_pinn.visual.NDArrayViewer(u_pred.real, labels=u_true.dims, dpi=50/ds, **wave_kws)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x154d0c9ee8b0>

In [451]:
# compute discrete laplacian of wave field

def discrete_laplacian(u, resolution=1, dim=0):
    '''
    Discrete Laplacian operator.
    '''
    components = []
    grad = np.gradient
    for i in range(u.shape[-1]):
        component = 0
        for j in range(dim, len(u.shape) - 1):
            component += grad(grad(u[...,i], axis=j), axis=j) / resolution**2
        components.append(component)
    return np.stack(components, axis=-1)

dx = 1
Lu_true = discrete_laplacian(u_true, resolution=dx, dim=1)
Lu_pred = discrete_laplacian(u_pred, resolution=dx, dim=1)

# display wave field laplacians
L_max = 50e-6
laplace_kws = dict(cmap=w_map, vmin=-L_max, vmax=L_max)

mre_pinn.visual.NDArrayViewer(Lu_true.real, labels=u_true.dims, dpi=50/ds, **laplace_kws)
#mre_pinn.visual.NDArrayViewer(Lu_pred.real, labels=u_true.dims, dpi=50/ds, **laplace_kws)
mre_pinn.visual.NDArrayViewer(lu_pred.real, labels=u_true.dims, dpi=50/ds, **laplace_kws)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x154cb28f42e0>

In [452]:
from scipy.ndimage import gaussian_filter

# compare elastograms

def direct_inversion(u, Lu, omega, rho=1):
    return (-rho * (2*np.pi*omega)**2 * u / Lu).mean(axis=-1)

omega = float(u_true.frequency)
mu_data = direct_inversion(u_true, Lu_true, omega)
mu_u_Lu = direct_inversion(u_pred, Lu_pred, omega)
mu_u_lu = direct_inversion(u_pred, lu_pred, omega)

# display elastogram
e_map = mre_pinn.visual.elast_color_map()
e_max = 25e6
elast_kws = dict(cmap=e_map, vmin=0, vmax=e_max)

mre_pinn.visual.NDArrayViewer(mu_data.real, labels=u_true.dims[:-1], dpi=50/ds, **elast_kws)
#mre_pinn.visual.NDArrayViewer(mu_u_Lu.real, labels=u_true.dims[:-1], dpi=50/ds, **elast_kws)
#mre_pinn.visual.NDArrayViewer(mu_u_lu.real, labels=u_true.dims[:-1], dpi=50/ds, **elast_kws)
mre_pinn.visual.NDArrayViewer(mu_pred.real, labels=u_true.dims[:-1], dpi=50/ds, **elast_kws)

<IPython.core.display.Javascript object>

  slider = matplotlib.widgets.Slider(


<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x154c9fd0f3a0>