In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

dv001.ib.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/mre-pinn/notebooks


In [2]:
import sys, os, pathlib
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt
import seaborn as sns

os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

sys.path.append('..')
%aimport mre_pinn

torch.cuda.is_available()

Using backend: pytorch



True

In [3]:
%autoreload

data, test_data = mre_pinn.data.load_bioqic_dataset('../data/BIOQIC', data_name='fem_box', frequency='multi', xyz_slice='2D')
data

Loading ../data/BIOQIC/four_target_phantom.mat
    __header__: <class 'bytes'>
    __version__: <class 'str'>
    __globals__: <class 'list'>
    u_ft: <class 'numpy.ndarray'> (100, 80, 10, 3, 6) complex128
Loading ../data/BIOQIC/fem_box_elastogram.npy
     <class 'numpy.ndarray'> (6, 10, 80, 100) complex128
Loading ../data/BIOQIC/fem_box_regions.npy
     <class 'numpy.ndarray'> (10, 80, 100) int64
Multi frequency 2D
<xarray.Dataset>
Dimensions:         (frequency: 6, x: 80, y: 100, component: 2)
Coordinates:
  * frequency       (frequency) float64 50.0 60.0 70.0 80.0 90.0 100.0
  * x               (x) float64 0.0 0.001 0.002 0.003 ... 0.077 0.078 0.079
  * y               (y) float64 0.0 0.001 0.002 0.003 ... 0.097 0.098 0.099
    z               float64 0.0
  * component       (component) <U1 'z' 'y'
