In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

dv003.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/mre-pinn/notebooks


In [2]:
import sys, os, pathlib
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt
import seaborn as sns

os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

sys.path.append('..')
%aimport mre_pinn
xarray = mre_pinn.utils.as_xarray

torch.cuda.is_available()

  from .autonotebook import tqdm as notebook_tqdm
Using backend: pytorch

Loading /ocean/projects/asc170022p/mtragoza/mre-pinn/notebooks/../mre_pinn/__init__.py


True

In [3]:
%autoreload

# load the FEM box data set
data = mre_pinn.data.load_bioqic_fem_box_data('../data/BIOQIC')

# select data subset
data, ndim = mre_pinn.data.select_data_subset(
    data,
    downsample=False,
    frequency=80,
    x_slice=None,
    y_slice=75,
    z_slice=0
)

u  = data['u']
mu = data['mu']

data

Single frequency
1D


In [4]:
%autoreload

# direct Helmholtz inversion via discrete laplacian

Lu = mre_pinn.discrete.laplacian(u, resolution=1e-3, dim=1)
data['Lu'] = Lu

#omega = u.frequency.expand_dims(u.dims[1:], axis=range(1, u.ndim)).to_numpy()
Mu = mre_pinn.discrete.helmholtz_inversion(u, Lu, u.frequency.to_numpy())
data['Mu'] = Mu


In [62]:
%autoreload

# configure color maps
pct = 100

w_map = mre_pinn.visual.wave_color_map()
w_max = np.percentile(np.abs(u), pct) * 1.1
wave_kws = dict(cmap=w_map, vmin=-w_max, vmax=w_max)

L_max = np.percentile(np.abs(Lu), pct) * 1.1
laplace_kws = dict(cmap=w_map, vmin=-L_max, vmax=L_max)

e_map = mre_pinn.visual.elast_color_map()
e_max = np.percentile(np.abs(mu), pct) * 1.1
elast_kws = dict(cmap=e_map, vmin=-e_max, vmax=e_max)

# display true wave field and elastogram

x = 'x'
y = 'y' if ndim > 1 else None
hue = 'part'

mre_pinn.visual.XArrayViewer(u,  y=y, hue=hue, **wave_kws)
mre_pinn.visual.XArrayViewer(Lu, y=y, hue=hue, **laplace_kws)
mre_pinn.visual.XArrayViewer(Mu, y=y, hue=hue, **elast_kws)
mre_pinn.visual.XArrayViewer(mu, y=y, hue=hue, **elast_kws)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<mre_pinn.visual.XArrayViewer at 0x145fecfd1100>

In [63]:
%autoreload

# create point set boundary condition

batch_size = 80 # training data points
u_bc = mre_pinn.data.XArrayBC(u, batch_size=batch_size)

#bc.points = bc.points.astype(np.complex128)

x = u_bc.points
u_true = u_bc.values
print('x     ', type(x), tuple(x.shape), x.dtype)
print('u_true', type(u_true), tuple(u_true.shape), u_true.dtype)

x      <class 'numpy.ndarray'> (80, 2) float32
u_true <class 'torch.Tensor'> (80, 1) torch.complex64


In [64]:
%autoreload

# set up PDE with geometry and boundary condition
wave_eq = mre_pinn.pde.WaveEquation(detach=True, homogeneous=True)

# need to add eps to singleton dimensions to avoid zero volume
eps = np.where(np.array(u.shape[:-1]) > 1, 0, 1e-5)

geometry = deepxde.geometry.Hypercube(x.min(axis=0), x.max(axis=0) + eps)
pde = deepxde.data.PDE(geometry, wave_eq, u_bc, num_domain=batch_size)

train_x = np.array(pde.train_x)
train_x.shape

(160, 2)

In [65]:
# descriptive statistics of model inputs and outputs

# standardize inputs to [-1, 1]
x = torch.as_tensor(u_bc.points)
x_loc   = x.mean(dim=0, keepdim=True)
x_scale = (x.max(dim=0, keepdim=True).values - x.min(dim=0, keepdim=True).values) / 2

print('x', tuple(x.shape))
print(x_loc)
print(x_scale)

