# DeepONet for Acoustic Wave Propagation

**PhD Autumn School - Scientific Machine Learning**

## Learning Objectives

In this 1-hour hands-on exercise, you will:
1. Understand the DeepONet architecture for operator learning
1. Understand the data.
2. Compare different activation functions (ReLU vs Sine)
3. Investigate the impact of Fourier feature expansions

## Introduction to DeepONet

**Deep Operator Networks (DeepONet)** learn operators that map functions to functions, rather than just vectors to vectors like traditional neural networks.

### The Problem

In acoustic wave propagation, we want to learn the operator $\mathcal{G}$ that maps:
- **Input function** $u(x)$: Initial pressure distribution (source configuration)
- **Output function** $s(x, y, t)$: Pressure field at any location $(x, y)$ and time $t$

Mathematically: $s = \mathcal{G}(u)$

### DeepONet Architecture

```
                    ┌─────────────┐
u(x) ─────────────> │ Branch Net  │ ───> [b₁, b₂, ..., bₚ]
  (function)        └─────────────┘            │
                                               │
                                               ├──> Inner Product ──> s(y)
                                               │
y = (x,y,t) ────> ┌─────────────┐            │
  (coordinates)   │  Trunk Net  │ ───> [t₁, t₂, ..., tₚ]
                  └─────────────┘
```

- **Branch network**: Encodes the input function $u$ into a latent representation
- **Trunk network**: Encodes the query coordinates $(x, y, t)$
- **Inner product**: Combines both representations: $s(y) = \sum_{i=1}^{p} b_i(u) \cdot t_i(y) + b_0$

### Key Insight

By learning this decomposition, the network can:
- Generalize to new source configurations $u$ (never seen during training)
- Evaluate at arbitrary query locations $(x, y, t)$
- Reduce computational cost compared to traditional PDE solvers

## Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Import project data handlers
from deeponet_acoustics.datahandlers.datagenerators import (
    DataH5Compact,
    DatasetStreamer,
    pytorch_collate  # PyTorch collator for data loading
)

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## 1. Data Loading

In [None]:
# NOTE: Update these paths to point to your data
# simpe 2D domain
train_data_path = "/Users/nikolasborrel/data/deeponet/data/input2D/rect2x2_freq_indep_ppw_2_6_2_5_train"  
data_val_path = "/Users/nikolasborrel/data/deeponet/data/input2D/rect2x2_freq_indep_ppw_2_4_2_val"
test_data_path = "/Users/nikolasborrel/data/deeponet/data/input2D/rect2x2_freq_indep_ppw_2_4_2_test"

# furnished room
# train_data_path = "/Users/nikolasborrel/data/deeponet/data/input2D/rect3x3_furn_freq_indep_ppw_2_6_2_3_train"  
# data_val_path = "/Users/nikolasborrel/data/deeponet/data/input2D/rect3x3_furn_freq_indep_ppw_2_4_2_val"
# test_data_path = "/Users/nikolasborrel/data/deeponet/data/input2D/rect3x3_furn_freq_indep_ppw_2_4_2_val"

# Load training data
data_train = DataH5Compact(
    train_data_path,
    t_norm=343.0,    # Speed of sound for normalization
    flatten_ic=True, # Flatten initial conditions (for MLPs)
    norm_data=True,  # Normalize spatial coordinates
    u_p_range=(-2.0, 2.0), # Model was tuned with this branch normalization in the paper
)

# Load val data
data_val = DataH5Compact(
    data_val_path,
    t_norm=343.0,
    flatten_ic=True,
    norm_data=True,
    u_p_range=(-2.0, 2.0),
)

print(f"Training sources: {data_train.N}")
print(f"Val sources: {data_val.N}")
print(f"Mesh points: {data_train.P_mesh}")
print(f"Time steps: {len(data_train.tsteps)}")
print(f"Input (u) shape: {data_train.u_shape}")

In [None]:
# Create datasets and dataloaders
batch_size_branch = 64  # Number of different sources per batch
batch_size_coord = 200  # Number of coordinate points per batch

dataset_train = DatasetStreamer(data_train, batch_size_coord=batch_size_coord)
dataset_val = DatasetStreamer(data_val, batch_size_coord=batch_size_coord)


# Use pytorch_collate to convert to PyTorch tensors
dataloader_train = DataLoader(
    dataset_train,
    batch_size=batch_size_branch,
    shuffle=True,
    collate_fn=pytorch_collate,
    drop_last=True,
)

dataloader_val = DataLoader(
    dataset_val,
    batch_size=1,
    shuffle=False,
    collate_fn=pytorch_collate,
)

# Get a sample batch to understand dimensions
sample_batch = next(iter(dataloader_train))
(u_sample, y_sample), s_sample, _, _ = sample_batch

print(f"\nBatch shapes:")
print(f"  u (branch input): {u_sample.shape}  # [batch_branch, u_dim]")
print(f"  y (trunk input):  {y_sample.shape}  # [batch_branch, batch_coord, coord_dim (including time dim)]")
print(f"  s (output):       {s_sample.shape}  # [batch_branch, batch_coord]")

### Visualizations

Let's visualize the training and test data to understand the acoustic wave propagation problem.

In [4]:
from deeponet_acoustics.datahandlers.datagenerators import _normalize_spatial


def plot_initial_condition(data: DataH5Compact, source_idx=0, figsize=(8, 6)):
    """
    Plot the initial pressure distribution (source configuration).
    
    Args:
        data: DataH5Compact instance
        source_idx: Index of source to visualize
        figsize: Figure size
    """
    # Get initial condition
    dataset = data.datasets[source_idx]
    u = dataset[data.tag_ufield][:]
    
    # Get umesh coordinates (the mesh for initial conditions - different from simulation mesh!)
    umesh = dataset['/umesh'][:]    
    
    # Create plot
    fig, ax = plt.subplots(figsize=figsize)
    scatter = ax.scatter(umesh[:, 0], umesh[:, 1], c=u, 
                        cmap='RdBu_r', s=20, vmin=-2, vmax=2)
    
    # Get source position if available
    if 'source_position' in dataset:
        x0 = dataset['source_position'][:]
        ax.plot(x0[0], x0[1], 'k*', markersize=20)
        ax.legend()
    
    plt.colorbar(scatter, ax=ax, label='Initial Pressure')
    ax.set_xlabel('x [m]')
    ax.set_ylabel('y [m]')
    ax.set_title(f'Initial Condition (Source {source_idx})')
    ax.set_aspect('equal')
    plt.tight_layout()
    plt.show()
    
    print("NOTE: that the coordinates for U are not used and hence not normalized.")
    print(f"Initial condition range: [{np.min(u):.3f}, {np.max(u):.3f}]")
    print(f"Number of IC points (umesh): {len(u)}")
    print(f"Number of simulation points (mesh): {len(data.mesh)}")


def plot_pressure_field_snapshots(data: DataH5Compact, source_idx=0, time_indices=[0, 30, 60, 90], figsize=(16, 4)):
    """
    Plot pressure field at multiple time steps.
    
    Args:
        data: DataH5Compact instance
        source_idx: Index of source to visualize
        time_indices: List of time step indices to plot
        figsize: Figure size
    """
    dataset = data.datasets[source_idx]
    
    # Get pressure field
    pressure_field = dataset[data.tags_field[0]][:, ::data.data_prune]  # [time, space]
    
    # Create subplots
    n_plots = len(time_indices)
    fig, axes = plt.subplots(1, n_plots, figsize=figsize)
    
    if n_plots == 1:
        axes = [axes]
    
    # Global colorbar limits
    vmin, vmax = -0.5, 0.5
    
    for i, (ax, t_idx) in enumerate(zip(axes, time_indices)):
        # Get pressure at this time step
        p_t = pressure_field[t_idx, :]
        
        # Plot
        scatter = ax.scatter(data.mesh[:, 0], data.mesh[:, 1], c=p_t,
                           cmap='RdBu_r', s=10, vmin=vmin, vmax=vmax)
        
        # Mark source position if available
        if 'source_position' in dataset:
            x0 = dataset['source_position'][:]
            # Source coordinates are not normalized (not used for training)
            if data.normalize_data:
                x0 = _normalize_spatial(x0, data.xmin, data.xmax)
            ax.plot(x0[0], x0[1], 'k*', markersize=15)
        
        # Get actual time value
        time_val = data.tsteps[t_idx]
        
        ax.set_title(f't_norm = {time_val:.2f}')
        ax.set_xlabel('x [m]')
        if i == 0:
            ax.set_ylabel('y [m]')
        ax.set_aspect('equal')
    
    # Add shared colorbar
    fig.colorbar(scatter, ax=axes, label='Pressure', pad=0.02)
    fig.suptitle(f'Pressure Field Evolution (Source {source_idx})', y=1.02, fontsize=14)
    plt.tight_layout()
    plt.show()

    print(f"t_norm min/max: {data.tsteps[0]}/{data.tsteps[-1]}")


