# PINN for 2D Seismic Wave Equation

This notebook extends the 1D wave equation to two spatial dimensions:

$$\frac{\partial^2 u}{\partial t^2} = c^2 \left(\frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2}\right)$$

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from tqdm.notebook import tqdm

# CPU optimization
device = torch.device('cpu')
torch.set_num_threads(4)

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 2D Wave Equation

**Domain:** $x, y \in [0, 1]$, $t \in [0, 0.5]$

**PDE:** $\frac{\partial^2 u}{\partial t^2} = c^2 \left(\frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2}\right)$

**Wave speed:** $c = 1.0$

**Initial conditions:**
- $u(x, y, 0) = \exp\left(-100\left((x-0.5)^2 + (y-0.5)^2\right)\right)$ â€” 2D Gaussian at center
- $\frac{\partial u}{\partial t}(x, y, 0) = 0$

**Boundary conditions:**
- $u = 0$ on all edges ($x=0$, $x=1$, $y=0$, $y=1$)

In [None]:
class WavePINN2D(nn.Module):
    """PINN for 2D Wave Equation."""
    
    def __init__(self, hidden_layers=[32, 32, 32, 32]):
        super().__init__()
        
        layers = []
        input_dim = 3  # (x, y, t)
        
        for hidden_dim in hidden_layers:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.Tanh())
            input_dim = hidden_dim
        
        layers.append(nn.Linear(input_dim, 1))
        self.network = nn.Sequential(*layers)
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, x, y, t):
        inputs = torch.cat([x, y, t], dim=1)
        return self.network(inputs)


# Quick test
model_test = WavePINN2D()
x_t = torch.rand(5, 1)
y_t = torch.rand(5, 1)
t_t = torch.rand(5, 1)
u_t = model_test(x_t, y_t, t_t)
print(f"Model output shape: {u_t.shape}")
print(f"Total parameters: {sum(p.numel() for p in model_test.parameters()):,}")

In [None]:
def compute_derivative(u, var, order=1):
    """
    Compute derivative of u with respect to var using autograd.
    """
    derivative = u
    for _ in range(order):
        derivative = torch.autograd.grad(
            derivative, var,
            grad_outputs=torch.ones_like(derivative),
            create_graph=True,
            retain_graph=True
        )[0]
    return derivative

In [None]:
def physics_loss_2d(model, x, y, t, c):
    """PDE residual: u_tt - c^2(u_xx + u_yy) = 0"""
    x = x.requires_grad_(True)
    y = y.requires_grad_(True)
    t = t.requires_grad_(True)
    
    u = model(x, y, t)
    
    # Time derivatives
    u_t = compute_derivative(u, t, order=1)
    u_tt = compute_derivative(u_t, t, order=1)
    
    # Spatial derivatives
    u_x = compute_derivative(u, x, order=1)
    u_xx = compute_derivative(u_x, x, order=1)
    
    u_y = compute_derivative(u, y, order=1)
    u_yy = compute_derivative(u_y, y, order=1)
    
    # Laplacian
    laplacian = u_xx + u_yy
    
    # Residual
    residual = u_tt - c**2 * laplacian
    
    return torch.mean(residual**2)

In [None]:
def boundary_loss_2d(model, t, n_per_edge=100):
    """Loss for u=0 on all four edges."""
    loss = 0
    
    # Bottom edge: y = 0
    x = torch.rand(n_per_edge, 1)
    y = torch.zeros(n_per_edge, 1)
    loss += torch.mean(model(x, y, t[:n_per_edge])**2)
    
    # Top edge: y = 1
    y = torch.ones(n_per_edge, 1)
    loss += torch.mean(model(x, y, t[:n_per_edge])**2)
    
    # Left edge: x = 0
    x = torch.zeros(n_per_edge, 1)
    y = torch.rand(n_per_edge, 1)
    loss += torch.mean(model(x, y, t[:n_per_edge])**2)
    
    # Right edge: x = 1
    x = torch.ones(n_per_edge, 1)
    loss += torch.mean(model(x, y, t[:n_per_edge])**2)
    
    return loss

In [None]:
def gaussian_2d(x, y, center_x=0.5, center_y=0.5, width=0.1):
    """2D Gaussian pulse."""
    r_squared = (x - center_x)**2 + (y - center_y)**2
    return torch.exp(-r_squared / (2 * width**2))


