In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

v016.ib.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/MRE-PINN/notebooks


In [2]:
import sys, os
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt
import tqdm

os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

sys.path.append('..')
%aimport mre_pinn

torch.cuda.is_available()

  from .autonotebook import tqdm as notebook_tqdm
Using backend: pytorch

Loading /ocean/projects/asc170022p/mtragoza/MRE-PINN/notebooks/../mre_pinn/__init__.py


True

In [60]:
# load the true wave image
data_root = '../data/BIOQIC'
wave_base = 'four_target_phantom.mat'
wave_file = data_root + '/' + wave_base
wave_data, _ = mre_pinn.data.load_mat_data(wave_file)

# convert to xarray and add metadata
u_true = wave_data['u_ft'].T
u_dims = ['frequency', 'component', 'z', 'x', 'y']
dx = 1e-3 # spatial resolution in meters
u_coords = {
    'frequency': np.linspace(50, 100, u_true.shape[0]), # Hz
    'x': np.arange(u_true.shape[3]) * dx,
    'y': np.arange(u_true.shape[4]) * dx,
    'z': np.arange(u_true.shape[2]) * dx,
    'component': ['y', 'x', 'z'],
}
u_true = xr.DataArray(u_true, dims=u_dims, coords=u_coords) * dx
u_true = u_true.transpose('frequency', 'x', 'y', 'z', 'component')

# downsampling
ds = 1
u_true = u_true.coarsen(x=ds, y=ds, z=ds).mean()

freq = 80
z = 0
y = 75 * dx

# single frequency 1D
u_true = u_true.sel(frequency=[freq], y=y, z=z, component=['z'])

# single frequency 2D
#u_true = u_true.sel(frequency=[freq], z=0, component=['x', 'y'])

# single frequency 3D
#u_true = u_true.sel(frequency=[freq], component=['x', 'y', 'z'])

# multifrequency 3D
#u_true = u_true.sel(component=['x', 'y', 'z'])

print(np.prod(u_true.shape))
u_true

80


In [61]:
# load the true elastogram
elast_base = 'fem_box_ground_truth.npy'
elast_file = data_root + '/' + elast_base

print(f'Loading {elast_file}')
mu_true = np.load(elast_file)
print(mu_true.shape, mu_true.dtype)

# convert to xarray and add metadata
mu_dims = ['frequency', 'z', 'x', 'y']
mu_coords = {
    'frequency': np.linspace(50, 100, mu_true.shape[0]), # Hz
    'x': np.arange(mu_true.shape[2]) * dx / 4 - dx * 3 / 8,
    'y': np.arange(mu_true.shape[3]) * dx / 4 - dx * 3 / 8,
    'z': np.arange(mu_true.shape[1]) * dx,
}
mu_true = xr.DataArray(mu_true, dims=mu_dims, coords=mu_coords) # Pa
mu_true = mu_true.transpose('frequency', 'x', 'y', 'z')

# downsampling
mu_true = mu_true.coarsen(x=4*ds, y=4*ds, z=ds).mean()

# single frequency 1D
mu_true = mu_true.sel(frequency=[freq], y=y, z=z)

# single frequency 2D
#mu_true = mu_true.sel(frequency=[freq], z=z)

mu_true

Loading ../data/BIOQIC/fem_box_ground_truth.npy
(6, 10, 320, 400) complex128


In [70]:
%autoreload

# display true wave field, elastogram, and direct inversion result
w_map = mre_pinn.visual.wave_color_map()
w_max = 1e-5
wave_kws = dict(cmap=w_map, vmin=-w_max, vmax=w_max)

L_max = 1e0
laplace_kws = dict(cmap=w_map, vmin=-L_max, vmax=L_max)

e_map = mre_pinn.visual.elast_color_map()
e_max = 15e3
elast_kws = dict(cmap=e_map, vmin=0, vmax=e_max)

Lu_true = mre_pinn.discrete.laplacian(u_true, resolution=dx, dim=1)
Lu_true = xr.DataArray(Lu_true, dims=u_true.dims, coords=u_true.coords)

omega = u_true.frequency.to_numpy().reshape([(-1, 1)[i > 0] for i in range(u_true.ndim)])
mu_u_Lu = mre_pinn.discrete.helmholtz_inversion(u_true, Lu_true, omega)

#mre_pinn.visual.NDArrayViewer(u_true.real, dpi=50 / ds, **wave_kws)
mre_pinn.visual.NDArrayViewer(u_true.real, y=None, dpi=50 / ds, **wave_kws)
mre_pinn.visual.NDArrayViewer(Lu_true.real, y=None, dpi=50 / ds, **laplace_kws)
mre_pinn.visual.NDArrayViewer(mu_u_Lu.real, y=None, dpi=50 / ds, **elast_kws)
mre_pinn.visual.NDArrayViewer(mu_true.real, y=None, dpi=50 / ds, **elast_kws)

