In [None]:
from pathlib import Path
from time import time

import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.ticker import FuncFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable

from kgi import apply_kgi_to_layer

In [None]:
# attempt to enable LaTeX rendering
# change to `False` if you get an error during plotting (latex not installed)
plt.rcParams['text.usetex'] = True

# Problem

In [None]:
# data
n_pts_edge = 101  # using the size of true solution, but this is unnecessary
eval_pts = np.linspace(0, 1, num=n_pts_edge)[:, None]
xv, yv = np.meshgrid(eval_pts[:, 0], eval_pts[:, 0], indexing='ij')
sol_true = np.load('./datasets/stokes.npz')['arr_0']
sol_true_grid = sol_true.reshape(n_pts_edge, n_pts_edge, 3)

# plot solution
fig, ax = plt.subplots(dpi=200, figsize=(3, 3))
p_plt = ax.imshow(sol_true_grid[:, :, 2].T, origin='lower',
                  vmin=-0.03, vmax=0.03, cmap="turbo", alpha=.6)
vec_space = 4
v_plt = ax.quiver(xv[::vec_space, ::vec_space] * 100,
                  yv[::vec_space, ::vec_space] * 100,
                  sol_true_grid[::vec_space, ::vec_space, 0],
                  sol_true_grid[::vec_space, ::vec_space, 1],
                  color="k", scale=.5, headwidth=5)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
# legend and colorbar
qk = ax.quiverkey(v_plt, 1.22, .85, 0.1, '$|\\mathbf{u}|=0.1$',
                  labelpos='N', coordinates='axes')
cb_ax = fig.add_axes([.82, 0.11, 0.03, 0.4])
cbar = fig.colorbar(p_plt, cax=cb_ax, label='$p\\ (10^{-2})$')
cbar.set_ticks(np.arange(-.03, .031, 0.01))
cbar.ax.yaxis.set_major_formatter(FuncFormatter(lambda x_, _: f'{round(x_ * 100)}'))

# boundary condition
divider = make_axes_locatable(ax)
ax2 = divider.append_axes("top", size=0.8, pad=-0.1)
ax2.plot(sol_true_grid[:, -1, 0], c='k')
ax2.set_aspect(100)
ax2.set_xlim(0, 100)
ax2.set_xticks([])
ax2.set_yticks([0, .1, .2])
ax2.set_ylabel("$\hat u_1$")
plt.savefig("figs/deeponet_problem.pdf", bbox_inches='tight', pad_inches=0.01)
plt.show()

# PDE components

In [None]:
# PDE equation
def pde(xy, uvp, _):
    mu = 0.01
    u, v, p = uvp[..., 0:1], uvp[..., 1:2], uvp[..., 2:3]
    grad_u = dde.zcs.LazyGrad(xy, u)
    grad_v = dde.zcs.LazyGrad(xy, v)
    grad_p = dde.zcs.LazyGrad(xy, p)
    # first order
    du_x = grad_u.compute((1, 0))
    dv_y = grad_v.compute((0, 1))
    dp_x = grad_p.compute((1, 0))
    dp_y = grad_p.compute((0, 1))
    # second order
    du_xx = grad_u.compute((2, 0))
    du_yy = grad_u.compute((0, 2))
    dv_xx = grad_v.compute((2, 0))
    dv_yy = grad_v.compute((0, 2))
    motion_x = mu * (du_xx + du_yy) - dp_x
    motion_y = mu * (dv_xx + dv_yy) - dp_y
    mass = du_x + dv_y
    return motion_x, motion_y, mass


# Geometry
geom = dde.geometry.Rectangle([0, 0], [1, 1])


# Boundary condition
# other boundary conditions will be enforced by output transform
def bc_slip_top_func(x, aux_var):
    # using (perturbation / 10 + 1) * x * (1 - x)
    return (aux_var / 10 + 1.) * dde.backend.as_tensor(x[:, 0:1] * (1 - x[:, 0:1]))  # noqa


bc_slip_top = dde.icbc.DirichletBC(
    geom=geom,
    func=bc_slip_top_func,
    on_boundary=lambda x, on_boundary: np.isclose(x[1], 1.),
    component=0)

