 TODOs (intentionally not implemented):
- `build_training_data`
- `class PINN`
- `model_loss`
- `train`


In [None]:
import math
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
from scipy.integrate import quad


def build_training_data(
    num_bc: Tuple[int, int] = (25, 25),
    num_ic: int = 50,
    num_f: int = 10_000,
    seed: int = 0,
):
    """
    TODO : Implement the training-data construction.

    Target behavior (same as original script):
    - Create BC samples for x=-1 and x=1, t in [0,1], u=0
    - Create IC samples for t=0, x in [-1,1], u=-sin(pi x)
    - Create interior collocation points (x,t) in (-1,1)x(0,1)
    - Return: xf, tf, x0, t0, u0
    """
    raise NotImplementedError("TODO: implement build_training_data")


class PINN(nn.Module):
    def __init__(self, in_dim: int = 2, hidden_dim: int = 20, num_hidden: int = 8, out_dim: int = 1):
        super().__init__()
        """
        TODO : Implement the fully connected Tanh network.

        """
        raise NotImplementedError("TODO: implement PINN.__init__")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """TODO : Implement forward pass."""
        raise NotImplementedError("TODO: implement PINN.forward")


def model_loss(net: nn.Module, xf: torch.Tensor, tf: torch.Tensor, x0: torch.Tensor, t0: torch.Tensor, u0: torch.Tensor):
    """
    TODO : Implement PINN loss.

    """
    raise NotImplementedError("TODO: implement model_loss")


def solve_burgers(x: np.ndarray, t: float, nu: float) -> np.ndarray:
    if t <= 0:
        raise ValueError("t must be > 0 for this reference solution.")

    def f(y):
        return np.exp(-np.cos(np.pi * y) / (2 * np.pi * nu))

    def g(y):
        return np.exp(-(y**2) / (4 * nu * t))

    u = np.zeros_like(x, dtype=np.float64)
    for i, xi in enumerate(x):
        if abs(xi) != 1.0:
            num_fun = lambda eta: np.sin(np.pi * (xi - eta)) * f(xi - eta) * g(eta)
            den_fun = lambda eta: f(xi - eta) * g(eta)
            num = -quad(num_fun, -np.inf, np.inf, limit=200)[0]
            den = quad(den_fun, -np.inf, np.inf, limit=200)[0]
            u[i] = num / den
    return u


def train():
    """
    TODO: Implement training loop.
    """
    raise NotImplementedError("TODO: implement train")


def evaluate(net: nn.Module):
    import matplotlib.pyplot as plt

    net.eval()
    nu = 0.01 / math.pi

    t_test = [0.25, 0.5, 0.75, 1.0]
    x_test = np.linspace(-1.0, 1.0, 1001)

    u_pred = []
    u_true = []

    with torch.no_grad():
        x_torch = torch.from_numpy(x_test).view(-1, 1)
        for t in t_test:
            t_torch = torch.full_like(x_torch, t)
            xt = torch.cat([x_torch, t_torch], dim=1)
            u_p = net(xt).cpu().numpy().reshape(-1)
            u_t = solve_burgers(x_test, t, nu)
            u_pred.append(u_p)
            u_true.append(u_t)

    u_pred = np.stack(u_pred, axis=0)
    u_true = np.stack(u_true, axis=0)
    err = np.linalg.norm(u_pred - u_true) / np.linalg.norm(u_true)
    print(f"Relative L2 error: {err:.4f}")

    fig, axes = plt.subplots(2, 2, figsize=(10, 7), constrained_layout=True)
    axes = axes.ravel()
    for i, t in enumerate(t_test):
        axes[i].plot(x_test, u_pred[i], "-", lw=2, label="Prediction")
        axes[i].plot(x_test, u_true[i], "--", lw=2, label="Target")
        axes[i].set_ylim(-1.1, 1.1)
        axes[i].set_xlabel("x")
        axes[i].set_ylabel(f"u(x, {t})")
    axes[0].legend()
    plt.show()


In [None]:
# Run after implementing all TODOs:
# model = train()
# evaluate(model)