(1, 80, 1) ['frequency', 'x', 'component']
2 1


<IPython.core.display.Javascript object>

(1, 80, 1) ['frequency', 'x', 'component']
2 1


  ax,


<IPython.core.display.Javascript object>

(1, 80) ['frequency', 'x']
1 1


<IPython.core.display.Javascript object>

(1, 80) ['frequency', 'x']
1 1


<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x1530903408b0>

In [118]:
%autoreload

batch_size = 80 # training data points
num_domain = 48 # PDE domain samples

# create point set boundary condition
bc = mre_pinn.data.NDArrayBC(u_true, batch_size=batch_size)
x = bc.points
x.shape, bc.values.shape, x.shape[0] / batch_size

((80, 2), torch.Size([80, 1]), 1.0)

In [156]:
%autoreload

# set up PDE with geometry and boundary condition
wave_eq = mre_pinn.pde.WaveEquation(detach=True)

# for single frequency, we need to add eps to the geometry range to avoid zero volume
eps = np.zeros(x.shape[1])
eps[0] = 1e-5
geometry = deepxde.geometry.Hypercube(x.min(axis=0), x.max(axis=0) + eps)

pde = deepxde.data.PDE(geometry, wave_eq, bc, num_domain=num_domain)
train_x = np.array(pde.train_x)
train_x.shape



(128, 2)

In [357]:
%autoreload

# number of displacement components
n_u = u_true.shape[-1]
print(n_u)

# define model architecture
parallel = True
n_outputs = [n_u, 1] # u and mu
idxs = [0] + list(np.cumsum(n_outputs))

if parallel:
    net = mre_pinn.model.Parallel([
        mre_pinn.model.PINN(
            n_input=x.shape[1],
            n_layers=5,
            n_hidden=8,
            n_output=n_output,
            activ_fn=torch.sin,
            complex=True,
            dense=True,
            omega0=8
        ) for n_output in n_outputs
    ])
else:
    net = mre_pinn.model.PINN(
        n_input=x.shape[1],
        n_layers=10,
        n_hidden=128,
        n_output=sum(n_outputs),
        activ_fn=torch.sin,
        complex=True,
        dense=False,
        omega0=64
    )
net

1


Parallel(
  (0): PINN(
    (linear0): Linear(in_features=2, out_features=8, bias=True)
    (linear1): Linear(in_features=10, out_features=8, bias=True)
    (linear2): Linear(in_features=18, out_features=8, bias=True)
    (linear3): Linear(in_features=26, out_features=8, bias=True)
    (linear4): Linear(in_features=34, out_features=2, bias=True)
  )
  (1): PINN(
    (linear0): Linear(in_features=2, out_features=8, bias=True)
    (linear1): Linear(in_features=10, out_features=8, bias=True)
    (linear2): Linear(in_features=18, out_features=8, bias=True)
    (linear3): Linear(in_features=26, out_features=8, bias=True)
    (linear4): Linear(in_features=34, out_features=2, bias=True)
  )
)

In [358]:
import seaborn as sns

# establish data and model weight distributions

# standardize inputs to [-1, 1]
x = bc.points
x_loc = np.mean(x, axis=0)
x_scale = (np.max(x, axis=0) - np.min(x, axis=0)) / 2

if x_scale[0] == 0: # avoid division by zero for single frequency
    x_scale[0] = 1

# normalize outputs with mean and std
u = mre_pinn.model.as_real(bc.values).cpu().numpy()
u_loc = np.mean(u, axis=0)
u_scale = np.std(u, axis=0)
print(u_loc, u_scale)

Lu = mre_pinn.discrete.laplacian(u_true, resolution=dx, dim=1).reshape(-1, n_u)
Lu = mre_pinn.model.as_real(torch.as_tensor(Lu)).cpu().numpy()
Lu_loc = np.mean(Lu, axis=0)
Lu_scale = np.std(Lu, axis=0)
print(Lu_loc, Lu_scale)

mu = torch.as_tensor(mu_true.to_numpy()).reshape(-1, 1)
mu = mre_pinn.model.as_real(mu).cpu().numpy()
mu_loc   = np.mean(mu, axis=0)
mu_scale = np.std(mu, axis=0)
print(mu_loc, mu_scale)

output_loc = np.append(u_loc, Lu_loc)
output_scale = np.append(u_scale, Lu_scale)

# initialize model weights
if parallel:
    net[0].init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=u_loc, output_scale=u_scale)
    net[1].init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=mu_loc, output_scale=mu_scale)