def plot_multiple_sources(data: DataH5Compact, source_indices=[0, 1, 2], time_idx=50, figsize=(15, 4)):
    """
    Compare pressure fields from different source positions at the same time.
    
    Args:
        data: DataH5Compact instance
        source_indices: List of source indices to compare
        time_idx: Time step index to visualize
        figsize: Figure size
    """    

    mesh_vis = data.mesh
    
    # Create subplots
    n_plots = len(source_indices)
    fig, axes = plt.subplots(1, n_plots, figsize=figsize)
    
    if n_plots == 1:
        axes = [axes]
    
    # Global colorbar limits
    vmin, vmax = -0.5, 0.5
    
    for i, (ax, src_idx) in enumerate(zip(axes, source_indices)):
        if src_idx >= data.N:
            ax.text(0.5, 0.5, f'Source {src_idx}\nnot available', 
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_xticks([])
            ax.set_yticks([])
            continue
            
        dataset = data.datasets[src_idx]
        pressure_field = dataset[data.tags_field[0]][:, ::data.data_prune]
        p_t = pressure_field[time_idx, :]
        
        # Plot
        scatter = ax.scatter(mesh_vis[:, 0], mesh_vis[:, 1], c=p_t,
                           cmap='RdBu_r', s=10, vmin=vmin, vmax=vmax)
        
        # Mark source position if available
        if 'source_position' in dataset:
            x0 = dataset['source_position'][:]
            # Source coordinates are not normalized (not used for training)
            if data.normalize_data:
                x0 = _normalize_spatial(x0, data.xmin, data.xmax)
            ax.plot(x0[0], x0[1], 'k*', markersize=15)
        
        ax.set_title(f'Source {src_idx}')
        ax.set_xlabel('x [m]')
        if i == 0:
            ax.set_ylabel('y [m]')
        ax.set_aspect('equal')
    
    # Add shared colorbar
    fig.colorbar(scatter, ax=axes, label='Pressure', pad=0.02)
    
    time_val = data.tsteps[time_idx]
    fig.suptitle(f'Pressure Fields at t_norm = {time_val:.4f}', y=1.02, fontsize=14)
    plt.tight_layout()
    plt.show()

## Visualizing the 2D Acoustic Data

Before training, let's understand what the data looks like by creating visualization functions.

In [None]:
# Visualize initial condition from training data (only 1 source available)
print("=" * 60)
print("TRAINING DATA - Initial Condition")
print("=" * 60)
plot_initial_condition(data_train, source_idx=0)

# Visualize pressure field evolution from val data
print("\n" + "=" * 60)
print("VALIDATION DATA - Pressure Field Evolution")
print("=" * 60)
plot_pressure_field_snapshots(data_val, source_idx=0, time_indices=[10, 40, 70, 95])

# Compare different source positions from test data
print("\n" + "=" * 60)
print("VALIDATION DATA - Multiple Source Positions")
print("=" * 60)
plot_multiple_sources(data_val, source_indices=[0, 1, 2], time_idx=50)

## 2. Fourier Feature Expansion

In [None]:
def fourier_feature_expansion(freqs=[]):
    """
    Create a Fourier feature expansion function.
    
    Args:
        freqs: List of frequencies for Fourier features
               Empty list means no Fourier features (just return input)
    
    Returns:
        Function that applies Fourier feature expansion
    """
    if len(freqs) == 0:
        return lambda y: y
    
    def expand(y):
        # y shape: [batch, coord_dim] or [batch, n_points, coord_dim]
        features = [y]
        for f in freqs:
            features.append(np.cos(2 * np.pi * f * y))
            features.append(np.sin(2 * np.pi * f * y))
        return np.concatenate (features, axis=-1)
    
    return expand

# Example: no Fourier features
feat_fn_none = fourier_feature_expansion(freqs=[])

# Example: with Fourier features at specific frequencies
feat_fn_fourier = fourier_feature_expansion(freqs=[1.0, 2.0])

# Test
test_input = np.random.rand(4, 3)  # [batch=4, coord_dim=3]
print(f"Input shape: {test_input.shape}")
print(f"Without Fourier features: {feat_fn_none(test_input).shape}")
print(f"With Fourier features: {feat_fn_fourier(test_input).shape}  # 3 + 2*2*3 = 15")

## 3. Network Components

### Sinusoidal Weight Initialization

For sine activations, proper weight initialization is CRITICAL.
Standard initializations (Xavier, He) don't work well with periodic activations.

Key parameters from JAX implementation (networks_flax.py):
- First layer: uniform(-1/d_in, 1/d_in)
- Hidden layers: uniform(-sqrt(6/d_in)/30, sqrt(6/d_in)/30)
- Angular frequency: 30.0 (scales first layer input)

Reference: SIREN paper (https://arxiv.org/abs/2006.09661)

In [7]:
import math

def sinusoidal_init_(weights: torch.Tensor, is_first_layer=False, angular_freq=30) -> None:
    """
    Initialize weights for networks with sine activations.

    Args:
        weights: PyTorch weight tensor with shape [out_features, in_features]
        is_first_layer: Use different range for first layer
        angular_freq: Angular frequency (default: 30)
    """
    with torch.no_grad():
        # IMPORTANT: PyTorch Linear weights have shape [out_features, in_features]
        # So shape[1] is the input dimension (d_in), NOT shape[0] as in Jax!
        d_in = weights.shape[1]  # Input features (correct for PyTorch)

        if is_first_layer:
            bound = 1.0 / d_in
        else:
            bound = math.sqrt(6.0 / d_in) / angular_freq

        weights.uniform_(-bound, bound)

# BranchNet and TrunkNet classes with sine initialization

In [8]:
class BranchNet(nn.Module):
    """Branch network with optional sine activation support."""
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers,
                 activation='relu', angular_freq=30.0):
        super().__init__()

        self.angular_freq = angular_freq if activation == 'sine' else 1.0

        layers = []

        # First layer
        first_layer = nn.Linear(input_dim, hidden_dim)
        if activation == 'sine':
            sinusoidal_init_(first_layer.weight, is_first_layer=True)
        layers.append(first_layer)

        # Hidden layers
        for _ in range(num_hidden_layers - 1):
            layer = nn.Linear(hidden_dim, hidden_dim)
            if activation == 'sine':
                sinusoidal_init_(layer.weight, is_first_layer=False, angular_freq=angular_freq)
            layers.append(layer)

        # Output layer  
        output_layer = nn.Linear(hidden_dim, output_dim)
        if activation == 'sine':
            sinusoidal_init_(output_layer.weight, is_first_layer=False, angular_freq=angular_freq)
        layers.append(output_layer)

        self.layers = nn.ModuleList(layers)

        # Select activation
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sine':
            self.activation = lambda x: torch.sin(x)
        else:
            raise ValueError(f"Unknown activation: {activation}")

    def forward(self, u):
        x = u

        for i, layer in enumerate(self.layers):
            # Apply angular frequency scaling to FIRST layer input (for sine networks)
            if i == 0:
                x = layer(self.angular_freq * x)
            else:
                x = layer(x)

            # Apply activation (except last layer)
            if i < len(self.layers) - 1:
                x = self.activation(x)

        return x

In [9]:
class TrunkNet(nn.Module):
    """Trunk network with optional sine activation support."""
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers,
                 activation='relu', angular_freq=30.0):
        super().__init__()

        self.angular_freq = angular_freq if activation == 'sine' else 1.0

        layers = []

        # First layer
        first_layer = nn.Linear(input_dim, hidden_dim)
        if activation == 'sine':
            sinusoidal_init_(first_layer.weight, is_first_layer=True)
        layers.append(first_layer)

        # Hidden layers
        for _ in range(num_hidden_layers - 1):
            layer = nn.Linear(hidden_dim, hidden_dim)
            if activation == 'sine':
                sinusoidal_init_(layer.weight, is_first_layer=False, angular_freq=angular_freq)
            layers.append(layer)

        # Output layer  
        output_layer = nn.Linear(hidden_dim, output_dim)
        if activation == 'sine':
            sinusoidal_init_(output_layer.weight, is_first_layer=False, angular_freq=angular_freq)
        layers.append(output_layer)

        self.layers = nn.ModuleList(layers)

        # Select activation
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sine':
            self.activation = lambda x: torch.sin(x)
        else:
            raise ValueError(f"Unknown activation: {activation}")

    def forward(self, y):
        x = y

        for i, layer in enumerate(self.layers):
            # Apply angular frequency scaling to FIRST layer input (for sine networks)
            if i == 0:
                x = layer(self.angular_freq * x)
            else:
                x = layer(x)

            # Apply activation (except last layer)
            if i < len(self.layers) - 1:
                x = self.activation(x)

        return x