x_scale[x_scale == 0] = 1 # avoid division by zero

# normalize outputs using mean and std
u_true  = mre_pinn.utils.as_real(u_bc.values)
u_loc   = u_true.mean(dim=0, keepdim=True)
u_scale = u_true.std(dim=0, keepdim=True)

print('\nu_true', tuple(u_true.shape))
print(u_loc)
print(u_scale)

mu_true  = mre_pinn.utils.as_real(torch.as_tensor(mu.to_numpy()).reshape(-1, 1))
mu_loc   = mu_true.mean(dim=0, keepdim=True)
mu_scale = mu_true.std(dim=0, keepdim=True)

print('\nmu_true', tuple(mu_true.shape))
print(mu_loc)
print(mu_scale)

u_mu_loc   = torch.cat([u_loc, mu_loc], dim=1)
u_mu_scale = torch.cat([u_scale, mu_scale], dim=1)

x (80, 2)
tensor([[8.0000e+01, 3.9500e-02]])
tensor([[0.0000, 0.0395]])

u_true (80, 2)
tensor([[ 1.2730e-07, -9.9131e-07]])
tensor([[4.1070e-06, 4.8689e-06]])

mu_true (80, 2)
tensor([[4688.4461,  502.6548]], dtype=torch.float64)
tensor([[2935.2449,    0.0000]], dtype=torch.float64)


In [103]:
%autoreload

# define model architecture

n_input   = x.shape[1] # frequency and spatial dims
n_outputs = [ndim, 1]  # u and mu, which are both complex
parallel  = True       # separate nets for each output

# input transformation
input_fn = lambda x: (x - x_loc) / x_scale

# hidden layers and activation function(s)

def gauss(x):
    return torch.exp(-x**2)

omega0   = 32
n_layers = 2
n_hidden = 8
activ_fn = [torch.sin, gauss]

# output transformations

if parallel:
    net_outputs = [n * 2 for n in n_outputs]
    output_fns = [
        lambda x: mre_pinn.utils.as_complex(x * u_scale + u_loc),
        lambda x: mre_pinn.utils.as_complex(x * mu_scale + mu_loc)
    ]
else:
    net_outputs = [sum(n_outputs) * 2]
    output_fns = [
        lambda x: mre_pinn.utils.as_complex(x * u_mu_scale + u_mu_loc)
    ]

# store slice indices for model outputs
idxs = [0] + list(np.cumsum(n_outputs))

# construct the network(s)
nets = [
    mre_pinn.model.PINN(
        n_input=n_input,
        n_layers=n_layers,
        n_hidden=n_hidden,
        n_output=n_output,
        activ_fn=activ_fn,
        input_fn=input_fn,
        output_fn=output_fn,
        dense=False
    ) for n_output, output_fn in zip(net_outputs, output_fns)
]

# initialize model weights
for n in nets:
    n.init_weights(omega0)

if parallel:
    net = mre_pinn.model.Parallel(nets)
else:
    net = nets[0]
    
net.regularizer = None

net