else:
    net.init_weights(input_loc=x_loc, input_scale=x_scale, output_loc=output_loc, output_scale=output_scale)
    print(output_loc.shape)

# plot distributions of model input and output

outputs = net.forward(torch.as_tensor(train_x).requires_grad_(False)).cpu().detach().numpy()
u_pred = outputs[:,idxs[0]:idxs[1]]
mu_pred = outputs[:,idxs[1]:idxs[2]]

u_pred = np.stack([u_pred.real, u_pred.imag], axis=-1).reshape(-1, 2*u_pred.shape[1])
print(outputs.shape, u_pred.shape, mu_pred.shape)

fig, axes = mre_pinn.visual.subplot_grid(3, 2, 2, 3, space=0.3, pad=[0.9, 0.4, 0.5, 0.4])

# x
axes[0,0].set_title('x')
axes[0,1].set_title('standardized x')
sns.histplot(x[:,1:], bins=20, label='x', ax=axes[0,0])
sns.histplot((x[:,1:] - x_loc[1:]) / x_scale[1:], bins=20, label='x', ax=axes[0,1])

# u_pred
axes[1,0].set_title('u_pred')
axes[1,1].set_title('normalized u_pred')
sns.histplot(u_pred[:,:], bins=20, ax=axes[1,0])
sns.histplot((u_pred[:,:] - u_loc[:1]) / u_scale[:], bins=20, ax=axes[1,1])

# u_true
axes[2,0].set_title('u_true')
axes[2,1].set_title('normalized u_true')
sns.histplot(u[:,:], bins=20, ax=axes[2,0])
sns.histplot((u[:,:] - u_loc[:]) / u_scale[:], bins=20, ax=axes[2,1])

u_pred.shape

[ 1.27296286e-07 -9.91309226e-07] [4.08122861e-06 4.83839765e-06]
[0.03283572 0.02858284] [0.31549624 0.37491125]
[4688.50331458  502.65482457] [2.91687543e+03 3.41060513e-13]
(128, 2) (128, 2) (128, 1)


<IPython.core.display.Javascript object>

(128, 2)

In [359]:
%autoreload

# create normalized loss functions
u_norm = torch.norm(bc.values, dim=-1).mean().detach()
print(u_norm)

def normalized_L2_loss(norm):
    def loss_fn(y_true, y_pred):
        return torch.mean(
            torch.norm(y_true - y_pred, dim=-1) / norm
        )
    return loss_fn

loss = normalized_L2_loss(u_norm)

def predicted_L2_norm(idx, n):
    print(idx, n)
    def norm(y_true, y_pred):
        return np.mean(
            np.linalg.norm(y_pred[:,idx:idx+n], axis=-1)
        )
    return norm

metrics = [predicted_L2_norm(idxs[i], n) for i, n in enumerate(n_outputs)]

model = deepxde.Model(pde, net)
model.compile(
    optimizer='adam',
    lr=1e-3,
    loss_weights=[1e-8, 1],
    loss=loss,
    metrics=metrics,
)

tensor(5.4517e-06, dtype=torch.float64)
0 1
1 1
Compiling model...
'compile' took 0.000169 s



In [360]:
%autoreload

deepxde.display.training_display = mre_pinn.visual.TrainingPlot(
    losses=['pde_loss', 'data_loss'],
    metrics=['u_norm', 'mu_norm']
)

class BatchResampler(deepxde.callbacks.Callback):
    def on_batch_end(self):
        self.model.data.train_x_bc = None
        self.model.data.resample_train_points()

batch_resampler = BatchResampler()
try:
    model.train(1000, display_every=10, callbacks=[batch_resampler])
except KeyboardInterrupt as e:
    print('Interrupt', file=sys.stderr)

Training model...



<IPython.core.display.Javascript object>




Best model at step 950:
  train loss: 1.59e+00
  test loss: 2.96e+00
  test metric: [5.42e-06, 4.41e+03]

'train' took 24.321509 s



In [361]:
def minibatch(func, batch_size):
    def wrapper(*args, **kwargs):
        n = args[0].shape[0]
        outputs = []
        for i in tqdm.tqdm(range(0, n, batch_size)):
            batch_args = [a[i:i+batch_size] for a in args]
            output = func(*batch_args, **kwargs)
            outputs.append(output)
        return np.concatenate(outputs, axis=0)
    return wrapper

def laplacian_u(x, outputs):
    u = outputs[:,idxs[0]:idxs[1]]
    lu = mre_pinn.pde.laplacian(u, x, dim=1)
    deepxde.gradients.clear()
    return lu

