In [None]:
### DeepXDE utilized
!pip install deepxde

In [None]:
import deepxde as dde
dde.backend.set_default_backend("pytorch")
from deepxde.backend import torch, backend_name
import torch

import numpy as np
import matplotlib.pyplot as plt


'''
"# DeepXDE will internally create tf.keras layers with these specs."
If there is any issue with torch.sin or other torch.###,
just replace it with tf.sin or tf.###.
'''

# # For reproducibility: fix all random seeds used by DeepXDE / TensorFlow backend
# dde.config.set_random_seed(1234)

In [None]:
# PDE: 1D viscous Burgers' equation
#   u_t + u u_x - ν u_xx = 0  with ν = 0.01 / π
def pde(x, y):
    # x : input tensor with columns [x_coord, t_coord]
    # y : network output u(x,t)

    # First derivatives: ∂u/∂x and ∂u/∂t
    dy_x = dde.grad.jacobian(y, x, i=0, j=0)  # du/dx
    dy_t = dde.grad.jacobian(y, x, i=0, j=1)  # du/dt

    # Second derivative in space: ∂²u/∂x²
    dy_xx = dde.grad.hessian(y, x, i=0, j=0)  # d²u/dx²

    # PDE residual: u_t + u u_x - ν u_xx
    return dy_t + y * dy_x - 0.01 / np.pi * dy_xx


# ---- Boundary condition helpers ----
def bc_fn(x, on_boundary):
    """Return True on spatial boundary x = -1 or x = 1."""
    # on_boundary is provided by DeepXDE (True for any boundary point of geomtime)
    if not on_boundary:
        return False
    # We only want the spatial boundaries at x = -1 or x = 1 (for any t)
    return np.isclose(x[0], -1) or np.isclose(x[0], 1)


def bc_val(x):
    """Dirichlet boundary value u(t, ±1) = 0."""
    # Constant boundary: u = 0 on both x = -1 and x = 1, for all t
    return 0


# ---- Initial condition helpers ----
def ic_fn(x, on_initial):
    """Return True on initial line t = 0 for all x."""
    # on_initial is True for t at the initial time; we additionally check t ≈ 0
    return on_initial and dde.utils.isclose(x[1], 0)


def ic_val(x):
    """Initial profile u(0, x) = -sin(π x)."""
    # x is a batch of points with shape (N, 2); x[:, 0:1] selects the spatial coordinate
    return -np.sin(np.pi * x[:, 0:1])


# ---- Geometry in (x, t) ----
# Spatial domain: x ∈ [-1, 1]
geom = dde.geometry.Interval(-1, 1)
# Time domain: t ∈ [0, 1]
timedomain = dde.geometry.TimeDomain(0, 1)
# Tensor product domain: (x, t) ∈ [-1,1] × [0,1]
geomtime = dde.geometry.GeometryXTime(geom, timedomain)


# ---- BC and IC objects ----
# Dirichlet BC: u(t, ±1) = 0 for all t
bc = dde.icbc.DirichletBC(geomtime, bc_val, bc_fn)
# Initial condition: u(0, x) = -sin(π x) for all x
ic = dde.icbc.IC(geomtime, ic_val, ic_fn)


# ---- Data and model definition ----
data = dde.data.TimePDE(
    geomtime,
    pde,          # PDE residual function
    [bc, ic],     # list of boundary/initial conditions
    num_domain=2540,   # number of interior collocation points in (x,t)
    num_boundary=80,   # number of boundary points on x = ±1
    num_initial=160,   # number of points on initial line t = 0
)

# Feedforward network: input (x,t) → 3 hidden layers (20 units each) → output u
layer_size = [2] + [20] * 3 + [1]
activation = "tanh"
initializer = "Glorot uniform"
net = dde.nn.FNN(layer_size, activation, initializer)

# Wrap network + data into a DeepXDE Model
model = dde.Model(data, net)


# Custom metric: always evaluate L2 error on external reference grid
def l2_rel_err_ext(_, y_pred):
    """
    L2 relative error on external reference data (X_ref, Y_ref).
    The first argument is required by DeepXDE API but not used here.
    """
    y_test_pred = model.predict(X_ref)  # predict on external test grid
    return dde.metrics.l2_relative_error(Y_ref, y_test_pred)


# ---- Compile and train with Adam ----
model.compile(
    "adam",
    lr=1e-3,
    metrics=[l2_rel_err_ext]  # report external L2 error during training
)

losshistory, train_state = model.train(iterations=15000)


# ---- Final evaluation on external test data ----
y_pred = model.predict(X_ref)
final_l2_error = dde.metrics.l2_relative_error(Y_ref, y_pred)

# PDE residual evaluated at test points (for diagnostics)
f = model.predict(X_ref, operator=pde)

print("Final L2 relative error (on external test data): {:.2e}".format(final_l2_error))