In [3]:
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
import deepxde as dde
import numpy as np


def gen_testdata():
    data = np.load("../dataset/Burgers.npz")
    t, x, exact = data["t"], data["x"], data["usol"].T
    xx, tt = np.meshgrid(x, t)
    X = np.vstack((np.ravel(xx), np.ravel(tt))).T
    y = exact.flatten()[:, None]
    return X, y


def pde(x, y):
    dy_x = dde.grad.jacobian(y, x, i=0, j=0)
    dy_t = dde.grad.jacobian(y, x, i=0, j=1)
    dy_xx = dde.grad.hessian(y, x, i=0, j=0)
    return dy_t + y * dy_x - 0.01 / np.pi * dy_xx


geom = dde.geometry.Interval(-1, 1)
timedomain = dde.geometry.TimeDomain(0, 0.99)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary)
ic = dde.icbc.IC(
    geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial
)

data = dde.data.TimePDE(
    geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160
)
net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
model = dde.Model(data, net)

model.compile("adam", lr=1e-3)
model.train(iterations=15000)
model.compile("L-BFGS")
losshistory, train_state = model.train()
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

X, y_true = gen_testdata()
y_pred = model.predict(X)
f = model.predict(X, operator=pde)
print("Mean residual:", np.mean(np.absolute(f)))
print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred))
np.savetxt("test.dat", np.hstack((X, y_true, y_pred)))

Using backend: pytorch
Other supported backends: tensorflow.compat.v1, tensorflow, jax, paddle.
paddle supports more examples now and is recommended.


Compiling model...
'compile' took 0.559996 s

Training model...

Step      Train loss                        Test loss                         Test metric
0         [5.33e-01, 1.66e-01, 4.62e-01]    [5.33e-01, 1.66e-01, 4.62e-01]    []  
1000      [4.34e-02, 2.26e-03, 6.42e-02]    [4.34e-02, 2.26e-03, 6.42e-02]    []  
2000      [3.20e-02, 2.64e-04, 4.78e-02]    [3.20e-02, 2.64e-04, 4.78e-02]    []  
3000      [2.19e-02, 1.76e-04, 3.24e-02]    [2.19e-02, 1.76e-04, 3.24e-02]    []  
4000      [1.57e-02, 6.62e-05, 1.53e-02]    [1.57e-02, 6.62e-05, 1.53e-02]    []  
5000      [8.77e-03, 3.61e-05, 6.88e-03]    [8.77e-03, 3.61e-05, 6.88e-03]    []  
6000      [6.58e-03, 1.91e-05, 5.76e-03]    [6.58e-03, 1.91e-05, 5.76e-03]    []  
7000      [6.91e-03, 1.11e-05, 5.45e-03]    [6.91e-03, 1.11e-05, 5.45e-03]    []  
8000      [6.05e-03, 1.01e-05, 4.98e-03]    [6.05e-03, 1.01e-05, 4.98e-03]    []  
9000      [4.58e-03, 5.34e-06, 4.45e-03]    [4.58e-03, 5.34e-06, 4.45e-03]    []  
10000     [4.43

KeyboardInterrupt: 