batch_predict = minibatch(model.predict, batch_size)

# model predictions
x = bc.points
outputs = batch_predict(x)
u_pred  = outputs[:,idxs[0]:idxs[1]].reshape(u_true.shape)
mu_pred = outputs[:,idxs[1]:idxs[2]].reshape(u_true.shape[:-1])
lu_pred = batch_predict(x, operator=laplacian_u).reshape(u_true.shape)

100%|██████████| 1/1 [00:00<00:00, 731.22it/s]
100%|██████████| 1/1 [00:00<00:00, 153.78it/s]


In [362]:
# display wave fields
mre_pinn.visual.NDArrayViewer(u_true.real, y=None, labels=u_true.dims, dpi=50/ds, **wave_kws)
mre_pinn.visual.NDArrayViewer(u_pred.real, y=None, labels=u_true.dims, dpi=50/ds, **wave_kws)
#mre_pinn.visual.NDArrayViewer(u_true.real - u_pred.real, y=None, labels=u_true.dims, dpi=50/ds,cmap=w_map)

(1, 80, 1) ['frequency', 'x', 'component']
2 1


<IPython.core.display.Javascript object>

(1, 80, 1) ['frequency', 'x', 'component']
2 1


  slider = matplotlib.widgets.Slider(


<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x153033657640>

In [363]:
# display wave field Laplacians

Lu_true = mre_pinn.discrete.laplacian(u_true, resolution=1e-3, dim=1)
Lu_pred = mre_pinn.discrete.laplacian(u_pred, resolution=1e-3, dim=1)

mre_pinn.visual.NDArrayViewer(Lu_true.real, y=None, labels=u_true.dims, dpi=50/ds, **laplace_kws)
#mre_pinn.visual.NDArrayViewer(Lu_pred.real, y=None, labels=u_true.dims, dpi=50/ds, **laplace_kws)
mre_pinn.visual.NDArrayViewer(lu_pred.real, y=None, labels=u_true.dims, dpi=50/ds, **laplace_kws)
#mre_pinn.visual.NDArrayViewer(Lu_pred.real - lu_pred.real, labels=u_true.dims, dpi=50/ds, cmap=w_map)
#mre_pinn.visual.NDArrayViewer(lu_model.real, labels=u_true.dims, dpi=50/ds, **laplace_kws)
#mre_pinn.visual.NDArrayViewer(lu_model.real - lu_pred.real, labels=u_true.dims, dpi=50/ds, **laplace_kws)

(1, 80, 1) ['frequency', 'x', 'component']
2 1


<IPython.core.display.Javascript object>

(1, 80, 1) ['frequency', 'x', 'component']
2 1


<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x153032b82bb0>

In [364]:
# display reconstructed elastograms

omega = u_true.frequency.to_numpy().reshape(-1, 1, 1)
mu_data = mre_pinn.discrete.helmholtz_inversion(u_true, Lu_true, omega) #.mean(axis=0)
mu_u_Lu = mre_pinn.discrete.helmholtz_inversion(u_pred, Lu_pred, omega) #.mean(axis=0)
mu_u_lu = mre_pinn.discrete.helmholtz_inversion(u_pred, lu_pred, omega) #.mean(axis=0)
#mu_pred = mre_pinn.discrete.helmholtz_inversion(u_pred, lu_model, omega)

mre_pinn.visual.NDArrayViewer(mu_data.real, y=None, labels=u_true.dims[0:-1], dpi=50/ds, **elast_kws)
#mre_pinn.visual.NDArrayViewer(mu_u_Lu.real, y=None, labels=u_true.dims[0:-1], dpi=50/ds, **elast_kws)
#mre_pinn.visual.NDArrayViewer(mu_u_lu.real, y=None, labels=u_true.dims[0:-1], dpi=50/ds, **elast_kws)
mre_pinn.visual.NDArrayViewer(mu_pred.real, y=None, labels=u_true.dims[0:-1], dpi=50/ds, **elast_kws)
mre_pinn.visual.NDArrayViewer(mu_true.real, y=None, labels=u_true.dims[0:-1], dpi=50/ds, **elast_kws)

#mre_pinn.visual.NDArrayViewer(mu_data.real - mu_pred.real, labels=u_true.dims[1:-1], dpi=50/ds, cmap=w_map)

(1, 80) ['frequency', 'x']
1 1


<IPython.core.display.Javascript object>

(1, 80) ['frequency', 'x']
1 1


  slider = matplotlib.widgets.Slider(


<IPython.core.display.Javascript object>

(1, 80) ['frequency', 'x']
1 1


<IPython.core.display.Javascript object>

<mre_pinn.visual.NDArrayViewer at 0x15303336aa90>