In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

dv001.ib.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/mre-pinn/notebooks


In [2]:
import sys, os, pathlib
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt
import seaborn as sns

os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

sys.path.append('..')
%aimport mre_pinn

torch.cuda.is_available()

Using backend: pytorch



True

In [5]:
%autoreload

data, test_data = mre_pinn.data.load_bioqic_dataset(
    data_root='../data/BIOQIC', data_name='phantom', frequency=80, xyz_slice='2D')
data

Loading ../data/BIOQIC/phantom_unwrapped_dejittered.mat
    __header__: <class 'bytes'>
    __version__: <class 'str'>
    __globals__: <class 'list'>
    info: <class 'numpy.ndarray'> (1, 1) [('dx_m', 'O'), ('dy_m', 'O'), ('dz_m', 'O'), ('frequencies_Hz', 'O'), ('index_description', 'O'), ('size', 'O')]
    magnitude: <class 'numpy.ndarray'> (80, 128, 25, 8, 3, 8) uint16
    phase_unwrap_noipd: <class 'numpy.ndarray'> (80, 128, 25, 8, 3, 8) float64
Loading ../data/BIOQIC/phantom_elastogram.npy
     <class 'numpy.ndarray'> (8, 128, 80, 25) complex128
Loading ../data/BIOQIC/phantom_regions.npy
     <class 'numpy.ndarray'> (128, 80, 25) int64
Preprocessing data
Single frequency 2D
<xarray.Dataset>
Dimensions:         (frequency: 1, component: 2, x: 128, y: 80)
Coordinates:
  * frequency       (frequency) float64 80.0
  * component       (component) <U1 'z' 'y'
    z               float64 0.0
  * x               (x) float64 0.0 0.0015 0.003 0.0045 ... 0.1875 0.189 0.1905
  * y            

In [6]:
%autoreload

# configure color maps
anat_kws = mre_pinn.visual.get_color_kws(data.a)
wave_kws = mre_pinn.visual.get_color_kws(data.u)
laplace_kws = mre_pinn.visual.get_color_kws(data.Lu)
elast_kws = mre_pinn.visual.get_color_kws(data.mu)

# display true wave field and elastogram
y = 'y' if 'y' in data.field.spatial_dims else None
hue = None if 'y' in data.field.spatial_dims else 'part'
mre_pinn.visual.XArrayViewer(data.a,  col='domain', y=y, hue=None, ax_width=2, **anat_kws)
mre_pinn.visual.XArrayViewer(data.u,  col='domain', y=y, hue=None, ax_width=2, **wave_kws)
mre_pinn.visual.XArrayViewer(data.Lu, col='domain', y=y, hue=None, ax_width=2, **laplace_kws)
mre_pinn.visual.XArrayViewer(data.Mu.mean('frequency'), col='domain', y=y, hue=None, ax_width=2, **elast_kws)
mre_pinn.visual.XArrayViewer(data.mu.mean('frequency'), col='domain', y=y, hue=None, ax_width=2, **elast_kws)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<mre_pinn.visual.XArrayViewer at 0x15311da109a0>

In [78]:
%autoreload

net = mre_pinn.pinn.ParallelPINN(
    n_inputs=[data.field.n_spatial_dims + 1, data.field.n_spatial_dims],
    n_outputs=[data.field.n_spatial_dims, 1],
    omega0=16,
    n_layers=5,
    n_hidden=128,
    activ_fn='s',
    dense=True,
    dtype=torch.float32
)
net

ParallelPINN(
  (net0): PINN(
    (input_scaler): InputScaler()
    (linear0): Linear(in_features=5, out_features=128, bias=True)
    (linear1): Linear(in_features=133, out_features=128, bias=True)
    (linear2): Linear(in_features=261, out_features=128, bias=True)
    (linear3): Linear(in_features=389, out_features=128, bias=True)
    (linear4): Linear(in_features=517, out_features=2, bias=True)
    (output_scaler): OutputScaler()
  )
  (net1): PINN(
    (input_scaler): InputScaler()
    (linear0): Linear(in_features=5, out_features=128, bias=True)
    (linear1): Linear(in_features=133, out_features=128, bias=True)
    (linear2): Linear(in_features=261, out_features=128, bias=True)
    (linear3): Linear(in_features=389, out_features=128, bias=True)
    (linear4): Linear(in_features=517, out_features=1, bias=True)
    (output_scaler): OutputScaler()
  )
)

In [None]:
%autoreload

pde = mre_pinn.pde.WaveEquation.from_name('helmholtz', detach=True)
model = mre_pinn.training.PINNModel(data, net, pde, batch_size=128)
model.compile(
    optimizer='adam',
    lr=1e-4,
    loss_weights=[1e0, 1e-8],
    loss=mre_pinn.training.standardized_msae_loss_fn(data.u.values)
)
deepxde.display.training_display = mre_pinn.training.SummaryDisplay()

test_eval = mre_pinn.testing.TestEvaluator(
    test_data, model, batch_size=1024, test_every=100, save_every=1000, interact=True
)
try:
    model.train(100000, display_every=10, callbacks=[test_eval])
except KeyboardInterrupt as e:
    print('Interrupt', file=sys.stderr)

Compiling model...
'compile' took 0.000214 s

Training model...



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

Time spent testing: 67.33%
Time spent testing: 58.28%
Time spent testing: 55.17%
Time spent testing: 53.25%
Time spent testing: 52.06%
Time spent testing: 50.72%
Time spent testing: 50.21%
Time spent testing: 49.78%
Time spent testing: 49.46%
Time spent testing: 48.96%
Time spent testing: 48.72%
Time spent testing: 48.59%
Time spent testing: 48.22%
Time spent testing: 48.15%
Time spent testing: 48.08%
Time spent testing: 47.83%
Time spent testing: 47.79%
Time spent testing: 47.76%
Time spent testing: 47.73%
Time spent testing: 47.73%
Time spent testing: 48.22%
Time spent testing: 48.22%
Time spent testing: 48.22%
Time spent testing: 48.21%
Time spent testing: 48.22%
Time spent testing: 48.21%
Time spent testing: 48.21%
Time spent testing: 48.08%
Time spent testing: 47.96%
Time spent testing: 47.96%
Time spent testing: 47.97%
Time spent testing: 47.87%
Time spent testing: 47.84%
Time spent testing: 47.85%
Time spent testing: 47.87%
Time spent testing: 47.89%
Time spent testing: 47.90%
T

In [None]:
%autoreload
test_eval.test(data)