# Forward PINN (2D Acoustic) — **(using c=1, consistent ICs; hard IC/BC; dataset-driven)**

This notebook trains a forward Physics-Informed Neural Network with:
- constant wavespeed, $c \equiv 1$,
- **two IC snapshots** from the same $\phi^*$ at $t=0$ and $t=\tau$ (here $\tau=0.05$), **enforced as hard constraints**,
- Neumann **top boundary** $\partial{\phi}/\partial z=0$ at $z=0$ which $\phi^*$ satisfies are **hard-enforced**,
- **limited dataset-based training samples** drawn from the analytical $\phi^*$ and $s(x,z,t)$ for supervised consistency.

**Domain:** $x,z\in[0,1]$, training window $t\in[0,0.5]$.  
**Goal:** the PDE residual and data loss decrease together; IC/BC are satisfied by construction; and the snapshot error remains small (relative $L^2$).


In [1]:
import torch, torch.nn as nn, torch.optim as optim
import numpy as np, matplotlib.pyplot as plt, math, time
from dataclasses import dataclass
print("PyTorch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
import os
from typing import Tuple
from google.colab import drive
drive.mount('/content/drive')



PyTorch: 2.8.0+cu126 | CUDA available: True
Mounted at /content/drive


In [2]:
# Import the data from google drive
def load_component(
    base_dir: str,
    component: str,                # 'Z' or 'X'
    t_s: float,                    # end time to keep (same meaning as your code)
    t_st: float,                   # start time to keep
    subsample: int = 100,          # like l_f in your code
    device: torch.device | str = "cpu",
) -> Tuple[torch.Tensor, torch.Tensor, list[str]]:
    """
    Returns:
      t_sub: (l_sub, 1) torch float32
      S:     (n_seis * l_sub, 1) torch float32 stacked like your code
      files: list of file names used (sorted)
    """
    files = sorted([f for f in os.listdir(base_dir) if len(f) >= 6 and f[-6] == component])
    if not files:
        raise FileNotFoundError(f"No {component}-component files found in {base_dir}")
    arrays = [np.loadtxt(os.path.join(base_dir, f)) for f in files]
    # Specfem time shift to start at 0, same as: -arr[0,0] + arr[:,0]
    t_spec = arrays[0][:, 0] - arrays[0][0, 0]
    # Keep only [t_st, t_s]
    mask = (t_spec <= t_s) & (t_spec >= t_st)
    if not np.any(mask):
        raise ValueError("No samples in the requested time window; check t_st and t_s.")
    first = np.argmax(mask)                       # first True
    last  = len(mask) - 1 - np.argmax(mask[::-1]) # last True
    index = np.arange(first, last + 1, subsample)
    t_sub = t_spec[index].reshape(-1, 1)
    # Stack station traces vertically, column 1 is amplitude (same as your [:,1])
    S = np.concatenate([a[index, 1:2] for a in arrays], axis=0)
    # Convert to torch
    t_sub = torch.from_numpy(t_sub.astype(np.float32)).to(device)
    S     = torch.from_numpy(S.astype(np.float32)).to(device)
    return t_sub, S, files
t_s = 0.75
t_st = 0
t_sub_Z, Sz, _ = load_component("/content/drive/MyDrive/seismograms", "Z", t_s=t_s, t_st=t_st, subsample=100)
t_sub_X, Sx, _ = load_component("/content/drive/MyDrive/seismograms", "X", t_s=t_s, t_st=t_st, subsample=100)



In [3]:
def build_X_S(t_sub: torch.Tensor,
              z0_s: float, zl_s: float, n_seis: int,
              ax: float, Lx: float = 1.0, Lz: float = 1.0,
              device: str | torch.device = "cpu") -> torch.Tensor:
    """
    Replicates your TF layout:
      For i=0..n_seis-1:
        x = ax/Lx (fixed), z = (z0_s - i*d_s)/Lz, t = t_sub
    Returns: X_S of shape (n_seis*l_sub, 3)
    """
    t_sub = t_sub.to(device).reshape(-1, 1)         # (l_sub, 1)
    l_sub = t_sub.shape[0]
    d_s = abs(zl_s - z0_s) / max(n_seis - 1, 1)
    rows = []
    for i in range(n_seis):
        x_col = torch.full((l_sub, 1), ax / Lx, dtype=t_sub.dtype, device=device)
        z_col = torch.full((l_sub, 1), (z0_s - i * d_s) / Lz, dtype=t_sub.dtype, device=device)
        rows.append(torch.cat([x_col, z_col, t_sub], dim=1))
    return torch.vstack(rows)



In [4]:
@dataclass
class Domain:
    ax: float = 1.0     # x-extent
    az: float = 1.0     # z-extent
    t_max: float = 0.5  # training time window
@dataclass
class TrainCfg:
    batch_size: int = 20000
    adam_lr: float = 1e-3
    adam_epochs: int = 1000
    use_lbfgs: bool = True
    lbfgs_max_iter: int = 250
    seed: int = 1234
@dataclass
class Weights:
    w_pde: float = 1.0
    w_ic1: float = 1.0
    w_ic2: float = 1.0
    w_seis: float = 1.0
    w_bc_top_neu: float = 1.0



In [5]:
class MLP(nn.Module):
    def __init__(self, in_dim=3, out_dim=1, hidden=128, layers=6, act=nn.Tanh):
        super().__init__()
        net, last = [], in_dim
        for _ in range(layers):
            net += [nn.Linear(last, hidden), act()]; last = hidden
        net += [nn.Linear(last, out_dim)]
        self.net = nn.Sequential(*net)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x):
        return self.net(x)