Data variables:
    u               (frequency, x, y, component) complex128 (7.43590828448788...
    mu              (frequency, x, y) complex128 (3000+314.1592653589793j) .....
    spat

In [4]:
%autoreload

# configure color maps
wave_kws = mre_pinn.visual.get_color_kws(data.u)
laplace_kws = mre_pinn.visual.get_color_kws(data.Lu)
elast_kws = mre_pinn.visual.get_color_kws(data.mu)

# display true wave field and elastogram
y = 'y' if 'y' in data.field.spatial_dims else None
hue = None if 'y' in data.field.spatial_dims else 'part'
mre_pinn.visual.XArrayViewer(data.u,  col='domain', y=y, hue=None, ax_width=2, **wave_kws)
mre_pinn.visual.XArrayViewer(data.Lu, col='domain', y=y, hue=None, ax_width=2, **laplace_kws)
mre_pinn.visual.XArrayViewer(data.Mu.mean('frequency'), col='domain', y=y, hue=None, ax_width=2, **elast_kws)
mre_pinn.visual.XArrayViewer(data.mu.mean('frequency'), col='domain', y=y, hue=None, ax_width=2, **elast_kws)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<mre_pinn.visual.XArrayViewer at 0x1516f468a110>

In [5]:
# convert to vector/scalar fields and coordinates
x  = data.u.field.points().astype(np.float32)
u  = data.u.field.values().astype(np.complex64)
mu = data.mu.field.values().astype(np.complex64)

print('x ', type(x), x.shape, x.dtype)
print('u ', type(u), u.shape, u.dtype)
print('mu', type(mu), mu.shape, mu.dtype)

x  <class 'numpy.ndarray'> (48000, 3) float32
u  <class 'numpy.ndarray'> (48000, 2) complex64
mu <class 'numpy.ndarray'> (48000, 1) complex64


In [6]:
%autoreload

# initialize the PDE, geometry, and boundary conditions
pde = mre_pinn.pde.WaveEquation.from_name('hetero', detach=True)
geom = deepxde.geometry.Hypercube(x.min(axis=0), x.max(axis=0) + 1e-5)
#geom = deepxde.geometry.PointCloud(x)
bc = mre_pinn.fields.VectorFieldBC(x, u)

In [7]:
%autoreload

# define model architecture

net = mre_pinn.model.MREPINN(
    input=x,
    outputs=[u, mu],
    omega0=16,
    n_layers=5,
    n_hidden=128,
    activ_fn='t',
    parallel=True,
    dense=True,
    dtype=torch.float32
)

net

MREPINN(
  (0): InputScaler()
  (1): Parallel(
    (0): FFNN(
      (linear0_tanh): Linear(in_features=3, out_features=128, bias=True)
      (linear1_tanh): Linear(in_features=131, out_features=128, bias=True)
      (linear2_tanh): Linear(in_features=259, out_features=128, bias=True)
      (linear3_tanh): Linear(in_features=387, out_features=128, bias=True)
      (linear4): Linear(in_features=515, out_features=4, bias=True)
    )
    (1): FFNN(
      (linear0_tanh): Linear(in_features=3, out_features=128, bias=True)
      (linear1_tanh): Linear(in_features=131, out_features=128, bias=True)
      (linear2_tanh): Linear(in_features=259, out_features=128, bias=True)
      (linear3_tanh): Linear(in_features=387, out_features=128, bias=True)
      (linear4): Linear(in_features=515, out_features=2, bias=True)
    )
  )
  (2): RealToComplex()
  (3): OutputScaler()
)

In [8]:
model = mre_pinn.training.MREPINNModel(
    net, pde, geom, bc,
    batch_size=128,
    num_domain=128,
    num_boundary=0,
    train_distribution='pseudo',
    anchors=None
)

In [None]:
# descriptive statistics

%autoreload

kws = dict(axis=0, keepdims=True)

# standardize input to [-1, 1]
x_loc = x.mean(**kws)
x_scale = (x.max(**kws) - x.min(**kws)) / 2
print('x', x_loc, x_scale)

# avoid division by zero
x_scale[x_scale == 0] = 1
x_s = (x - x_loc) / x_scale

# test forward pass
u_pred, lu_pred, mu_pred, f_trac, f_body = model.predict(x, batch_size=128)
pde_res = f_trac + f_body

u_pred = u_pred.detach().cpu().numpy()
lu_pred = lu_pred.detach().cpu().numpy()
mu_pred = mu_pred.detach().cpu().numpy()

# normalize outputs via mean and std
u_loc = u.mean(**kws)
u_scale = u.std(**kws)
print('u', u_loc, u_scale)

u_s = (u - u_loc) / u_scale
u_pred_s = (u_pred - u_loc) / u_scale

mu_loc = mu.mean(**kws)
mu_scale = mu.std(**kws)
print('mu', mu_loc, mu_scale)

mu_s = (mu - mu_loc) / mu_scale
mu_pred_s = (mu_pred - mu_loc) / mu_scale

# compute laplacian statistics
lu_loc = lu_pred.mean(**kws)
lu_scale = lu_pred.std(**kws)
print('lu', lu_loc, lu_scale)

# display input and output distributions

def hex_to_rgb(h):
    return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))

def plot_hist(ax, a, xlabel, hue=None):
    if np.iscomplexobj(a):
        a = mre_pinn.utils.as_real(a)
    bin_width = (a.max() - a.min()) / 20
    for a, hue in zip(a.T, hue):
        bins = np.arange(a.min(), a.max() + bin_width, bin_width)
        color = next(ax._get_lines.prop_cycler)["color"]
        ax.hist(a, bins=bins, label=hue, edgecolor='0.2', fc=color + '80')
    ax.set_xlabel(xlabel)
    ax.legend(frameon=True, edgecolor='0.2')

fig, axes = plt.subplots(6, 2, figsize=(8, 10))

ndim = len(data.u.field.spatial_dims)
x_hues = ['$\omega$', '$x$', '$y$', '$z$'][:ndim+1]
plot_hist(axes[0,0], x,   hue=x_hues, xlabel='x')
plot_hist(axes[0,1], x_s, hue=x_hues, xlabel='x_s')

u_hues = ['$u_z$', '$u_y$', '$u_x$'][:ndim]
u_hues = [p.format(h) for h in u_hues for p in ['Re[{}]', 'Im[{}]']]
plot_hist(axes[1,0], u_pred,   hue=u_hues, xlabel='u_pred')
plot_hist(axes[1,1], u_pred_s, hue=u_hues, xlabel='u_pred_s')

plot_hist(axes[2,0], u,   hue=u_hues, xlabel='u_true')
plot_hist(axes[2,1], u_s, hue=u_hues, xlabel='u_true_s')