def initial_loss_2d(model, x, y, u0_func):
    """IC loss for 2D case."""
    t = torch.zeros_like(x).requires_grad_(True)
    
    u = model(x, y, t)
    u_target = u0_func(x, y)
    
    loss_u0 = torch.mean((u - u_target)**2)
    
    # Zero initial velocity
    u_t = compute_derivative(u, t, order=1)
    loss_v0 = torch.mean(u_t**2)
    
    return loss_u0 + loss_v0


# Visualize 2D initial condition
x_vis = torch.linspace(0, 1, 50)
y_vis = torch.linspace(0, 1, 50)
X, Y = torch.meshgrid(x_vis, y_vis, indexing='ij')
U0 = gaussian_2d(X, Y).numpy()

fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(U0.T, extent=[0, 1, 0, 1], origin='lower', cmap='RdBu_r')
plt.colorbar(im, label='u(x, y, 0)')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('2D Gaussian Initial Condition')
plt.tight_layout()
plt.show()

In [None]:
# Sampling functions for 2D

def sample_collocation_2d(n_points, t_max=0.5):
    """Sample random points in the 2D+time domain."""
    x = torch.rand(n_points, 1)
    y = torch.rand(n_points, 1)
    t = torch.rand(n_points, 1) * t_max
    return x, y, t


def sample_initial_2d(n_points):
    """Sample points at t=0 in 2D domain."""
    x = torch.rand(n_points, 1)
    y = torch.rand(n_points, 1)
    return x, y


def sample_boundary_time(n_points, t_max=0.5):
    """Sample time points for boundary conditions."""
    t = torch.rand(n_points, 1) * t_max
    return t

In [None]:
config_2d = {
    'hidden_layers': [32, 32, 32, 32],  # Smaller for CPU
    'learning_rate': 3e-3,
    'epochs': 8000,  # More epochs needed for 2D
    'wave_speed': 1.0,
    't_max': 0.5,  # Shorter time domain
    
    # Reduced sampling for CPU
    'n_collocation': 3000,
    'n_initial': 500,
    'n_boundary': 400,  # 100 per edge
    
    'lambda_physics': 1.0,
    'lambda_ic': 100.0,
    'lambda_bc': 100.0,
}

print("2D Training Configuration:")
for k, v in config_2d.items():
    print(f"  {k}: {v}")