## 4. DeepONet Model

In [10]:
class DeepONet(nn.Module):
    """
    Deep Operator Network combining branch and trunk networks.
    """
    def __init__(self, branch_net, trunk_net):
        """
        Args:
            branch_net: Branch network (BranchNet instance)
            trunk_net: Trunk network (TrunkNet instance)
        """
        super().__init__()
        self.branch_net = branch_net
        self.trunk_net = trunk_net
        
        # Learnable bias term
        self.b0 = nn.Parameter(torch.zeros(1))
    
    def forward(self, u, y):
        """
        Forward pass through DeepONet.
        
        Args:
            u: Branch input [batch_branch, u_dim]
            y: Trunk input [batch_branch, n_points, coord_dim]
        
        Returns:
            Predictions [batch_branch, n_points]
        """

        # Get branch and trunk latent representations
        branch_output = self.branch_net(u)  # [batch_branch, p]
        trunk_output = self.trunk_net(y)    # [batch_branch, n_points, p]
        
        # Inner product operation
        # We need to compute: sum over p dimension of (branch * trunk)
        # branch_output: [batch_branch, p]
        # trunk_output: [batch_branch, n_points, p]
        # Goal: [batch_branch, n_points]
        
        # Expand branch dimensions and element-wise multiply        
        branch_expanded = branch_output.unsqueeze(1)  # [batch_branch, 1, p]
        s_pred = torch.sum(branch_expanded * trunk_output, dim=-1)  # [batch_branch, n_points]
        
        # Add bias
        s_pred = s_pred + self.b0
        
        return s_pred

### Test Implementation

In [None]:
# Create test networks
test_branch = BranchNet(input_dim=100, hidden_dim=64, output_dim=40, num_hidden_layers=2, activation='relu')
test_trunk = TrunkNet(input_dim=3, hidden_dim=64, output_dim=40, num_hidden_layers=2, activation='relu')
test_model = DeepONet(test_branch, test_trunk)

# Create random test data
test_u = torch.randn(4, 100)   # [batch_branch=4, u_dim=100]
test_y = torch.randn(4, 50, 3) # [batch_branch=4, n_points=50, coord_dim=3]

# Forward pass
test_output = test_model(test_u, test_y)

print(f"Branch output shape: {test_branch(test_u).shape}")
print(f"Trunk output shape: {test_trunk(test_y).shape}")
print(f"DeepONet output shape: {test_output.shape}")
print(f"Expected output shape: [4, 50]")

assert test_output.shape == (4, 50), "Output shape is incorrect!"
print("\n✓ Test passed! Implementation is correct.")

## 5. Training Loop

