# Wave Equation Tutorial with Physics-Informed Neural Networks

This notebook provides a comprehensive tutorial on solving the 1D wave equation using PINNs.

## Learning Objectives
- Understand the wave equation and its physical meaning
- Implement a PINN to solve the wave equation
- Analyze the results and understand the physics
- Explore different initial and boundary conditions

In [None]:
# Setup
import sys
import os
sys.path.append('..')

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import our PINN implementation
from src.model import WavePINN
from src.losses import PhysicsInformedLoss
from src.train import train_model
from src.evaluate import evaluate_model, compute_analytical_solution
from src.visualization import plot_wave_evolution

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. The Wave Equation

The 1D wave equation describes the propagation of waves (sound, vibrations, etc.):

$$\frac{\partial^2 u}{\partial t^2} = c^2 \frac{\partial^2 u}{\partial x^2}$$

Where:
- $u(x,t)$ is the wave amplitude (displacement)
- $c$ is the wave speed
- $x$ is the spatial coordinate
- $t$ is time

### Physical Interpretation

In [None]:
# Visualize what the wave equation describes
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Wave at different times
x = np.linspace(0, 1, 100)
times = [0, 0.25, 0.5, 0.75]

for idx, t in enumerate(times):
    ax = axes[idx // 2, idx % 2]
    u = np.sin(np.pi * x) * np.cos(np.pi * t)
    
    ax.plot(x, u, 'b-', linewidth=2)
    ax.fill_between(x, 0, u, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(-1.2, 1.2)
    ax.set_xlabel('Position x')
    ax.set_ylabel('Amplitude u(x,t)')
    ax.set_title(f'Wave at t = {t}')
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color='k', linestyle='-', alpha=0.3)

plt.suptitle('Wave Evolution Over Time', fontsize=16)
plt.tight_layout()
plt.show()

print("The wave oscillates up and down over time, with fixed endpoints (boundary conditions).")

## 2. Problem Setup

We'll solve the wave equation with:
- **Domain**: $x \in [0, 1]$, $t \in [0, 1]$
- **Initial condition**: $u(x, 0) = \sin(\pi x)$ (initial shape)
- **Initial velocity**: $\frac{\partial u}{\partial t}(x, 0) = 0$ (starts from rest)
- **Boundary conditions**: $u(0, t) = u(1, t) = 0$ (fixed endpoints)
- **Wave speed**: $c = 1$

In [None]:
# Visualize initial and boundary conditions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Initial condition
x = np.linspace(0, 1, 100)
u_initial = np.sin(np.pi * x)

ax1.plot(x, u_initial, 'b-', linewidth=3, label='u(x,0) = sin(πx)')
ax1.scatter([0, 1], [0, 0], c='red', s=100, zorder=5, label='Fixed endpoints')
ax1.set_xlabel('x')
ax1.set_ylabel('u(x,0)')
ax1.set_title('Initial Condition')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Domain visualization
ax2.add_patch(plt.Rectangle((0, 0), 1, 1, fill=True, facecolor='lightblue', alpha=0.5))
ax2.plot([0, 1, 1, 0, 0], [0, 0, 1, 1, 0], 'k-', linewidth=2)
ax2.plot([0, 1], [0, 0], 'r-', linewidth=4, label='Initial condition (t=0)')
ax2.plot([0, 0], [0, 1], 'g-', linewidth=4, label='Boundary (x=0)')
ax2.plot([1, 1], [0, 1], 'g-', linewidth=4, label='Boundary (x=1)')
ax2.set_xlabel('x')
ax2.set_ylabel('t')
ax2.set_title('Problem Domain')
ax2.legend()
ax2.set_xlim(-0.1, 1.1)
ax2.set_ylim(-0.1, 1.1)

plt.tight_layout()
plt.show()

## 3. Building the PINN

Let's create and examine our Physics-Informed Neural Network:

In [None]:
# Create the model
model = WavePINN().to(device)

# Examine the architecture
print("PINN Architecture for Wave Equation:")
print("=" * 50)
print(model)
print("\n" + "=" * 50)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Visualize the network
def count_parameters_per_layer(model):
    params = []
    for name, param in model.named_parameters():
        params.append((name, param.numel()))
    return params

params_per_layer = count_parameters_per_layer(model)
for name, count in params_per_layer:
    print(f"{name:20s}: {count:6d} parameters")

## 4. Understanding the Physics-Informed Loss

The loss function enforces the physics through automatic differentiation:

In [None]:
# Demonstrate how we compute the PDE residual
def demonstrate_pde_residual(model, x_sample, t_sample):
    """Show step-by-step computation of PDE residual"""
    x = torch.tensor([[x_sample]], dtype=torch.float32, requires_grad=True)
    t = torch.tensor([[t_sample]], dtype=torch.float32, requires_grad=True)
    
    # Forward pass
    u = model(x, t)
    print(f"u({x_sample}, {t_sample}) = {u.item():.4f}")
    
    # First derivatives
    u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), 
                              create_graph=True, retain_graph=True)[0]
    u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), 
                              create_graph=True, retain_graph=True)[0]
    
    print(f"∂u/∂x = {u_x.item():.4f}")
    print(f"∂u/∂t = {u_t.item():.4f}")
    
    # Second derivatives
    u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), 
                               create_graph=True, retain_graph=True)[0]
    u_tt = torch.autograd.grad(u_t, t, grad_outputs=torch.ones_like(u_t), 
                               create_graph=True, retain_graph=True)[0]
    
    print(f"∂²u/∂x² = {u_xx.item():.4f}")
    print(f"∂²u/∂t² = {u_tt.item():.4f}")
    
    # Wave equation residual
    c = 1.0  # wave speed
    residual = u_tt - c**2 * u_xx
    
    print(f"\nWave equation residual: ∂²u/∂t² - c²∂²u/∂x² = {residual.item():.4f}")
    print("(This should be close to 0 when trained)")
    
    return residual.item()