# PDE object
pde = dde.data.PDE(
    geom,
    pde,
    bcs=[bc_slip_top],
    num_domain=5000,
    num_boundary=4000,  # sampling a bit more points on boundary (1000 on top bc)
    num_test=500,
)

# Function space
func_space = dde.data.GRF(length_scale=0.2)

# Data
data = dde.zcs.PDEOperatorCartesianProd(
    pde, func_space, eval_pts, num_function=1000,
    function_variables=[0], num_test=100, batch_size=50
)


# Output transform for zero boundary conditions
def out_transform(inputs, outputs):
    x, y = inputs[1][:, 0], inputs[1][:, 1]
    # horizontal velocity on left, right, bottom
    u = outputs[:, :, 0] * (x * (1 - x) * y)[None, :]
    # vertical velocity on all edges
    v = outputs[:, :, 1] * (x * (1 - x) * y * (1 - y))[None, :]
    # pressure on bottom
    p = outputs[:, :, 2] * y[None, :]
    return dde.backend.stack((u, v, p), axis=2)  # noqa

# Training

In [None]:
def train(kgi, seed=0, iterations=10000):
    torch.manual_seed(seed)
    # Net
    net = dde.nn.DeepONetCartesianProd(  # noqa
        [n_pts_edge, 128, 128, 128],
        [2, 128, 128, 128],
        "tanh",
        "Glorot normal",
        num_outputs=3,
        multi_output_strategy="independent"
    )
    net.apply_output_transform(out_transform)

    # KGI
    if kgi:
        for m in net.branch:
            apply_kgi_to_layer(m.linears[0], knot_low=-1.5, knot_high=1.5,
                               perturb_factor=0.2, kgi_by_bias=True)
            for i in range(1, len(m.linears)):
                apply_kgi_to_layer(m.linears[i], knot_low=-0.8, knot_high=0.8,
                                   perturb_factor=0.2, kgi_by_bias=True)
        for m in net.trunk:
            apply_kgi_to_layer(m.linears[0], knot_low=0.2, knot_high=0.8,
                               perturb_factor=0.2, kgi_by_bias=True)
            for i in range(1, len(m.linears)):
                apply_kgi_to_layer(m.linears[i], knot_low=-0.8, knot_high=0.8,
                                   perturb_factor=0.2, kgi_by_bias=True)

    # Model
    model = dde.zcs.Model(data, net)
    model.compile("adam", lr=0.001, decay=("inverse time", 10000, 0.5))
    loss_history, train_state = model.train(iterations=iterations, display_every=20)

    # Evaluation
    func_feats = func_space.random(1)
    v = func_space.eval_batch(func_feats, eval_pts)
    v[:] = 0.  # true solution uses zero perturbation
    xy = np.vstack((np.ravel(xv), np.ravel(yv))).T
    sol_pred = model.predict((v, xy))[0]
    rel_err = {
        "vx": dde.metrics.l2_relative_error(sol_true[:, 0], sol_pred[:, 0]),  # noqa
        "vy": dde.metrics.l2_relative_error(sol_true[:, 1], sol_pred[:, 1]),  # noqa
        "p": dde.metrics.l2_relative_error(sol_true[:, 2], sol_pred[:, 2])  # noqa
    }
    return loss_history, rel_err, train_state

In [None]:
# train all models
seeds = list(range(10))  # use `seeds = [0]` for fast test
epochs = 50000  # use a smaller one for fast test
out_dir = Path("results/deeponet_paper")

out_dir.mkdir(exist_ok=True, parents=True)
for seed_ in seeds:
    for kgi_ in [False, True]:
        name_ = f"{seed_}_{kgi_}"
        if not (out_dir / name_).exists():
            t0 = time()
            hist, err, _ = train(kgi_, seed_, epochs)
            hist_train = np.array(hist.loss_train)
            np.savetxt(out_dir / name_, hist_train, header=f"{err['vx']}, {err['vy']}, {err['p']}")
            print(f"{name_} trained in {(time() - t0) / 60:.1f} min, loss={hist_train[-1, -1]:.2e}")
        else:
            print(f"{name_} exists")

# Analysis