Parallel(
  (0): PINN(
    (linear0_sin): Linear(in_features=2, out_features=8, bias=True)
    (linear0_gauss): Linear(in_features=2, out_features=8, bias=True)
    (linear1): Linear(in_features=8, out_features=2, bias=True)
  )
  (1): PINN(
    (linear0_sin): Linear(in_features=2, out_features=8, bias=True)
    (linear0_gauss): Linear(in_features=2, out_features=8, bias=True)
    (linear1): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [104]:
# test forward pass

x   = torch.as_tensor(u_bc.points).requires_grad_(False)
x_s = (x - x_loc) / x_scale

outputs = net.forward(x)

u_pred   = mre_pinn.utils.as_real(outputs[:,idxs[0]:idxs[1]])
u_pred_s = (u_pred - u_loc) / u_scale
print(u_pred.shape, u_loc.shape, u_scale.shape)

u_true   = mre_pinn.utils.as_real(u_bc.values)
u_true_s = (u_true - u_loc) / u_scale

mu_pred   = mre_pinn.utils.as_real(outputs[:,idxs[1]:idxs[2]])
mu_pred_s = (mu_pred - mu_loc) / mu_scale

mu_true   = mre_pinn.utils.as_real(torch.as_tensor(mu.to_numpy()).reshape(-1, 1))
mu_true_s = (mu_true - mu_loc) / mu_scale

# display input and output distributions

def plot_hist(ax, a, label):
    sns.histplot(a.detach().cpu().numpy(), bins=20, ax=ax)
    ax.set_xlabel(label)
    ax.get_legend().set_frame_on(False)

fig, axes = plt.subplots(5, 2, figsize=(8, 8))

plot_hist(axes[0,0], x,   label='x')
plot_hist(axes[0,1], x_s, label='x_s')

plot_hist(axes[1,0], u_pred,   label='u_pred')
plot_hist(axes[1,1], u_pred_s, label='u_pred_s')

plot_hist(axes[2,0], u_true,   label='u_true')
plot_hist(axes[2,1], u_true_s, label='u_true_s')

plot_hist(axes[3,0], mu_pred,   label='mu_pred')
plot_hist(axes[3,1], mu_pred_s, label='mu_pred_s')

plot_hist(axes[4,0], mu_true,   label='mu_true')
plot_hist(axes[4,1], mu_true_s, label='mu_true_s')

fig.tight_layout()

torch.Size([80, 2]) torch.Size([1, 2]) torch.Size([1, 2])


<IPython.core.display.Javascript object>

In [110]:
%autoreload

pde.pde = mre_pinn.pde.WaveEquation(detach=True, homogeneous=False)

# create normalized loss functions
u_norm = torch.norm(u_bc.values, dim=-1).mean().detach()
print(u_norm)

def normalized_L2_loss(norm):
    def loss_fn(y_true, y_pred):
        return torch.mean(
            (torch.norm(y_true - y_pred, dim=-1) / norm)
        )
    return loss_fn

loss = normalized_L2_loss(u_norm)

def predicted_L2_norm(idx, n):
    def norm(y_true, y_pred):
        return np.mean(
            np.linalg.norm(y_pred[:,idx:idx+n], axis=-1)
        )
    return norm

metrics = [
    predicted_L2_norm(idxs[i], n) for i, n in enumerate(n_outputs)
]

model = deepxde.Model(pde, net)
model.compile(
    optimizer='adam',
    lr=1e-2,
    loss_weights=[1e-8, 1],
    loss=loss,
    metrics=metrics,
)
batch_predict = mre_pinn.utils.minibatch(model.predict, batch_size)
batch_size

tensor(5.4517e-06)
Compiling model...
'compile' took 0.000137 s



80

In [106]:
%autoreload

deepxde.display.training_display = mre_pinn.visual.TrainingPlot(
    losses=['pde_loss', 'data_loss'],
    metrics=['u_norm', 'mu_norm']
)

class PDEPointResampler(deepxde.callbacks.Callback):

    def on_batch_end(self):
        self.model.data.train_x_all = None
        self.model.data.train_x_bc = None
        self.model.data.resample_train_points()


class OutputViewer(deepxde.callbacks.Callback):
    
    def __init__(self, update_every, spectrum=False, residual=False):
        self.update_every = update_every
        self.spectrum = spectrum
        self.residual = residual
        super().__init__()
        
    def get_outputs(self):
        x = torch.as_tensor(self.model.data.bcs[0].points).requires_grad_(True)
        outputs = self.model.net(x)
        pde = model.data.pde(x, outputs).reshape(mu.shape).detach().cpu().numpy()
        u_pred  = outputs[:,idxs[0]:idxs[1]].reshape(u.shape).detach().cpu().numpy()
        mu_pred = outputs[:,idxs[1]:idxs[2]].reshape(mu.shape).detach().cpu().numpy()
        
        u_pred  = mre_pinn.utils.as_xarray(u_pred,  like=u)
        mu_pred = mre_pinn.utils.as_xarray(mu_pred, like=mu)
        pde = mre_pinn.utils.as_xarray(pde, like=mu)
         
        if self.spectrum: # TODO what if it's 1D?
            u_pred  = np.fft.fftshift(np.fft.fftn(u_pred,  axes=(1,2)), axes=(1,2))
            mu_pred = np.fft.fftshift(np.fft.fftn(mu_pred, axes=(1,2)), axes=(1,2))
            raise NotImplementedError

        elif self.residual:
            u_pred  = xr.concat([u_pred,  u - u_pred], dim='residual')
            mu_pred = xr.concat([mu_pred, pde, mu - mu_pred], dim='residual')

        return u_pred, mu_pred

    def on_train_begin(self):
        u_pred, mu_pred = self.get_outputs()
        self.u_viewer  = mre_pinn.visual.XArrayViewer(u_pred,  y=y, hue='residual', **wave_kws)
        self.mu_viewer = mre_pinn.visual.XArrayViewer(mu_pred, y=y, hue='residual', **elast_kws)

    def on_batch_end(self):
        if self.model.train_state.step % self.update_every != 0:
            return

        u_pred, mu_pred = self.get_outputs()
        self.u_viewer.update_array(u_pred)
        self.mu_viewer.update_array(mu_pred)

callbacks = [PDEPointResampler(), OutputViewer(update_every=1, spectrum=False, residual=True)]
try:
    model.train(1000, display_every=10, callbacks=callbacks)
except KeyboardInterrupt as e:
    print('Interrupt', file=sys.stderr)

Training model...



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>




Best model at step 80:
  train loss: 1.58e+00
  test loss: 2.41e+00
  test metric: [4.26e-06, 2.29e+03]

'train' took 71.038503 s



In [125]:
def laplacian_u(x, outputs):
    u = outputs[:,idxs[0]:idxs[1]]
    lu = mre_pinn.pde.laplacian(u, x, dim=1)
    deepxde.gradients.clear()
    return lu

# model predictions
x = u_bc.points
outputs = batch_predict(x)
u_pred  = outputs[:,idxs[0]:idxs[1]].reshape(u.shape)
mu_pred = outputs[:,idxs[1]:idxs[2]].reshape(mu.shape)
lu_pred = batch_predict(x, operator=laplacian_u).reshape(u.shape)

u_pred  = mre_pinn.utils.as_xarray(u_pred,  like=u)
lu_pred = mre_pinn.utils.as_xarray(lu_pred, like=u)
mu_pred = mre_pinn.utils.as_xarray(mu_pred, like=mu)

In [126]:
%autoreload

# display wave fields
mre_pinn.visual.XArrayViewer(
    xr.concat([u_pred, u - u_pred, u], dim='which'), y=y, hue='which', **wave_kws
)

<IPython.core.display.Javascript object>

<mre_pinn.visual.XArrayViewer at 0x145fb1a2eee0>

In [127]:
# display wave field Laplacians

Lu_true = mre_pinn.discrete.laplacian(u_true, resolution=1e-3, dim=1)
Lu_pred = mre_pinn.discrete.laplacian(u_pred, resolution=1e-3, dim=1)

mre_pinn.visual.XArrayViewer(
    xr.concat([lu_pred, lu_pred - Lu, Lu], dim='which'), y=y, hue='which', **laplace_kws
)

<IPython.core.display.Javascript object>

<mre_pinn.visual.XArrayViewer at 0x145fed7bd850>

In [130]:
%autoreload

# display reconstructed elastograms

omega = u.frequency.to_numpy().reshape(-1, 1, 1)
mu_data = mre_pinn.discrete.helmholtz_inversion(u, Lu_true, omega) #.mean(axis=0)
mu_u_Lu = mre_pinn.discrete.helmholtz_inversion(u_pred, Lu_pred, omega) #.mean(axis=0)
mu_u_lu = mre_pinn.discrete.helmholtz_inversion(u_pred, lu_pred, omega) #.mean(axis=0)
#mu_pred = mre_pinn.discrete.helmholtz_inversion(u_pred, lu_model, omega)

mre_pinn.visual.XArrayViewer(
    xr.concat([mu_pred, mu_pred - mu, mu], dim='which'), y=y, hue='which', **elast_kws
)

<IPython.core.display.Javascript object>

<mre_pinn.visual.XArrayViewer at 0x145fb162be50>