In [12]:
def train_deeponet(model, dataloader_train, dataloader_val, num_epochs, learning_rate, device='cpu', grad_clip=None, gamma=0.995, warmup_epochs=5):
    """
    Train DeepONet model.
    
    Args:
        model: DeepONet instance
        dataloader_train: Training data loader
        dataloader_val: Validation data loader
        num_epochs: Number of training epochs
        learning_rate: Initial learning rate for optimizer
        device: 'cpu' or 'cuda'
        grad_clip: Gradient clipping value (None = no clipping, 0.01 recommended for sine networks)
        gamma: Decay rate for exponential learning rate schedule (default: 0.995)
        warmup_epochs: Number of epochs for learning rate warmup (default: 5)
    
    Returns:
        Dictionary with training history
    """
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    
    # Setup learning rate scheduler with warmup
    use_warmup = False
    if use_warmup:
        def lr_lambda(epoch):
            if epoch < warmup_epochs:
                # Linear warmup from 0.1 to 1.0
                return 0.1 + 0.9 * (epoch / warmup_epochs)
            else:
                # Exponential decay after warmup
                return gamma ** (epoch - warmup_epochs)    
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    else:
        scheduler =  torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
    
    history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
    
    step = 0
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss_epoch = 0
        num_batches = 0
        
        for batch in dataloader_train:
            (u, y), s_true, _, _ = batch
            u, y, s_true = u.to(device), y.to(device), s_true.to(device)
            
            # Forward pass
            s_pred = model(u, y)
            loss = criterion(s_pred, s_true)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping (critical for sine networks!)
            if grad_clip is not None and epoch > warmup_epochs:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            optimizer.step()            
            
            train_loss_epoch += loss.item()
            num_batches += 1
            step += 1
        
        scheduler.step()
        
        train_loss_epoch /= num_batches
        history['train_loss'].append(train_loss_epoch)
        
        # Record current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        history['learning_rate'].append(current_lr)
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            # Validating
            model.eval()
            val_loss_epoch = 0
            num_val_batches = 0
            with torch.no_grad():
                for batch in dataloader_val:
                    (u, y), s_true, _, _ = batch
                    u, y, s_true = u.to(device), y.to(device), s_true.to(device)
                    
                    s_pred = model(u, y)
                    loss = criterion(s_pred, s_true)
                    
                    val_loss_epoch += loss.item()
                    num_val_batches += 1
        
            val_loss_epoch /= num_val_batches
            history['val_loss'].append(val_loss_epoch)

            print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss_epoch:.6f}, Val Loss: {val_loss_epoch:.6f}, LR: {current_lr:.6f}")
        else:
            print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss_epoch:.6f}, LR: {current_lr:.6f}")
    
    return history