In [6]:

import torch
import torch.nn as nn

class ForwardPINN:
    """
    Physics-Informed Neural Network for the 2D acoustic wave equation (c = 1).
    - Optional source function can be provided via `source_fn`.
    - Optional hard constraints can be applied outside this class via an output transform on the model.
    - Optional IC targets can be supplied via `set_ic_targets`; otherwise IC losses default to zero.
    """
    def __init__(self, dom: Domain, model: nn.Module, source_fn=None, device_override=None):
        self.dom = dom
        self.model = model.to(device if device_override is None else device_override)
        self.device = device if device_override is None else device_override
        self.source_fn = source_fn  # callable: (x,z,t)->tensor or None
        self._X_S = None
        self._Sx = None
        self._Sz = None

        # Optional IC targets (soft constraints). If unset, they are ignored.
        self._ic_grid = None  # tuple (Xg, Zg)
        self._ic_times = None # list of t tensors
        self._ic_targets = None # list of (Ux_tgt, Uz_tgt)

    @staticmethod
    def grad(y, x):
        return torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y),
                                   create_graph=True, retain_graph=True)[0]

    def pde_residual(self, pts):
        """
        Residual for phi_tt - phi_xx - phi_zz = s(x,z,t), with c = 1.
        Returns (eq, phi_z) where phi_z is used for a Neumann top boundary condition if desired.
        """
        pts = pts.requires_grad_(True)
        phi = self.model(pts)
        x, z, t = pts[:,0:1], pts[:,1:2], pts[:,2:3]
        d = self.grad(phi, pts)
        phi_x, phi_z, phi_t = d[:,0:1], d[:,1:2], d[:,2:3]
        phi_xx = self.grad(phi_x, pts)[:,0:1]
        phi_zz = self.grad(phi_z, pts)[:,1:2]
        phi_tt = self.grad(phi_t, pts)[:,2:3]

        if self.source_fn is None:
            src = torch.zeros_like(phi)
        else:
            # Ensure source is evaluated on the correct device and shape
            src = self.source_fn(x, z, t)
            if src.shape != phi.shape:
                raise ValueError("source_fn must return a tensor with shape matching phi. "
                                 f"Got {src.shape}, expected {phi.shape}.")

        eq = phi_tt - phi_xx - phi_zz - src   # c = 1
        return eq, phi_z

    # Samplers
    def sample_pde(self, n):
        x = torch.rand(n,1, device=self.device) * self.dom.ax
        z = torch.rand(n,1, device=self.device) * self.dom.az
        t = torch.rand(n,1, device=self.device) * self.dom.t_max
        return torch.cat([x,z,t], dim=1)

    def sample_top(self, n):
        x = torch.rand(n,1, device=self.device) * self.dom.ax
        z = torch.zeros(n,1, device=self.device)  # top boundary at z=0
        t = torch.rand(n,1, device=self.device) * self.dom.t_max
        return torch.cat([x,z,t], dim=1)

    def sample_ic_grid(self, n):
        xs = torch.linspace(0, self.dom.ax, n, device=self.device)
        zs = torch.linspace(0, self.dom.az, n, device=self.device)
        X, Z = torch.meshgrid(xs, zs, indexing='ij')
        return X.reshape(-1,1), Z.reshape(-1,1)

    def set_seismo_data(self, X_S: torch.Tensor, Sx: torch.Tensor, Sz: torch.Tensor):
        self._X_S = X_S.to(self.device).reshape(-1, 3)
        self._Sx   = Sx.to(self.device).reshape(-1, 1)
        self._Sz   = Sz.to(self.device).reshape(-1, 1)

    def set_ic_targets(self, Xg, Zg, times, targets):
        """
        Provide soft initial-condition targets:
        - Xg, Zg: flattened coordinate grids (N,1) tensors on device
        - times: list of t tensors, each (N,1) on device
        - targets: list of (Ux_tgt, Uz_tgt), each (N,1) pair
        """
        self._ic_grid = (Xg, Zg)
        self._ic_times = times
        self._ic_targets = targets

    # displacement-only seismogram loss
    def seismo_loss(self):
        if self._X_S is None:
            return torch.tensor(0.0, device=self.device)
        ux, uz = self.predict_displacement_components(self._X_S)
        return torch.mean((ux - self._Sx)**2) + torch.mean((uz - self._Sz)**2)

    def predict_displacement_components(self, pts):
        # Assuming displacement components are the spatial gradients of phi
        pts = pts.requires_grad_(True)
        phi = self.model(pts)
        d = self.grad(phi, pts)
        ux, uz = d[:, 0:1], d[:, 1:2]
        return ux, uz

    def _ic_losses(self):
        if self._ic_grid is None or self._ic_times is None or self._ic_targets is None:
            return []
        Xg, Zg = self._ic_grid
        losses = []
        for tvec, (Ux_tgt, Uz_tgt) in zip(self._ic_times, self._ic_targets):
            pts = torch.cat([Xg, Zg, tvec], dim=1).requires_grad_(True)
            phi = self.model(pts)
            d = self.grad(phi, pts)
            ux, uz = d[:,0:1], d[:,1:2]
            losses.append(torch.mean((ux - Ux_tgt)**2) + torch.mean((uz - Uz_tgt)**2))
        return losses

    def train(self, cfg: TrainCfg, w: Weights, X_S: torch.Tensor=None, Sx: torch.Tensor=None, Sz: torch.Tensor=None):
        torch.manual_seed(cfg.seed)
        opt = torch.optim.Adam(self.model.parameters(), lr=cfg.adam_lr)

        # Optional: seismogram supervision
        if X_S is not None and Sx is not None and Sz is not None:
            self.set_seismo_data(X_S, Sx, Sz)

        # Adam phase
        for epoch in range(1, cfg.adam_epochs+1):
            opt.zero_grad()

            # PDE residual
            eq, _ = self.pde_residual(self.sample_pde(cfg.batch_size))
            L_pde = torch.mean(eq**2)

            # Top Neumann boundary (dphi/dz = 0)
            _, phi_z_top = self.pde_residual(self.sample_top(max(1, cfg.batch_size//4)))
            L_bc = torch.mean(phi_z_top**2)

            # IC losses (if provided)
            ic_losses = self._ic_losses()
            L_ic = sum(ic_losses) if ic_losses else torch.tensor(0.0, device=self.device)

            # Seismogram loss (if provided)
            L_seis = self.seismo_loss()

            loss = w.w_pde*L_pde + w.w_bc_top_neu*L_bc + w.w_ic1*L_ic + w.w_seis*L_seis
            loss.backward()
            opt.step()

            if epoch % 200 == 0 or epoch == 1:
                print(f"[Adam {epoch:4d}] total={loss.item():.3e} | pde={L_pde.item():.3e} | bc={L_bc.item():.3e} | ic={L_ic.item():.3e} | seis={L_seis.item():.3e}")

        # L-BFGS refine (optional)
        if getattr(cfg, "use_lbfgs", False):
            print("\\nSwitching to L-BFGS...")
            lbfgs = torch.optim.LBFGS(self.model.parameters(),
                                      max_iter=cfg.lbfgs_max_iter,
                                      line_search_fn="strong_wolfe")
            def closure():
                lbfgs.zero_grad()
                eq, _ = self.pde_residual(self.sample_pde(cfg.batch_size))
                L_pde = torch.mean(eq**2)
                _, phi_z_top = self.pde_residual(self.sample_top(max(1, cfg.batch_size//4)))
                L_bc = torch.mean(phi_z_top**2)
                ic_losses = self._ic_losses()
                L_ic = sum(ic_losses) if ic_losses else torch.tensor(0.0, device=self.device)
                loss = w.w_pde*L_pde + w.w_bc_top_neu*L_bc + w.w_ic1*L_ic + w.w_seis*self.seismo_loss()
                loss.backward()
                return loss
            lbfgs.step(closure)
            print("L-BFGS done.")


In [7]:
dom = Domain()
model = MLP(hidden=128, layers=6)
solver = ForwardPINN(dom, model)
cfg = TrainCfg(adam_epochs=3000, lbfgs_max_iter=500, batch_size=32000)
wts = Weights(w_pde=0.01, w_ic1=1.0, w_ic2=1.0, w_bc_top_neu=1.0,w_seis=1.0)
X_S = build_X_S(
    t_sub=t_sub_Z,         # use the same t grid you used for Sz/Sx
    z0_s=0.01, zl_s=1,
    n_seis=20,
    ax=dom.ax, Lx=3, Lz=3,
    device=device
)
solver.train(cfg, wts, X_S, Sx, Sz)



  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


[Adam    1] total=1.519e-02 | pde=2.637e-04 | bc=5.981e-03 | ic=0.000e+00 | seis=9.205e-03
[Adam  200] total=4.577e-07 | pde=3.227e-06 | bc=5.437e-09 | ic=0.000e+00 | seis=4.200e-07
[Adam  400] total=1.503e-07 | pde=3.840e-06 | bc=1.818e-09 | ic=0.000e+00 | seis=1.101e-07
[Adam  600] total=5.481e-08 | pde=2.104e-06 | bc=4.385e-09 | ic=0.000e+00 | seis=2.939e-08
[Adam  800] total=2.775e-08 | pde=1.563e-06 | bc=3.900e-09 | ic=0.000e+00 | seis=8.223e-09
[Adam 1000] total=1.845e-08 | pde=1.281e-06 | bc=2.402e-09 | ic=0.000e+00 | seis=3.235e-09
[Adam 1200] total=1.397e-08 | pde=1.076e-06 | bc=1.318e-09 | ic=0.000e+00 | seis=1.889e-09
[Adam 1400] total=1.131e-08 | pde=9.052e-07 | bc=8.102e-10 | ic=0.000e+00 | seis=1.449e-09
[Adam 1600] total=9.756e-09 | pde=7.795e-07 | bc=6.597e-10 | ic=0.000e+00 | seis=1.301e-09
[Adam 1800] total=8.353e-09 | pde=6.485e-07 | bc=6.285e-10 | ic=0.000e+00 | seis=1.239e-09
[Adam 2000] total=7.164e-09 | pde=5.334e-07 | bc=6.230e-10 | ic=0.000e+00 | seis=1.207e-09

In [8]:
import torch
from torch import nn
# --- Hard-constraint utilities (no MMS) ---
L   = 1.0   # domain length in x (set to your domain)
H   = 1.0   # domain height in z (set to your domain)
def Bx(x: torch.Tensor) -> torch.Tensor:
    """Zero at x=0 and x=L; O(1) inside. Prevents correction from changing Dirichlet BCs in x."""
    xl = x / L
    return xl * (1.0 - xl)
def Bz(z: torch.Tensor, both_ends: bool = False) -> torch.Tensor:
    """Neumann at z=0 (and optionally at z=H)."""
    zh = z / H
    if not both_ends:
        return zh ** 2
    return (zh ** 2) * ((1.0 - zh) ** 2)
class ConstrainedModel(nn.Module):
    """
    φ(x,z,t) = D_x(x,z,t) + Bx(x) * Bz(z) * uθ(x,z,t)
      - D_x(x,z,t) = ((L-x)/L)*d1(z,t) + (x/L)*d2(z,t) enforces Dirichlet at x=0 and x=L.
    Default here uses d1=d2=0 (homogeneous Dirichlet). Modify constants below if needed.
    """
    def __init__(self, base_net: nn.Module, both_ends_neumann: bool = False,
                 d1_value: float = 0.0, d2_value: float = 0.0):
        super().__init__()
        self.base = base_net
        self.both_ends_neumann = both_ends_neumann
        self.register_buffer("d1_const", torch.tensor(float(d1_value)))
        self.register_buffer("d2_const", torch.tensor(float(d2_value)))
    def forward(self, xzt: torch.Tensor) -> torch.Tensor:
        x = xzt[:, 0:1]
        z = xzt[:, 1:2]
        # t = xzt[:, 2:3]  # time left unconstrained
        # Hard Dirichlet in x
        d1 = self.d1_const.expand_as(x)
        d2 = self.d2_const.expand_as(x)
        D  = ((L - x) / L) * d1 + (x / L) * d2
        # Interior correction with optional Neumann in z
        u  = self.base(xzt)
        return D + Bx(x) * Bz(z, both_ends=self.both_ends_neumann) * u



In [9]:
# Replace the plain MLP with the constrained wrapper
base_net = MLP(hidden=128, layers=6)
model = ConstrainedModel(base_net).to(device)
solver = ForwardPINN(dom, model)
# With hard constraints, set IC/BC weights to 0 (no need to optimize them)
cfg = TrainCfg(adam_epochs=3000, lbfgs_max_iter=500, batch_size=32000)  # keep your values if different
wts = Weights(w_pde=0.01, w_ic1=0.0, w_ic2=0.0, w_bc_top_neu=0.0,w_seis=1.0)
X_S = build_X_S(
    t_sub=t_sub_Z,         # use the same t grid you used for Sz/Sx
    z0_s=0.01, zl_s=1,
    n_seis=20,
    ax=dom.ax, Lx=3, Lz=3,
    device=device
)
solver.train(cfg, wts, X_S, Sx, Sz)



[Adam    1] total=3.228e-04 | pde=1.051e-02 | bc=0.000e+00 | ic=0.000e+00 | seis=2.177e-04
[Adam  200] total=1.800e-09 | pde=1.149e-07 | bc=0.000e+00 | ic=0.000e+00 | seis=6.512e-10
[Adam  400] total=9.716e-10 | pde=4.106e-08 | bc=0.000e+00 | ic=0.000e+00 | seis=5.609e-10
[Adam  600] total=6.948e-10 | pde=1.548e-08 | bc=0.000e+00 | ic=0.000e+00 | seis=5.400e-10
[Adam  800] total=6.064e-10 | pde=7.091e-09 | bc=0.000e+00 | ic=0.000e+00 | seis=5.355e-10
[Adam 1000] total=5.674e-10 | pde=3.354e-09 | bc=0.000e+00 | ic=0.000e+00 | seis=5.339e-10
[Adam 1200] total=5.549e-10 | pde=2.194e-09 | bc=0.000e+00 | ic=0.000e+00 | seis=5.329e-10
[Adam 1400] total=5.499e-10 | pde=1.742e-09 | bc=0.000e+00 | ic=0.000e+00 | seis=5.325e-10
[Adam 1600] total=5.480e-10 | pde=1.600e-09 | bc=0.000e+00 | ic=0.000e+00 | seis=5.320e-10
[Adam 1800] total=5.463e-10 | pde=1.468e-09 | bc=0.000e+00 | ic=0.000e+00 | seis=5.317e-10
[Adam 2000] total=5.445e-10 | pde=1.310e-09 | bc=0.000e+00 | ic=0.000e+00 | seis=5.314e-10