In [None]:
def train_pinn_2d(config):
    """
    Train 2D wave PINN.
    
    Returns:
        model: trained model
        history: dict with loss history
    """
    model = WavePINN2D(hidden_layers=config['hidden_layers'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config['epochs'], eta_min=1e-6
    )
    
    history = {
        'total_loss': [],
        'physics_loss': [],
        'ic_loss': [],
        'bc_loss': [],
        'lr': []
    }
    
    pbar = tqdm(range(config['epochs']), desc='Training 2D')
    
    for epoch in pbar:
        optimizer.zero_grad()
        
        # Sample points
        x_col, y_col, t_col = sample_collocation_2d(
            config['n_collocation'], t_max=config['t_max']
        )
        x_ic, y_ic = sample_initial_2d(config['n_initial'])
        t_bc = sample_boundary_time(config['n_boundary'], t_max=config['t_max'])
        
        # Compute losses
        loss_physics = physics_loss_2d(
            model, x_col, y_col, t_col, config['wave_speed']
        )
        loss_ic = initial_loss_2d(model, x_ic, y_ic, gaussian_2d)
        loss_bc = boundary_loss_2d(
            model, t_bc, n_per_edge=config['n_boundary'] // 4
        )
        
        # Total weighted loss
        total_loss = (
            config['lambda_physics'] * loss_physics +
            config['lambda_ic'] * loss_ic +
            config['lambda_bc'] * loss_bc
        )
        
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Record history
        history['total_loss'].append(total_loss.item())
        history['physics_loss'].append(loss_physics.item())
        history['ic_loss'].append(loss_ic.item())
        history['bc_loss'].append(loss_bc.item())
        history['lr'].append(scheduler.get_last_lr()[0])
        
        if epoch % 200 == 0:
            pbar.set_postfix({
                'loss': f"{total_loss.item():.2e}",
                'physics': f"{loss_physics.item():.2e}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}"
            })
    
    return model, history


# ============================================
# UNCOMMENT THE LINE BELOW TO START TRAINING
# ============================================
# model_2d, history_2d = train_pinn_2d(config_2d)

In [None]:
def plot_loss_history_2d(history):
    """Plot training loss curves for 2D case."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    axes[0, 0].semilogy(history['total_loss'])
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].semilogy(history['physics_loss'])
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].set_title('Physics Loss (PDE Residual)')
    axes[0, 1].grid(True, alpha=0.3)
    
    axes[1, 0].semilogy(history['ic_loss'])
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].set_title('Initial Condition Loss')
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].semilogy(history['bc_loss'])
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].set_title('Boundary Condition Loss')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../results/2d_loss_curves.png', dpi=150)
    plt.show()


# Uncomment after training:
# plot_loss_history_2d(history_2d)

In [None]:
def plot_2d_snapshots(model, times=[0, 0.1, 0.25, 0.4], resolution=50):
    """Plot 2D wave as heatmaps at different times."""
    model.eval()
    
    x = torch.linspace(0, 1, resolution)
    y = torch.linspace(0, 1, resolution)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    X_flat = X.reshape(-1, 1)
    Y_flat = Y.reshape(-1, 1)
    
    fig, axes = plt.subplots(1, len(times), figsize=(16, 4))
    
    with torch.no_grad():
        for i, t_val in enumerate(times):
            t = torch.full_like(X_flat, t_val)
            u = model(X_flat, Y_flat, t).reshape(resolution, resolution).numpy()
            
            im = axes[i].imshow(
                u.T, extent=[0, 1, 0, 1], origin='lower',
                cmap='RdBu_r', vmin=-1, vmax=1
            )
            axes[i].set_xlabel('x')
            axes[i].set_ylabel('y')
            axes[i].set_title(f't = {t_val:.2f}')
            plt.colorbar(im, ax=axes[i], fraction=0.046)
    
    plt.tight_layout()
    plt.savefig('../results/2d_snapshots.png', dpi=150)
    plt.show()


# Uncomment after training:
# plot_2d_snapshots(model_2d)

In [None]:
def create_2d_animation(model, n_frames=50, resolution=50, t_max=0.5):
    """Create heatmap animation of 2D wave."""
    model.eval()
    
    x = torch.linspace(0, 1, resolution)
    y = torch.linspace(0, 1, resolution)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    X_flat = X.reshape(-1, 1)
    Y_flat = Y.reshape(-1, 1)
    
    times = np.linspace(0, t_max, n_frames)
    
    fig, ax = plt.subplots(figsize=(6, 5))
    
    # Initial frame
    t = torch.full_like(X_flat, 0.0)
    with torch.no_grad():
        u = model(X_flat, Y_flat, t).reshape(resolution, resolution).numpy()
    
    im = ax.imshow(
        u.T, extent=[0, 1, 0, 1], origin='lower',
        cmap='RdBu_r', vmin=-1, vmax=1, animated=True
    )
    plt.colorbar(im, ax=ax, label='u(x,y,t)')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    title = ax.set_title('t = 0.000')
    
    def update(frame):
        t_val = times[frame]
        t = torch.full_like(X_flat, t_val)
        
        with torch.no_grad():
            u = model(X_flat, Y_flat, t).reshape(resolution, resolution).numpy()
        
        im.set_array(u.T)
        title.set_text(f't = {t_val:.3f}')
        return [im, title]
    
    anim = FuncAnimation(
        fig, update, frames=n_frames, interval=100, blit=True
    )
    
    plt.close()
    return anim


# Uncomment after training:
# anim_2d = create_2d_animation(model_2d)
# HTML(anim_2d.to_jshtml())

# To save as GIF:
# anim_2d.save('../results/2d_wave.gif', writer='pillow', fps=10)

In [None]:
def plot_cross_sections(model, t_val=0.25, resolution=100):
    """
    Plot cross-sections of the 2D solution along x and y axes.
    """
    model.eval()
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Cross-section along y = 0.5
    x = torch.linspace(0, 1, resolution).reshape(-1, 1)
    y = torch.full_like(x, 0.5)
    t = torch.full_like(x, t_val)
    
    with torch.no_grad():
        u = model(x, y, t).numpy()
    
    axes[0].plot(x.numpy(), u, 'b-', linewidth=2)
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('u')
    axes[0].set_title(f'Cross-section at y=0.5, t={t_val}')
    axes[0].grid(True, alpha=0.3)
    
    # Cross-section along x = 0.5
    y = torch.linspace(0, 1, resolution).reshape(-1, 1)
    x = torch.full_like(y, 0.5)
    t = torch.full_like(y, t_val)
    
    with torch.no_grad():
        u = model(x, y, t).numpy()
    
    axes[1].plot(y.numpy(), u, 'r-', linewidth=2)
    axes[1].set_xlabel('y')
    axes[1].set_ylabel('u')
    axes[1].set_title(f'Cross-section at x=0.5, t={t_val}')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../results/2d_cross_sections.png', dpi=150)
    plt.show()


# Uncomment after training:
# plot_cross_sections(model_2d, t_val=0.25)