# Test with untrained model
print("Before training (random weights):")
residual_before = demonstrate_pde_residual(model, 0.5, 0.5)

## 5. Training the PINN

Now let's train the network to satisfy the wave equation:

In [None]:
# Create a fresh model for training
model = WavePINN().to(device)

# Train the model
print("Training the Physics-Informed Neural Network...")
print("This will take about 1-2 minutes...\n")

history = train_model(
    model,
    epochs=3000,
    lr=1e-3,
    device=device,
    log_every=500,
    verbose=True
)

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Total loss
ax = axes[0, 0]
ax.semilogy(history['loss'], 'b-', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Total Loss')
ax.set_title('Total Loss Evolution')
ax.grid(True, alpha=0.3)

# Individual losses
ax = axes[0, 1]
ax.semilogy(history['pde'], label='PDE', linewidth=2)
ax.semilogy(history['ic'], label='Initial Condition', linewidth=2)
ax.semilogy(history['bc'], label='Boundary Condition', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Individual Loss Components')
ax.legend()
ax.grid(True, alpha=0.3)

# Loss reduction
ax = axes[1, 0]
reduction = np.array(history['loss']) / history['loss'][0]
ax.plot(reduction, 'g-', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss / Initial Loss')
ax.set_title(f'Loss Reduction: {1/reduction[-1]:.0f}x')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

# Final losses
ax = axes[1, 1]
final_losses = {
    'PDE': history['pde'][-1],
    'IC': history['ic'][-1],
    'BC': history['bc'][-1]
}
bars = ax.bar(final_losses.keys(), final_losses.values())
ax.set_ylabel('Final Loss Value')
ax.set_title('Final Loss Components')
ax.set_yscale('log')
for bar, val in zip(bars, final_losses.values()):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.5,
            f'{val:.2e}', ha='center', va='bottom')

plt.suptitle('Training History Analysis', fontsize=16)
plt.tight_layout()
plt.show()

## 6. Evaluating the Results

Let's see how well our PINN learned the wave equation:

In [None]:
# Check PDE residual after training
print("After training:")
residual_after = demonstrate_pde_residual(model, 0.5, 0.5)
print(f"\nResidual reduction: {abs(residual_before):.4f} → {abs(residual_after):.4f} ({abs(residual_before)/abs(residual_after):.0f}x improvement)")

In [None]:
# Evaluate model accuracy
metrics = evaluate_model(model, device=device)

print("Model Evaluation Metrics:")
print("=" * 40)
print(f"Mean Squared Error: {metrics['mse']:.6f}")
print(f"Root Mean Squared Error: {metrics['rmse']:.6f}")
print(f"Mean Absolute Error: {metrics['mae']:.6f}")
print(f"Maximum Error: {metrics['max_error']:.4f}")

print("\nSpecific Test Cases:")
print("-" * 40)
for name, error in metrics['test_cases'].items():
    print(f"{name:20s}: error = {error:.6f}")

## 7. Visualizing the Solution

Let's visualize how the wave propagates over time:

In [None]:
# Create detailed visualization
times = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
axes = axes.flatten()

x = torch.linspace(0, 1, 200, device=device).reshape(-1, 1)
x_np = x.cpu().numpy()

for idx, t_val in enumerate(times):
    ax = axes[idx]
    
    # PINN prediction
    t = torch.full_like(x, t_val)
    with torch.no_grad():
        u_pred = model(x, t).cpu().numpy()
    
    # Analytical solution
    u_true = compute_analytical_solution(x_np, t_val)
    
    # Plot
    ax.plot(x_np, u_true, 'k--', linewidth=2, label='Analytical', alpha=0.7)
    ax.plot(x_np, u_pred, 'b-', linewidth=2.5, label='PINN')
    ax.fill_between(x_np.squeeze(), 0, u_pred.squeeze(), alpha=0.3, color='blue')
    
    # Formatting
    ax.set_xlim(0, 1)
    ax.set_ylim(-1.2, 1.2)
    ax.set_xlabel('x')
    ax.set_ylabel('u(x,t)')
    ax.set_title(f't = {t_val:.3f}')
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3)
    
    if idx == 0:
        ax.legend(loc='upper right')
    
    # Add phase indicator
    phase = t_val % 1.0
    if phase < 0.25:
        phase_text = "Moving down ↓"
        color = 'red'
    elif phase < 0.5:
        phase_text = "At minimum"
        color = 'darkred'
    elif phase < 0.75:
        phase_text = "Moving up ↑"
        color = 'green'
    else:
        phase_text = "At maximum"
        color = 'darkgreen'
    
    ax.text(0.95, 0.95, phase_text, transform=ax.transAxes,
            ha='right', va='top', fontsize=10, color=color,
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))