In [13]:
def plot_training_history(history, title="Training History"):
    """Plot training and test loss over epochs."""
    plt.figure(figsize=(10, 5))
    plt.semilogy(history['train_loss'], label='Train Loss', linewidth=2)
    plt.semilogy(history['val_loss'], label='Validation Loss', linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('MSE Loss', fontsize=12)
    plt.title(title, fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [17]:
# Hyperparameters
input_dim_branch = data_train.u_shape[0]  # From data
input_dim_trunk = 3  # x, y, t coordinates
hidden_dim = 2048
output_dim = 100  # latent dimension p
num_hidden_layers = 2
learning_rate = 1e-3
num_epochs = 100

grad_clip = None

use_fourier_expansions = False
use_sine = True # otherwise relu
angular_freq = 30

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [18]:
def plot_losses(history, label):
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.semilogy(history['train_loss'], label=f'f{label} - Train', linewidth=2, linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.semilogy(history['val_loss'], label='f{label} - Test', linewidth=2, linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Validation Loss Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print final results
    print(f"\nFinal Validation Loss:")
    print(f"  f{label}: {history['val_loss'][-1]:.6f}")

In [None]:
if not use_fourier_expansions:
    feat_fn_fourier = None
    if use_sine:
        # Create sine network WITH proper initialization
        activation = "sine"
        angular_freq = angular_freq        
    else:
        activations="relu"
        angular_freq = None        
    
    # Create model with ReLU activation
    branch = BranchNet(input_dim_branch, hidden_dim, output_dim, num_hidden_layers, activation=activation, angular_freq=angular_freq)
    trunk = TrunkNet(input_dim_trunk, hidden_dim, output_dim, num_hidden_layers, activation=activation, angular_freq=angular_freq)
    model = DeepONet(branch, trunk)

    label_str = "Sine" if use_sine else "ReLU"
    print(f"\nTraining model with {label_str} activation...")
    history_sine = train_deeponet(model, dataloader_train, dataloader_val, num_epochs, learning_rate, device, grad_clip=grad_clip)

    plot_losses(history_sine, label_str)
else:    
    if use_sine:
        # Create sine network WITH proper initialization
        activation = "sine"
        angular_freq = angular_freq        
    else:
        activations="relu"
        angular_freq = None        

    # Create feature expansion functions
    # use normalized frequencies
    c_phys = 343
    freq_norm = [166.7 / c_phys, 250 / c_phys, 500 / c_phys]
    feat_fn_fourier = fourier_feature_expansion(freqs=freq_norm)

    # Create datasets with different feature expansions
    dataset_train_fourier = DatasetStreamer(data_train, batch_size_coord=batch_size_coord, y_feat_extract_fn=feat_fn_fourier)
    dataset_test_fourier = DatasetStreamer(data_val, batch_size_coord=-1, y_feat_extract_fn=feat_fn_fourier)

    # Create dataloaders
    dataloader_train_fourier = DataLoader(dataset_train_fourier, batch_size=batch_size_branch, shuffle=True,
                                        collate_fn=pytorch_collate, drop_last=True)
    dataloader_val_fourier = DataLoader(dataset_test_fourier, batch_size=1, shuffle=False, collate_fn=pytorch_collate)

    # Get actual input dimensions
    sample_fourier = next(iter(dataloader_train_fourier))
    input_dim_trunk_fourier = sample_fourier[0][1].shape[-1]  # Should be 3 + 2*3*3 = 21

    print(f"Trunk input dim with Fourier features: {input_dim_trunk_fourier}")

    # Create models with Fourier features
    branch_fourier = BranchNet(input_dim_branch, hidden_dim, output_dim, num_hidden_layers, activation=activation)
    trunk_fourier = TrunkNet(input_dim_trunk_fourier, hidden_dim, output_dim, num_hidden_layers, activation=activation)
    model = DeepONet(branch_fourier, trunk_fourier)

    print("\nTraining model WITH Fourier features...")
    history_fourier = train_deeponet(model, dataloader_train_fourier, dataloader_val_fourier, num_epochs, learning_rate, device, grad_clip=grad_clip)

    plot_losses(history_fourier, "Fourier")

## Summary

This notebook demonstrates:
- ✓ Complete DeepONet implementation in PyTorch
- ✓ Branch and trunk network architectures
- ✓ Inner product operation for operator learning
- ✓ Comparison of ReLU vs Sine activations
- ✓ Impact of Fourier feature expansions
- ✓ Full training and evaluation pipeline

## 7. Model Inference and Evaluation

Now let's use the trained model to make predictions on test data and compare against ground truth.

**Important**: The inference functions automatically detect whether you trained with Fourier features (`use_fourier_expansions = True`) and will use the appropriate model and feature extraction function. This ensures consistency between training and inference.

### Key Observations

The inference results demonstrate:

1. **Full Wavefield Prediction**: The model can predict the entire pressure field across space and time for new source configurations not seen during training.

2. **Impulse Response Accuracy**: By snapping receiver positions to the training grid using `utils.getNearestFromCoordinates`, we can accurately compare predictions with ground truth data.

3. **Model Generalization**: The error metrics show how well the trained DeepONet generalizes to unseen test data, which is crucial for practical applications.

4. **Computational Efficiency**: Once trained, the DeepONet can predict the full acoustic field much faster than running traditional PDE solvers, making it suitable for real-time applications or design optimization.

In [None]:
data_test = DataH5Compact(
    test_data_path,
    t_norm=343.0,
    flatten_ic=True,
    norm_data=True,
    u_p_range=(-2.0, 2.0),
)

source_idx = 0

print(f"Test sources: {data_test.N}")
print(f"Test source_idx: {source_idx}")
print(f"Mesh points: {data_test.P_mesh}")
print(f"Time steps: {len(data_test.tsteps)}")

In [None]:
import deeponet_acoustics.utils.utils as utils


def predict_impulse_response(model, data, source_idx, receiver_positions, device='cpu', snap_to_grid=True, y_feat_extract_fn=None):
    """
    Predict impulse response at specific receiver locations.
    
    Args:
        model: Trained DeepONet model
        data: DataH5Compact instance
        source_idx: Index of source
        receiver_positions: List of receiver positions [[x1, y1], [x2, y2], ...]
        device: Device to run predictions on
        snap_to_grid: Whether to snap receiver positions to nearest grid points
        y_feat_extract_fn: Optional feature extraction function for coordinates (e.g., Fourier features)
    
    Returns:
        ir_pred: Predicted impulse responses [n_receivers, time_steps]
        ir_true: Ground truth impulse responses [n_receivers, time_steps] (if available)
        receiver_pos_actual: Actual receiver positions (snapped if snap_to_grid=True)
    """
    model.eval()
    
    # Get data for this source - but we'll manually create the coordinate array
    dataset_single = DatasetStreamer(data, batch_size_coord=-1)
    (u, _), s_true, _, x0 = dataset_single[source_idx]
    
    # Get mesh in physical coordinates for snapping
    mesh_phys = data.denormalize_spatial(data.mesh)
    
    # Snap to grid if requested
    if snap_to_grid:
        receiver_positions_list = [receiver_positions]  # Wrap for utils function
        receiver_pos_actual, receiver_indices = utils.getNearestFromCoordinates(
            mesh_phys, receiver_positions_list
        )
        receiver_pos_actual = receiver_pos_actual[0]
        receiver_indices = receiver_indices[0]
    else:
        receiver_pos_actual = np.array(receiver_positions)
        receiver_indices = None
    
    # Normalize receiver positions for model input
    receiver_pos_norm = data.normalize_spatial(receiver_pos_actual)
    
    # Create coordinate array: repeat each receiver position for all time steps
    n_receivers = len(receiver_pos_norm)
    tdim = len(data.tsteps)
    
    y_rcvs = np.repeat(receiver_pos_norm, tdim, axis=0)  # [n_receivers*tdim, 2]
    tsteps_rcvs = np.tile(data.tsteps, n_receivers)  # [n_receivers*tdim]
    y_input = np.concatenate((y_rcvs, np.expand_dims(tsteps_rcvs, 1)), axis=1)  # [n_receivers*tdim, 3]
    
    # Apply feature extraction if provided (e.g., Fourier features)
    if y_feat_extract_fn is not None:
        y_input = y_feat_extract_fn(y_input)
    
    # Convert to torch tensors
    u_torch = torch.from_numpy(u).float().unsqueeze(0).to(device)  # [1, u_dim]
    y_torch = torch.from_numpy(y_input).float().unsqueeze(0).to(device)  # [1, n_receivers*tdim, coord_dim]
    
    # Predict
    with torch.no_grad():
        ir_pred_flat = model(u_torch, y_torch)  # [1, n_receivers*tdim]
        ir_pred = ir_pred_flat.cpu().numpy().reshape(n_receivers, tdim).T  # [tdim, n_receivers]
    
    # Extract ground truth if snapped to grid
    if snap_to_grid and receiver_indices is not None:
        s_true_reshaped = s_true.reshape(tdim, -1)
        ir_true = s_true_reshaped[:, receiver_indices]  # [tdim, n_receivers]
    else:
        ir_true = None
    
    return ir_pred, ir_true, receiver_pos_actual


# Define receiver positions (in physical coordinates)
receiver_positions = [
    [1.0, 1.0],   # Center
    [0.5, 1.5],   # Top left
    [1.5, 0.5],   # Bottom right
]

# Predict impulse responses - use same model and features as for full wavefield
ir_pred, ir_true, receiver_pos_actual = predict_impulse_response(
    model, data_test, source_idx, receiver_positions, device, snap_to_grid=True, y_feat_extract_fn=feat_fn_fourier
)

print(f"Predicted IR shape: {ir_pred.shape}")
print(f"Ground truth IR shape: {ir_true.shape}")
print(f"\nReceiver positions (snapped to grid):")
for i, pos in enumerate(receiver_pos_actual):
    print(f"  Receiver {i+1}: [{pos[0]:.3f}, {pos[1]:.3f}]")

In [None]:
def plot_impulse_responses(data, ir_pred, ir_true, receiver_positions, figsize=(15, 10)):
    """
    Plot impulse responses at receiver locations.
    
    Args:
        data: DataH5Compact instance
        ir_pred: Predicted impulse responses [time_steps, n_receivers]
        ir_true: Ground truth impulse responses [time_steps, n_receivers]
        receiver_positions: Receiver positions [[x1, y1], [x2, y2], ...]
        figsize: Figure size
    """
    n_receivers = ir_pred.shape[1]
    tsteps_phys = data.denormalize_temporal(data.tsteps / 343.0)  # Convert to physical time
    
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(2, n_receivers, height_ratios=[1, 2], hspace=0.3, wspace=0.3)
    
    # Top row: Receiver positions on mesh
    ax_mesh = fig.add_subplot(gs[0, :])
    
    # Plot mesh
    ax_mesh.scatter(data.mesh[:, 0], data.mesh[:, 1], c='lightgray', s=1, alpha=0.3)
    
    # Plot receiver positions
    colors = plt.cm.tab10(np.linspace(0, 1, n_receivers))
    for i, (pos, color) in enumerate(zip(receiver_positions, colors)):
        # Normalize for plotting
        pos_norm = data.normalize_spatial(np.array([pos]))[0]
        ax_mesh.scatter(pos_norm[0], pos_norm[1], c=[color], s=200, marker='o', 
                       edgecolors='black', linewidths=2, label=f'Receiver {i+1}', zorder=5)
    
    ax_mesh.set_xlabel('x [m]')
    ax_mesh.set_ylabel('y [m]')
    ax_mesh.set_title('Receiver Positions')
    ax_mesh.set_aspect('equal')
    ax_mesh.legend(loc='upper right')
    ax_mesh.grid(True, alpha=0.3)
    
    # Bottom row: Impulse responses
    for i in range(n_receivers):
        ax = fig.add_subplot(gs[1, i])
        
        # Plot ground truth and prediction
        if ir_true is not None:
            ax.plot(tsteps_phys, ir_true[:, i], 'k-', linewidth=2, label='Ground Truth', alpha=0.7)
        ax.plot(tsteps_phys, ir_pred[:, i], '--', color=colors[i], linewidth=2, label='Prediction')
        
        ax.set_xlabel('Time [s]')
        ax.set_ylabel('Pressure')
        ax.set_title(f'Receiver {i+1}: [{receiver_positions[i][0]:.2f}, {receiver_positions[i][1]:.2f}]')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Compute error metrics if ground truth available
        if ir_true is not None:
            mse = np.mean((ir_pred[:, i] - ir_true[:, i])**2)
            mae = np.mean(np.abs(ir_pred[:, i] - ir_true[:, i]))
            ax.text(0.02, 0.98, f'MSE: {mse:.6f}\nMAE: {mae:.6f}', 
                   transform=ax.transAxes, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.show()


# Plot impulse responses
plot_impulse_responses(data_test, ir_pred, ir_true, receiver_pos_actual)

### Impulse Response at Specific Receiver Locations

Now let's predict the impulse response (pressure over time) at specific receiver locations and compare with ground truth.

In [None]:
def predict_full_wavefield(model, data, source_idx=0, device='cpu', y_feat_extract_fn=None):
    """
    Predict full wavefield for a single source.
    
    Args:
        model: Trained DeepONet model
        data: DataH5Compact instance
        source_idx: Index of source to predict
        device: Device to run predictions on
        y_feat_extract_fn: Optional feature extraction function for coordinates (e.g., Fourier features)
    
    Returns:
        s_pred: Predicted pressure field [time_steps, mesh_points]
        s_true: Ground truth pressure field [time_steps, mesh_points]
        u: Initial condition
        x0: Source position
    """
    model.eval()
    
    # Get data for this source - apply same feature extraction as during training
    dataset_single = DatasetStreamer(data, batch_size_coord=-1, y_feat_extract_fn=y_feat_extract_fn)
    (u, y), s_true, _, x0 = dataset_single[source_idx]
    
    # Convert to torch tensors
    u = torch.from_numpy(u).float().unsqueeze(0).to(device)  # [1, u_dim]
    y = torch.from_numpy(y).float().unsqueeze(0).to(device)  # [1, n_points, coord_dim]
    
    tdim = len(data.tsteps)
    n_mesh = data.P_mesh
    
    # Predict for all time steps
    with torch.no_grad():
        s_pred = model(u, y)  # [1, n_points]
        s_pred = s_pred.cpu().numpy().reshape(tdim, n_mesh)
    
    s_true = s_true.reshape(tdim, n_mesh)
    
    return s_pred, s_true, u.cpu().numpy().squeeze(), x0


# Predict for first test source
source_idx = 0
s_pred, s_true, u_test, x0 = predict_full_wavefield(model, data_test, source_idx, device, y_feat_extract_fn=feat_fn_fourier)

print(f"Predicted wavefield shape: {s_pred.shape}")
print(f"Ground truth wavefield shape: {s_true.shape}")
print(f"Source position: {x0}")

In [None]:
def plot_wavefield_comparison(data, s_pred, s_true, source_idx=0, time_indices=[20, 40, 60, 80], figsize=(16, 8)):
    """
    Plot comparison between predicted and ground truth wavefields.
    
    Args:
        data: DataH5Compact instance
        s_pred: Predicted pressure field [time_steps, mesh_points]
        s_true: Ground truth pressure field [time_steps, mesh_points]
        source_idx: Index of source
        time_indices: List of time step indices to plot
        figsize: Figure size
    """
    n_plots = len(time_indices)
    fig, axes = plt.subplots(3, n_plots, figsize=figsize)
    
    vmin, vmax = -0.5, 0.5
    
    dataset = data.datasets[source_idx]
    
    for i, t_idx in enumerate(time_indices):
        # Ground truth
        scatter1 = axes[0, i].scatter(data.mesh[:, 0], data.mesh[:, 1], c=s_true[t_idx, :],
                                     cmap='RdBu_r', s=10, vmin=vmin, vmax=vmax)
        axes[0, i].set_title(f't_norm = {data.tsteps[t_idx]:.2f}')
        axes[0, i].set_aspect('equal')
        if i == 0:
            axes[0, i].set_ylabel('Ground Truth\ny [m]')
        
        # Prediction
        scatter2 = axes[1, i].scatter(data.mesh[:, 0], data.mesh[:, 1], c=s_pred[t_idx, :],
                                     cmap='RdBu_r', s=10, vmin=vmin, vmax=vmax)
        axes[1, i].set_aspect('equal')
        if i == 0:
            axes[1, i].set_ylabel('Prediction\ny [m]')
        
        # Error
        error = np.abs(s_pred[t_idx, :] - s_true[t_idx, :])
        scatter3 = axes[2, i].scatter(data.mesh[:, 0], data.mesh[:, 1], c=error,
                                     cmap='hot', s=10, vmin=0, vmax=0.1)
        axes[2, i].set_xlabel('x [m]')
        axes[2, i].set_aspect('equal')
        if i == 0:
            axes[2, i].set_ylabel('Absolute Error\ny [m]')
        
        # Mark source position
        if 'source_position' in dataset:
            x0 = dataset['source_position'][:]
            if data.normalize_data:
                from deeponet_acoustics.datahandlers.datagenerators import _normalize_spatial
                x0 = _normalize_spatial(x0, data.xmin, data.xmax)
            for ax_row in axes[:, i]:
                ax_row.plot(x0[0], x0[1], 'k*', markersize=10)
    
    # Add colorbars
    fig.colorbar(scatter1, ax=axes[0, :], label='Pressure', pad=0.02, fraction=0.046)
    fig.colorbar(scatter2, ax=axes[1, :], label='Pressure', pad=0.02, fraction=0.046)
    fig.colorbar(scatter3, ax=axes[2, :], label='|Error|', pad=0.02, fraction=0.046)
    
    fig.suptitle(f'Wavefield Comparison (Source {source_idx})', y=0.995, fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Compute and print metrics
    mse = np.mean((s_pred - s_true)**2)
    mae = np.mean(np.abs(s_pred - s_true))
    max_error = np.max(np.abs(s_pred - s_true))
    
    print(f"\nPrediction Metrics:")
    print(f"  MSE: {mse:.6f}")
    print(f"  MAE: {mae:.6f}")
    print(f"  Max Error: {max_error:.6f}")


# Plot comparison
plot_wavefield_comparison(data_test, s_pred, s_true, source_idx, time_indices=[20, 40, 60, 80])

### Discussion Questions

After running the experiments, discuss:

1. **Activation Functions**:
   - Which activation (ReLU vs Sine) performed better? Why might this be?
   - How did convergence speed differ between the two?
   - What physical properties of acoustic waves might favor one activation over another?

2. **Fourier Features**:
   - Did Fourier features improve performance? By how much?
   - What is the computational cost of Fourier features (hint: look at input dimensions)?
   - When would you recommend using Fourier features?

3. **DeepONet Architecture**:
   - Why is the operator learning approach useful for PDEs?
   - What are the advantages of DeepONet vs traditional PDE solvers?
   - What are the limitations?