# PINN for 1D Seismic Wave Equation

This notebook implements a Physics-Informed Neural Network to solve the 1D wave equation:

$$\frac{\partial^2 u}{\partial t^2} = c^2 \frac{\partial^2 u}{\partial x^2}$$

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from tqdm.notebook import tqdm

# CPU optimization
device = torch.device('cpu')
torch.set_num_threads(4)

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## Problem Setup

**Domain:** $x \in [0, 1]$, $t \in [0, 1]$

**Wave speed:** $c = 1.0$

**Initial conditions:**
- Displacement: $u(x, 0) = \exp(-200(x - 0.5)^2)$ — Gaussian pulse at center
- Velocity: $\frac{\partial u}{\partial t}(x, 0) = 0$ — Initially at rest

**Boundary conditions:**
- $u(0, t) = 0$ — Fixed left end
- $u(1, t) = 0$ — Fixed right end

In [None]:
class WavePINN(nn.Module):
    """
    Physics-Informed Neural Network for 1D Wave Equation.
    
    Architecture: MLP with tanh activation
    Input: (x, t) -> 2 features
    Output: u(x,t) -> 1 value
    """
    def __init__(self, hidden_layers=[48, 48, 48]):
        super().__init__()
        
        layers = []
        input_dim = 2  # (x, t)
        
        for hidden_dim in hidden_layers:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.Tanh())
            input_dim = hidden_dim
        
        layers.append(nn.Linear(input_dim, 1))
        self.network = nn.Sequential(*layers)
        
        # Initialize weights (Xavier)
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, x, t):
        """
        Forward pass.
        
        Args:
            x: spatial coordinate, shape (N, 1)
            t: time coordinate, shape (N, 1)
        
        Returns:
            u: displacement field, shape (N, 1)
        """
        inputs = torch.cat([x, t], dim=1)
        return self.network(inputs)


# Quick test
model_test = WavePINN()
x_test = torch.rand(5, 1)
t_test = torch.rand(5, 1)
u_test = model_test(x_test, t_test)
print(f"Model output shape: {u_test.shape}")
print(f"Total parameters: {sum(p.numel() for p in model_test.parameters()):,}")

In [None]:
def compute_derivative(u, var, order=1):
    """
    Compute derivative of u with respect to var using autograd.
    
    Args:
        u: tensor to differentiate
        var: variable to differentiate with respect to
        order: order of derivative (1 or 2)
    
    Returns:
        derivative tensor
    """
    derivative = u
    for _ in range(order):
        derivative = torch.autograd.grad(
            derivative, var,
            grad_outputs=torch.ones_like(derivative),
            create_graph=True,
            retain_graph=True
        )[0]
    return derivative

In [None]:
def physics_loss(model, x, t, c):
    """
    Compute PDE residual loss: u_tt - c² u_xx = 0
    """
    x = x.requires_grad_(True)
    t = t.requires_grad_(True)
    
    u = model(x, t)
    
    # Second derivatives
    u_t = compute_derivative(u, t, order=1)
    u_tt = compute_derivative(u_t, t, order=1)
    
    u_x = compute_derivative(u, x, order=1)
    u_xx = compute_derivative(u_x, x, order=1)
    
    # PDE residual
    residual = u_tt - c**2 * u_xx
    
    return torch.mean(residual**2)


def initial_condition_loss(model, x, u0_func):
    """
    Loss for initial displacement: u(x, 0) = u0(x)
    and initial velocity: u_t(x, 0) = 0
    """
    t = torch.zeros_like(x).requires_grad_(True)
    x = x.requires_grad_(True)
    
    u = model(x, t)
    u_target = u0_func(x)
    
    # Initial displacement loss
    loss_u0 = torch.mean((u - u_target)**2)
    
    # Initial velocity loss (u_t = 0)
    u_t = compute_derivative(u, t, order=1)
    loss_v0 = torch.mean(u_t**2)
    
    return loss_u0 + loss_v0


def boundary_condition_loss(model, t):
    """
    Loss for boundary conditions: u(0,t) = u(1,t) = 0
    """
    n_points = t.shape[0]
    
    # Left boundary: x = 0
    x_left = torch.zeros(n_points, 1)
    u_left = model(x_left, t)
    
    # Right boundary: x = 1
    x_right = torch.ones(n_points, 1)
    u_right = model(x_right, t)
    
    return torch.mean(u_left**2) + torch.mean(u_right**2)

In [None]:
def sample_collocation_points(n_points, x_range=(0, 1), t_range=(0, 1)):
    """Sample random points in the domain for physics loss."""
    x = torch.rand(n_points, 1) * (x_range[1] - x_range[0]) + x_range[0]
    t = torch.rand(n_points, 1) * (t_range[1] - t_range[0]) + t_range[0]
    return x, t