plt.suptitle('Wave Evolution: Complete Period', fontsize=16)
plt.tight_layout()
plt.show()

## 8. Creating an Animation

Let's create an animation to see the wave in motion:

In [None]:
# Create animation
from matplotlib.animation import FuncAnimation

fig, ax = plt.subplots(figsize=(10, 6))

x = torch.linspace(0, 1, 200, device=device).reshape(-1, 1)
x_np = x.cpu().numpy()

# Initialize plot
line_true, = ax.plot([], [], 'k--', linewidth=2, label='Analytical', alpha=0.7)
line_pred, = ax.plot([], [], 'b-', linewidth=3, label='PINN')
fill = None
time_text = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12,
                    bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.7))

ax.set_xlim(0, 1)
ax.set_ylim(-1.5, 1.5)
ax.set_xlabel('Position x', fontsize=12)
ax.set_ylabel('Amplitude u(x,t)', fontsize=12)
ax.set_title('Wave Propagation Animation', fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right')

def init():
    line_true.set_data([], [])
    line_pred.set_data([], [])
    time_text.set_text('')
    return line_true, line_pred, time_text

def animate(frame):
    global fill
    
    t_val = frame / 50.0  # 50 frames for t in [0, 1]
    
    # PINN prediction
    t = torch.full_like(x, t_val)
    with torch.no_grad():
        u_pred = model(x, t).cpu().numpy()
    
    # True solution
    u_true = compute_analytical_solution(x_np, t_val)
    
    # Update lines
    line_true.set_data(x_np, u_true)
    line_pred.set_data(x_np, u_pred)
    
    # Update fill
    if fill is not None:
        fill.remove()
    fill = ax.fill_between(x_np.squeeze(), 0, u_pred.squeeze(), alpha=0.3, color='blue')
    
    # Update time text
    time_text.set_text(f't = {t_val:.3f}')
    
    return line_true, line_pred, fill, time_text

# Create animation
anim = FuncAnimation(fig, animate, init_func=init, frames=51, interval=50, blit=False)

# Display animation
plt.close()  # Prevent static plot
HTML(anim.to_jshtml())

## 9. Exploring Different Scenarios

Let's explore what happens with different initial conditions:

In [None]:
# Different initial conditions
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

x = np.linspace(0, 1, 100)

# Different modes
initial_conditions = [
    ('First mode: sin(πx)', lambda x: np.sin(np.pi * x)),
    ('Second mode: sin(2πx)', lambda x: np.sin(2 * np.pi * x)),
    ('Third mode: sin(3πx)', lambda x: np.sin(3 * np.pi * x)),
    ('Combination: sin(πx) + 0.5sin(3πx)', lambda x: np.sin(np.pi * x) + 0.5 * np.sin(3 * np.pi * x))
]

for idx, (title, func) in enumerate(initial_conditions):
    ax = axes[idx // 2, idx % 2]
    
    # Plot initial condition
    u0 = func(x)
    ax.plot(x, u0, 'b-', linewidth=2.5, label='t=0')
    
    # Plot at quarter period
    if 'First mode' in title:
        u_quarter = func(x) * np.cos(np.pi * 0.25)
    elif 'Second mode' in title:
        u_quarter = func(x) * np.cos(2 * np.pi * 0.25)
    elif 'Third mode' in title:
        u_quarter = func(x) * np.cos(3 * np.pi * 0.25)
    else:
        u_quarter = np.sin(np.pi * x) * np.cos(np.pi * 0.25) + 0.5 * np.sin(3 * np.pi * x) * np.cos(3 * np.pi * 0.25)
    
    ax.plot(x, u_quarter, 'r--', linewidth=2, label='t=0.25', alpha=0.7)
    
    ax.set_xlim(0, 1)
    ax.set_ylim(-1.5, 1.5)
    ax.set_xlabel('x')
    ax.set_ylabel('u(x,t)')
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3)

plt.suptitle('Different Wave Modes', fontsize=16)
plt.tight_layout()
plt.show()

print("Note: Higher modes oscillate faster in both space and time.")
print("The frequency is proportional to the mode number.")

## 10. Error Analysis

Let's analyze where and when the PINN makes errors: