In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

dv002.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/mre-pinn/notebooks


In [2]:
import sys, os, pathlib
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt
import seaborn as sns

os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

sys.path.append('..')
%aimport mre_pinn

torch.cuda.is_available()

Using backend: pytorch



True

In [3]:
%autoreload

data, test_data = mre_pinn.data.load_bioqic_dataset(
    data_root='../data/BIOQIC',
    data_name='phantom',
    xyz_slice='2D',
    frequency=80
)
data

Loading ../data/BIOQIC/phantom_unwrapped_dejittered.mat
    __header__: <class 'bytes'>
    __version__: <class 'str'>
    __globals__: <class 'list'>
    info: <class 'numpy.ndarray'> (1, 1) [('dx_m', 'O'), ('dy_m', 'O'), ('dz_m', 'O'), ('frequencies_Hz', 'O'), ('index_description', 'O'), ('size', 'O')]
    magnitude: <class 'numpy.ndarray'> (80, 128, 25, 8, 3, 8) uint16
    phase_unwrap_noipd: <class 'numpy.ndarray'> (80, 128, 25, 8, 3, 8) float64
Loading ../data/BIOQIC/phantom_elastogram.npy
     <class 'numpy.ndarray'> (8, 128, 80, 25) complex128
Loading ../data/BIOQIC/phantom_regions.npy
     <class 'numpy.ndarray'> (128, 80, 25) int64
Preprocessing data
Single frequency 2D
<xarray.Dataset>
Dimensions:         (frequency: 1, component: 2, x: 128, y: 80)
Coordinates:
  * frequency       (frequency) float64 80.0
  * component       (component) <U1 'z' 'y'
    z               float64 0.0
  * x               (x) float64 0.0 0.0015 0.003 0.0045 ... 0.1875 0.189 0.1905
  * y            

In [4]:
%autoreload

# configure color maps
anat_kws = mre_pinn.visual.get_color_kws(data.a)
wave_kws = mre_pinn.visual.get_color_kws(data.u)
laplace_kws = mre_pinn.visual.get_color_kws(data.Lu)
elast_kws = mre_pinn.visual.get_color_kws(data.mu)

# display true wave field and elastogram
y = 'y' if 'y' in data.field.spatial_dims else None
hue = None if 'y' in data.field.spatial_dims else 'part'
mre_pinn.visual.XArrayViewer(data.a,  col='domain', hue=None, ax_width=2, **anat_kws)
mre_pinn.visual.XArrayViewer(data.u,  col='domain', hue=None, ax_width=2, **wave_kws)
#mre_pinn.visual.XArrayViewer(data.Lu, col='domain', y=y, hue=None, ax_width=2, **laplace_kws)
mre_pinn.visual.XArrayViewer(data.Mu.mean('frequency'), col='domain', hue=None, ax_width=2, **elast_kws)
mre_pinn.visual.XArrayViewer(data.mu.mean('frequency'), col='domain', hue=None, ax_width=2, **elast_kws)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<mre_pinn.visual.XArrayViewer at 0x152204e37c40>

In [5]:
%autoreload

net = mre_pinn.pinn.ParallelPINN(
    n_inputs=[data.field.n_spatial_dims + 1, data.field.n_spatial_dims],
    n_outputs=[data.field.n_spatial_dims, 1],
    omega0=1,
    n_layers=5,
    n_hidden=128,
    activ_fn='s',
    dense=True,
    polar=False,
    conditional=False,
    dtype=torch.float32
)
net

ParallelPINN(
  (net0): PINN(
    (input_scaler): InputScaler()
    (linear0): Linear(in_features=3, out_features=128, bias=True)
    (linear1): Linear(in_features=131, out_features=128, bias=True)
    (linear2): Linear(in_features=259, out_features=128, bias=True)
    (linear3): Linear(in_features=387, out_features=128, bias=True)
    (linear4): Linear(in_features=515, out_features=4, bias=True)
    (output_scaler): OutputScaler()
  )
  (net1): PINN(
    (input_scaler): InputScaler()
    (linear0): Linear(in_features=3, out_features=128, bias=True)
    (linear1): Linear(in_features=131, out_features=128, bias=True)
    (linear2): Linear(in_features=259, out_features=128, bias=True)
    (linear3): Linear(in_features=387, out_features=128, bias=True)
    (linear4): Linear(in_features=515, out_features=2, bias=True)
    (output_scaler): OutputScaler()
  )
)

In [6]:
%autoreload
pde = mre_pinn.pde.WaveEquation.from_name('hetero', detach=True)
pde

<mre_pinn.pde.HeteroEquation at 0x152216f426e0>

In [7]:
%autoreload

model = mre_pinn.training.PINNModel(data, net, pde, batch_size=128)
model.compile(
    optimizer='adam',
    lr=1e-4,
    loss_weights=[1, 1e-8],
    loss=mre_pinn.training.standardized_msae_loss_fn(data.u.values)
)
deepxde.display.training_display = mre_pinn.training.SummaryDisplay()

test_eval = mre_pinn.testing.TestEvaluator(
    test_data, model, batch_size=None, test_every=100, save_every=1000, interact=True
)
try:
    model.train(50000, display_every=10, callbacks=[test_eval])
except KeyboardInterrupt as e:
    print('Interrupt', file=sys.stderr)

  data = torch.as_tensor(data, dtype=self.dtype)



Compiling model...
'compile' took 0.000493 s

Training model...



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='domain', options=(('space', 0), ('frequency', 1)), value=0)…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='domain', options=(('space', 0), ('frequency', 1)), value=0)…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='domain', options=(('space', 0), ('frequency', 1)), value=0)…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='domain', options=(('space', 0), ('frequency', 1)), value=0)…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='domain', options=(('space', 0), ('frequency', 1)), value=0)…

Time spent testing: 39.96%
Time spent testing: 34.12%
Time spent testing: 31.85%
Time spent testing: 30.99%
Time spent testing: 30.15%
Time spent testing: 29.56%
Time spent testing: 29.14%
Time spent testing: 28.83%
Time spent testing: 28.59%
Time spent testing: 28.39%
Time spent testing: 28.42%
Time spent testing: 28.28%
Time spent testing: 28.19%
Time spent testing: 28.10%
Time spent testing: 28.01%
Time spent testing: 27.94%
Time spent testing: 27.88%
Time spent testing: 27.93%
Time spent testing: 27.89%
Time spent testing: 27.84%
Time spent testing: 27.90%
Time spent testing: 27.87%
Time spent testing: 27.85%
Time spent testing: 27.91%
Time spent testing: 27.90%
Time spent testing: 27.96%
Time spent testing: 28.01%
Time spent testing: 28.06%
Time spent testing: 28.10%
Time spent testing: 28.14%
Time spent testing: 28.25%
Time spent testing: 28.29%
Time spent testing: 28.32%
Time spent testing: 28.36%
Time spent testing: 28.39%
Time spent testing: 28.42%
Time spent testing: 28.51%
T

Time spent testing: 33.80%
Time spent testing: 33.82%
Time spent testing: 33.84%
Time spent testing: 33.86%
Time spent testing: 33.88%
Time spent testing: 33.89%
Time spent testing: 33.91%
Time spent testing: 33.93%
Time spent testing: 33.95%
Time spent testing: 33.97%
Time spent testing: 33.98%
Time spent testing: 34.01%
Time spent testing: 34.02%
Time spent testing: 34.04%
Time spent testing: 34.06%
Time spent testing: 34.08%
Time spent testing: 34.10%
Time spent testing: 34.12%
Time spent testing: 34.13%
Time spent testing: 34.16%
Time spent testing: 34.17%
Time spent testing: 34.19%
Time spent testing: 34.21%
Time spent testing: 34.23%
Time spent testing: 34.25%
Time spent testing: 34.27%
Time spent testing: 34.29%
Time spent testing: 34.30%
Time spent testing: 34.33%
Time spent testing: 34.34%
Time spent testing: 34.36%
Time spent testing: 34.38%
Time spent testing: 34.40%
Time spent testing: 34.42%
Time spent testing: 34.44%
Time spent testing: 34.45%
Time spent testing: 34.47%
T

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