mu_hues = ['Re[$\mu$]', 'Im[$\mu$]']
plot_hist(axes[3,0], mu_pred,   hue=mu_hues, xlabel='mu_pred')
plot_hist(axes[3,1], mu_pred_s, hue=mu_hues, xlabel='mu_pred_s')

plot_hist(axes[4,0], mu,   hue=mu_hues, xlabel='mu_true')
plot_hist(axes[4,1], mu_s, hue=mu_hues, xlabel='mu_true_s')

lu_hues = ['$\\nabla^2 u_z$', '$\\nabla^2 u_y$', '$\nabla^2 u_x$'][:ndim]
#lu_hues = [p.format(h) for h in lu_hues for p in ['Re[{}]', 'Im[{}]']]
plot_hist(axes[5,0], lu_pred, hue=lu_hues, xlabel='lu_pred')

fig.tight_layout()

In [9]:
%autoreload

model.compile(
    optimizer='adam',
    lr=1e-4,
    loss_weights=[1e-8, 1],
    loss=mre_pinn.training.standardized_msae_loss_fn(u)
)
deepxde.display.training_display = mre_pinn.training.SummaryDisplay()
callbacks = [
    mre_pinn.training.TestEvaluation(test_data, batch_size=256, test_every=100, save_every=1000, interact=True),
    mre_pinn.training.PDEResampler(period=1),
]

try:
    model.train(250000, display_every=10, callbacks=callbacks)
except KeyboardInterrupt as e:
    print('Interrupt', file=sys.stderr)

Compiling model...
'compile' took 0.000274 s

Training model...



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

Time spent testing: 59.63%
Time spent testing: 55.54%
Time spent testing: 53.78%
Time spent testing: 53.07%
Time spent testing: 52.44%
Time spent testing: 51.99%
Time spent testing: 51.66%
Time spent testing: 51.40%
Time spent testing: 51.24%
Time spent testing: 51.07%
Time spent testing: 50.96%
Time spent testing: 50.95%
Time spent testing: 50.87%
Time spent testing: 50.81%
Time spent testing: 50.73%
Time spent testing: 50.67%
Time spent testing: 50.61%
Time spent testing: 50.58%
Time spent testing: 50.61%
Time spent testing: 50.58%
Time spent testing: 50.55%
Time spent testing: 50.53%
Time spent testing: 50.51%
Time spent testing: 50.49%
Time spent testing: 50.48%
Time spent testing: 50.47%
Time spent testing: 50.51%
Time spent testing: 50.50%
Time spent testing: 50.49%
Time spent testing: 50.48%
Time spent testing: 50.47%
Time spent testing: 50.46%
Time spent testing: 50.45%
Time spent testing: 50.45%
Time spent testing: 50.48%
Time spent testing: 50.47%
Time spent testing: 50.46%
T

Time spent testing: 55.09%
Time spent testing: 55.10%
Time spent testing: 55.12%
Time spent testing: 55.14%
Time spent testing: 55.15%
Time spent testing: 55.17%
Time spent testing: 55.19%
Time spent testing: 55.20%
Time spent testing: 55.22%
Time spent testing: 55.23%
Time spent testing: 55.25%
Time spent testing: 55.26%
Time spent testing: 55.28%
Time spent testing: 55.30%
Time spent testing: 55.32%
Time spent testing: 55.33%
Time spent testing: 55.35%
Time spent testing: 55.36%
Time spent testing: 55.38%
Time spent testing: 55.39%
Time spent testing: 55.41%
Time spent testing: 55.43%
Time spent testing: 55.45%
Time spent testing: 55.46%
Time spent testing: 55.47%
Time spent testing: 55.49%
Time spent testing: 55.51%
Time spent testing: 55.53%
Time spent testing: 55.54%
Time spent testing: 55.56%
Time spent testing: 55.57%
Time spent testing: 55.59%
Time spent testing: 55.61%
Time spent testing: 55.62%
Time spent testing: 55.64%
Time spent testing: 55.65%
Time spent testing: 55.67%
T

Time spent testing: 59.79%
Time spent testing: 59.80%
Time spent testing: 59.81%
Time spent testing: 59.82%
Time spent testing: 59.83%
Time spent testing: 59.84%
Time spent testing: 59.85%
Time spent testing: 59.86%
Time spent testing: 59.87%
Time spent testing: 59.88%
Time spent testing: 59.89%
Time spent testing: 59.90%
Time spent testing: 59.91%
Time spent testing: 59.93%
Time spent testing: 59.94%
Time spent testing: 59.95%
Time spent testing: 59.96%
Time spent testing: 59.97%
Time spent testing: 59.98%
Time spent testing: 59.99%
Time spent testing: 60.00%
Time spent testing: 60.01%
Time spent testing: 60.02%
Time spent testing: 60.04%
Time spent testing: 60.05%
Time spent testing: 60.06%
Time spent testing: 60.07%
Time spent testing: 60.08%
Time spent testing: 60.09%
Time spent testing: 60.10%
Time spent testing: 60.11%
Time spent testing: 60.12%
Time spent testing: 60.13%
Time spent testing: 60.14%
Time spent testing: 60.15%
Time spent testing: 60.16%
Time spent testing: 60.17%
T

Time spent testing: 62.99%
Time spent testing: 63.00%
Time spent testing: 63.02%
Time spent testing: 63.03%
Time spent testing: 63.04%
Time spent testing: 63.05%
Time spent testing: 63.07%
Time spent testing: 63.08%
Time spent testing: 63.09%
Time spent testing: 63.10%
Time spent testing: 63.12%
Time spent testing: 63.13%
Time spent testing: 63.14%
Time spent testing: 63.15%
Time spent testing: 63.16%
Time spent testing: 63.18%
Time spent testing: 63.19%
Time spent testing: 63.20%
Time spent testing: 63.21%
Time spent testing: 63.23%
Time spent testing: 63.24%
Time spent testing: 63.25%
Time spent testing: 63.26%
Time spent testing: 63.27%
Time spent testing: 63.29%
Time spent testing: 63.30%
Time spent testing: 63.31%
Time spent testing: 63.32%
Time spent testing: 63.34%
Time spent testing: 63.35%
Time spent testing: 63.36%
Time spent testing: 63.37%
Time spent testing: 63.38%
Time spent testing: 63.40%
Time spent testing: 63.41%
Time spent testing: 63.42%
Time spent testing: 63.43%
T

Time spent testing: 66.47%
Time spent testing: 66.48%
Time spent testing: 66.49%
Time spent testing: 66.50%
Time spent testing: 66.51%
Time spent testing: 66.52%
Time spent testing: 66.53%
Time spent testing: 66.54%
Time spent testing: 66.55%
Time spent testing: 66.56%
Time spent testing: 66.57%
Time spent testing: 66.58%
Time spent testing: 66.59%
Time spent testing: 66.60%
Time spent testing: 66.61%
Time spent testing: 66.62%
Time spent testing: 66.63%
Time spent testing: 66.64%
Time spent testing: 66.65%
Time spent testing: 66.66%
Time spent testing: 66.68%
Time spent testing: 66.69%
Time spent testing: 66.70%
Time spent testing: 66.71%
Time spent testing: 66.72%
Time spent testing: 66.73%
Time spent testing: 66.74%
Time spent testing: 66.75%
Time spent testing: 66.76%
Time spent testing: 66.77%
Time spent testing: 66.78%
Time spent testing: 66.79%
Time spent testing: 66.80%
Time spent testing: 66.81%
Time spent testing: 66.82%
Time spent testing: 66.83%
Time spent testing: 66.84%
T

Time spent testing: 69.21%
Time spent testing: 69.21%
Time spent testing: 69.22%
Time spent testing: 69.23%
Time spent testing: 69.24%
Time spent testing: 69.25%
Time spent testing: 69.26%
Time spent testing: 69.27%
Time spent testing: 69.27%
Time spent testing: 69.28%
Time spent testing: 69.29%
Time spent testing: 69.30%
Time spent testing: 69.31%
Time spent testing: 69.32%
Time spent testing: 69.33%
Time spent testing: 69.34%
Time spent testing: 69.34%
Time spent testing: 69.35%
Time spent testing: 69.36%
Time spent testing: 69.37%
Time spent testing: 69.38%
Time spent testing: 69.39%
Time spent testing: 69.40%
Time spent testing: 69.41%
Time spent testing: 69.41%
Time spent testing: 69.42%
Time spent testing: 69.43%
Time spent testing: 69.44%
Time spent testing: 69.45%
Time spent testing: 69.46%
Time spent testing: 69.47%
Time spent testing: 69.47%
Time spent testing: 69.48%
Time spent testing: 69.49%
Time spent testing: 69.50%
Time spent testing: 69.51%
Time spent testing: 69.52%
T

Time spent testing: 71.51%
Time spent testing: 71.52%
Time spent testing: 71.53%
Time spent testing: 71.53%
Time spent testing: 71.54%
Time spent testing: 71.55%
Time spent testing: 71.55%
Time spent testing: 71.56%
Time spent testing: 71.57%
Time spent testing: 71.58%
Time spent testing: 71.58%
Time spent testing: 71.59%
Time spent testing: 71.60%
Time spent testing: 71.60%
Time spent testing: 71.61%
Time spent testing: 71.62%
Time spent testing: 71.62%
Time spent testing: 71.63%
Time spent testing: 71.64%
Time spent testing: 71.64%
Time spent testing: 71.65%
Time spent testing: 71.66%
Time spent testing: 71.66%
Time spent testing: 71.67%
Time spent testing: 71.68%
Time spent testing: 71.68%
Time spent testing: 71.69%
Time spent testing: 71.70%
Time spent testing: 71.71%
Time spent testing: 71.72%
Time spent testing: 71.72%
Time spent testing: 71.73%
Time spent testing: 71.74%
Time spent testing: 71.75%
Time spent testing: 71.76%
Time spent testing: 71.76%
Time spent testing: 71.77%
T

Time spent testing: 73.66%
Time spent testing: 73.67%
Time spent testing: 73.67%
Time spent testing: 73.68%
Time spent testing: 73.69%
Time spent testing: 73.69%
Time spent testing: 73.70%
Time spent testing: 73.71%
Time spent testing: 73.71%
Time spent testing: 73.72%
Time spent testing: 73.73%
Time spent testing: 73.74%
Time spent testing: 73.74%
Time spent testing: 73.75%
Time spent testing: 73.76%
Time spent testing: 73.76%
Time spent testing: 73.77%
Time spent testing: 73.78%
Time spent testing: 73.78%
Time spent testing: 73.79%
Time spent testing: 73.80%
Time spent testing: 73.80%
Time spent testing: 73.81%
Time spent testing: 73.82%
Time spent testing: 73.82%
Time spent testing: 73.83%
Time spent testing: 73.84%
Time spent testing: 73.84%
Time spent testing: 73.85%
Time spent testing: 73.86%
Time spent testing: 73.86%
Time spent testing: 73.87%
Time spent testing: 73.88%
Time spent testing: 73.89%
Time spent testing: 73.89%
Time spent testing: 73.90%
Time spent testing: 73.91%
T

Time spent testing: 75.58%
Time spent testing: 75.59%
Time spent testing: 75.60%
Time spent testing: 75.60%
Time spent testing: 75.61%
Time spent testing: 75.61%
Time spent testing: 75.62%
Time spent testing: 75.62%
Time spent testing: 75.63%
Time spent testing: 75.64%
Time spent testing: 75.64%
Time spent testing: 75.65%
Time spent testing: 75.65%
Time spent testing: 75.66%
Time spent testing: 75.66%
Time spent testing: 75.67%
Time spent testing: 75.68%
Time spent testing: 75.68%
Time spent testing: 75.69%
Time spent testing: 75.69%
Time spent testing: 75.70%
Time spent testing: 75.71%
Time spent testing: 75.71%
Time spent testing: 75.72%
Time spent testing: 75.72%
Time spent testing: 75.73%
Time spent testing: 75.73%
Time spent testing: 75.74%
Time spent testing: 75.75%
Time spent testing: 75.75%
Time spent testing: 75.76%
Time spent testing: 75.76%
Time spent testing: 75.77%
Time spent testing: 75.77%
Time spent testing: 75.78%
Time spent testing: 75.78%
Time spent testing: 75.79%
T