def sample_initial_points(n_points, x_range=(0, 1)):
    """Sample points at t=0 for initial condition."""
    x = torch.rand(n_points, 1) * (x_range[1] - x_range[0]) + x_range[0]
    return x


def sample_boundary_points(n_points, t_range=(0, 1)):
    """Sample time points for boundary condition."""
    t = torch.rand(n_points, 1) * (t_range[1] - t_range[0]) + t_range[0]
    return t

In [None]:
def gaussian_pulse(x, center=0.5, width=0.05):
    """
    Gaussian pulse initial condition.
    
    u0(x) = exp(-((x - center)^2 / (2 * width^2)))
    """
    return torch.exp(-((x - center)**2) / (2 * width**2))


# Visualize initial condition
x_vis = torch.linspace(0, 1, 200).reshape(-1, 1)
u0_vis = gaussian_pulse(x_vis)

plt.figure(figsize=(8, 3))
plt.plot(x_vis.numpy(), u0_vis.numpy(), 'b-', linewidth=2)
plt.xlabel('x')
plt.ylabel('u(x, 0)')
plt.title('Initial Condition: Gaussian Pulse')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Hyperparameters (CPU-friendly)
config = {
    'hidden_layers': [48, 48, 48],
    'learning_rate': 3e-3,
    'epochs': 5000,
    'wave_speed': 1.0,
    
    # Sampling points (reduced for CPU)
    'n_collocation': 2000,
    'n_initial': 500,
    'n_boundary': 500,
    
    # Loss weights
    'lambda_physics': 1.0,
    'lambda_ic': 100.0,
    'lambda_bc': 100.0,
}

print("Training Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

In [None]:
def train_pinn(config):
    """
    Main training function.
    
    Returns:
        model: trained model
        history: dict with loss history
    """
    # Initialize model
    model = WavePINN(hidden_layers=config['hidden_layers'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config['epochs'], eta_min=1e-6
    )
    
    # History
    history = {
        'total_loss': [],
        'physics_loss': [],
        'ic_loss': [],
        'bc_loss': [],
        'lr': []
    }
    
    # Training loop
    pbar = tqdm(range(config['epochs']), desc='Training')
    
    for epoch in pbar:
        optimizer.zero_grad()
        
        # Sample new points each epoch (reduces overfitting)
        x_col, t_col = sample_collocation_points(config['n_collocation'])
        x_ic = sample_initial_points(config['n_initial'])
        t_bc = sample_boundary_points(config['n_boundary'])
        
        # Compute losses
        loss_physics = physics_loss(model, x_col, t_col, config['wave_speed'])
        loss_ic = initial_condition_loss(model, x_ic, gaussian_pulse)
        loss_bc = boundary_condition_loss(model, t_bc)
        
        # Total weighted loss
        total_loss = (
            config['lambda_physics'] * loss_physics +
            config['lambda_ic'] * loss_ic +
            config['lambda_bc'] * loss_bc
        )
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Record history
        history['total_loss'].append(total_loss.item())
        history['physics_loss'].append(loss_physics.item())
        history['ic_loss'].append(loss_ic.item())
        history['bc_loss'].append(loss_bc.item())
        history['lr'].append(scheduler.get_last_lr()[0])
        
        # Update progress bar
        if epoch % 100 == 0:
            pbar.set_postfix({
                'loss': f"{total_loss.item():.2e}",
                'physics': f"{loss_physics.item():.2e}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}"
            })
    
    return model, history


# ============================================
# UNCOMMENT THE LINE BELOW TO START TRAINING
# ============================================
# model, history = train_pinn(config)

In [None]:
def plot_loss_history(history):
    """Plot training loss curves."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    # Total loss
    axes[0, 0].semilogy(history['total_loss'])
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Physics loss
    axes[0, 1].semilogy(history['physics_loss'])
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].set_title('Physics Loss (PDE Residual)')
    axes[0, 1].grid(True, alpha=0.3)
    
    # IC loss
    axes[1, 0].semilogy(history['ic_loss'])
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].set_title('Initial Condition Loss')
    axes[1, 0].grid(True, alpha=0.3)
    
    # BC loss
    axes[1, 1].semilogy(history['bc_loss'])
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].set_title('Boundary Condition Loss')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../results/1d_loss_curves.png', dpi=150)
    plt.show()


# Uncomment after training:
# plot_loss_history(history)

In [None]:
def plot_solution_snapshots(model, times=[0, 0.25, 0.5, 0.75, 1.0], n_points=200):
    """Plot wave solution at different time snapshots."""
    model.eval()
    
    x = torch.linspace(0, 1, n_points).reshape(-1, 1)
    
    fig, axes = plt.subplots(1, len(times), figsize=(15, 3))
    
    with torch.no_grad():
        for i, t_val in enumerate(times):
            t = torch.full_like(x, t_val)
            u = model(x, t).numpy()
            
            axes[i].plot(x.numpy(), u, 'b-', linewidth=2)
            axes[i].set_xlim([0, 1])
            axes[i].set_ylim([-1.2, 1.2])
            axes[i].set_xlabel('x')
            axes[i].set_ylabel('u(x,t)')
            axes[i].set_title(f't = {t_val:.2f}')
            axes[i].grid(True, alpha=0.3)
            
            # Show initial condition for reference
            if t_val == 0:
                u0 = gaussian_pulse(x).numpy()
                axes[i].plot(x.numpy(), u0, 'r--', label='IC', alpha=0.5)
                axes[i].legend()
    
    plt.tight_layout()
    plt.savefig('../results/1d_snapshots.png', dpi=150)
    plt.show()


# Uncomment after training:
# plot_solution_snapshots(model)

In [None]:
def create_wave_animation(model, n_frames=100, n_points=200):
    """Create animation of wave propagation."""
    model.eval()
    
    x = torch.linspace(0, 1, n_points).reshape(-1, 1)
    times = np.linspace(0, 1, n_frames)
    
    fig, ax = plt.subplots(figsize=(10, 4))
    line, = ax.plot([], [], 'b-', linewidth=2)
    time_text = ax.text(0.02, 0.95, '', transform=ax.transAxes)
    
    ax.set_xlim([0, 1])
    ax.set_ylim([-1.2, 1.2])
    ax.set_xlabel('x')
    ax.set_ylabel('u(x,t)')
    ax.set_title('1D Wave Propagation')
    ax.grid(True, alpha=0.3)
    
    def init():
        line.set_data([], [])
        time_text.set_text('')
        return line, time_text
    
    def update(frame):
        t_val = times[frame]
        t = torch.full_like(x, t_val)
        
        with torch.no_grad():
            u = model(x, t).numpy()
        
        line.set_data(x.numpy().flatten(), u.flatten())
        time_text.set_text(f't = {t_val:.3f}')
        return line, time_text
    
    anim = FuncAnimation(
        fig, update, init_func=init,
        frames=n_frames, interval=50, blit=True
    )
    
    plt.close()
    return anim


# Uncomment after training:
# anim = create_wave_animation(model)
# HTML(anim.to_jshtml())

# To save as GIF:
# anim.save('../results/1d_wave.gif', writer='pillow', fps=20)

In [None]:
def analytical_solution_1d(x, t, c=1.0, center=0.5, width=0.05):
    """
    D'Alembert solution for wave equation with Gaussian initial condition.
    
    u(x,t) = 0.5 * [u0(x - ct) + u0(x + ct)]
    
    Note: This is approximate due to boundary reflections.
    """
    # Traveling waves
    u_right = gaussian_pulse(x - c*t, center=center, width=width)
    u_left = gaussian_pulse(x + c*t, center=center, width=width)
    
    return 0.5 * (u_right + u_left)


def compute_l2_error(model, n_points=200, t_test=0.25):
    """Compute L2 relative error vs analytical solution."""
    model.eval()
    
    x = torch.linspace(0.1, 0.9, n_points).reshape(-1, 1)  # Avoid boundaries
    t = torch.full_like(x, t_test)
    
    with torch.no_grad():
        u_pred = model(x, t)
    
    u_exact = analytical_solution_1d(x, t, c=1.0)
    
    l2_error = torch.norm(u_pred - u_exact) / torch.norm(u_exact)
    
    return l2_error.item()


# Uncomment after training:
# error = compute_l2_error(model, t_test=0.25)
# print(f"L2 Relative Error at t=0.25: {error:.4f} ({error*100:.2f}%)")

In [None]:
def plot_comparison_with_analytical(model, times=[0.1, 0.25], n_points=200):
    """Compare PINN solution with analytical solution."""
    model.eval()
    x = torch.linspace(0, 1, n_points).reshape(-1, 1)
    
    fig, axes = plt.subplots(1, len(times), figsize=(12, 4))
    
    for i, t_val in enumerate(times):
        t = torch.full_like(x, t_val)
        
        with torch.no_grad():
            u_pred = model(x, t).numpy()
        
        u_exact = analytical_solution_1d(x, t, c=1.0).numpy()
        
        axes[i].plot(x.numpy(), u_exact, 'r-', linewidth=2, label='Analytical')
        axes[i].plot(x.numpy(), u_pred, 'b--', linewidth=2, label='PINN')
        axes[i].set_xlabel('x')
        axes[i].set_ylabel('u(x,t)')
        axes[i].set_title(f't = {t_val:.2f}')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../results/1d_comparison.png', dpi=150)
    plt.show()


# Uncomment after training:
# plot_comparison_with_analytical(model)