In [14]:
%autoreload
test_eval = mre_pinn.training.TestEvaluation(data, batch_size=256, plot=False, view=True, interact=True)
test_eval.model = model
test_eval.test_evaluate(data)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Output…

In [None]:
# try out complex initialization schemes

def runiform(n, scale=1):
    return (2 * torch.rand(n) - 1) * scale

def cuniform(n, scale=1):
    radius = torch.rand(n) * scale 
    angle  = torch.rand(n) * 2 * np.pi
    return radius * torch.exp(1j * angle)

n = 10000
d = 2
layers = [16, 16, 16, 16]
sin = torch.sin
cis = lambda x: torch.cos(x) + 1j * torch.sin(x)
gas = lambda x: torch.exp(-x**2)
wav = lambda x: cis(x) * gas(x)

fig, axes = plt.subplots(len(layers) + 2, 2, figsize=(6, 1.5*(len(layers) + 2)), squeeze=False)

def plot_hist(ax, a):
    a = a.detach().cpu().numpy().flatten()
    if np.iscomplexobj(a):
        a = mre_pinn.utils.as_real(a, interleave=False)
    sns.histplot(a, bins=20, ax=ax, legend=True)

a = runiform((n, d))
b = runiform((n, d)) + 0j
plot_hist(axes[0,0], a)
plot_hist(axes[0,1], b)

for i, w in enumerate(layers):
    if i == 0:
        a_scale = 32 / d
        b_scale = 32 / d
    else:
        a_scale = np.sqrt(6 / d)
        b_scale = np.sqrt(6 / d)
    a = sin(a @ runiform((d, w), a_scale))
    b = cis(b @ cuniform((d, w), b_scale))
    plot_hist(axes[i+1,0], a)
    plot_hist(axes[i+1,1], b)
    d = w

a = a @ runiform((d, 1))
b = b @ cuniform((d, 1))
plot_hist(axes[-1,0], a)
plot_hist(axes[-1,1], b)
fig.tight_layout()

In [None]:
# investigate complex derivatives
break
x = np.linspace(-np.pi, np.pi, 100)

a = torch.ones(1).requires_grad_(True)
b = torch.ones(1).requires_grad_(True)

def f(a, b):
    x = torch.linspace(-np.pi, np.pi, 100)
    y = torch.complex(a * torch.cos(x), b * torch.sin(x))
    return torch.norm(y)

L = f(a, b)
L_a = torch.autograd.grad(L, a, grad_outputs=torch.ones_like(L), create_graph=True)[0][0]
L_b = torch.autograd.grad(L, b, grad_outputs=torch.ones_like(L), create_graph=True)[0][0]

print(L)
print(L_a, L_b)

def diff(f, a, b, da=0, db=0):
    return (f(a + da, b + db) - f(a - da, b - db)) / (2 * (da + db))

d_a = diff(f, a, b, da=1e-3)
d_b = diff(f, a, b, db=1e-3)

print(d_a, d_b)

a, b

In [55]:
geom = deepxde.geometry.Hypercube([-1,-1], [1,1])


ns = [16, 64, 144]

fig, axes = plt.subplots(len(ns), 3,  figsize=(6,2*len(ns)))

for i, n in enumerate(ns):

    
    c = [str(j / n) for j in np.arange(n)]

    if i == 0:
        axes[i,0].set_title('Uniform')
        axes[i,1].set_title('Sobol')
        axes[i,2].set_title('Pseudo-random')
        
    axes[i,0].set_ylabel(f'N = {n}')
    
    x = geom.uniform_points(n)
    axes[i,0].scatter(*x.T, c=c)

    x = geom.random_points(n, 'Sobol')
    axes[i,1].scatter(*x.T, c=c)

    x = geom.random_points(n, 'pseudo')
    axes[i,2].scatter(*x.T, c=c)
    
fig.tight_layout()
fig.savefig('domain_sampling.png', dpi=200, bbox_inches='tight')

<IPython.core.display.Javascript object>