Sets up dataloader for FDTD data from file.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
import numpy as np

DATA_FILE = "fdtdData_random_smooth_nonlinear_SET1.npz"
#Load Data
data = np.load(DATA_FILE, allow_pickle=True)
samples = data["samples"]
sample_count = 1000
samples = samples[:sample_count]
prediction_index = 8

all_inputs = []
all_outputs = []

# Compute means and stds for normalization 
# Inputs
epsR_all = np.concatenate([s["epsR"] for s in samples])
src0_all = np.concatenate([s["source_prop"][0] for s in samples])
src1_all = np.concatenate([s["source_prop"][1] for s in samples])

epsR_mean, epsR_std = epsR_all.mean(), epsR_all.std()
src0_mean, src0_std = src0_all.mean(), src0_all.std()
src1_mean, src1_std = src1_all.mean(), src1_all.std()

# Outputs
ez_all = np.concatenate([s["ez_history"][prediction_index] for s in samples])
hy_all = np.concatenate([s["hy_history"][prediction_index] for s in samples])

ez_mean, ez_std = ez_all.mean(), ez_all.std()
hy_mean, hy_std = hy_all.mean(), hy_all.std()

for sample in samples:
    N = sample["epsR"].shape[0]
    

    ez_pred_snaps = np.array(sample["ez_history"])[prediction_index]
    hy_pred_snaps = np.array(sample["hy_history"])[prediction_index]
    pred_time_points = np.array(sample["time_points"])[prediction_index]
    # Normalize inputs
    epsR_norm = (sample["epsR"] - epsR_mean) / epsR_std
    src0_norm = (sample["source_prop"][0] - src0_mean) / src0_std
    src1_norm = (sample["source_prop"][1] - src1_mean) / src1_std
        
    # Stack input features 
    sample_inputs = np.stack([
        epsR_norm,
        src0_norm,
        src1_norm
    ], axis=0)  

    all_inputs.append(sample_inputs)

    # Normalize outputs
    ez_norm = (ez_pred_snaps - ez_mean) / ez_std
    hy_norm = (hy_pred_snaps - hy_mean) / hy_std
        
    #Stack outputs (shape = (2, ...))
    sample_outputs = np.stack((ez_norm, hy_norm), axis=0)
    all_outputs.append(sample_outputs)

# Convert lists to arrays
all_inputs = np.stack(all_inputs, axis=0)   

all_outputs = np.stack(all_outputs, axis=0) 

class FieldDataset(Dataset):
    def __init__(self, inputs, outputs):
        """
        inputs: numpy array, shape (N, C_in, L)
        outputs: numpy array, shape (N, C_out, L)
        """
        self.inputs = torch.from_numpy(inputs).float()
        self.outputs = torch.from_numpy(outputs).float()

    def __len__(self):
        return self.inputs.shape[0]

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]
    
dataset = FieldDataset(all_inputs, all_outputs)
        

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

batch_size = 20

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)



Fourier layer definition

In [None]:
import numpy as np
import torch
import torch.nn as nn



from functools import reduce
from functools import partial


import torch
import torch.nn as nn

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, mask):
        super(SpectralConv1d, self).__init__()

        """
        1D Fourier layer. It does FFT, multiplies selected spectral modes, and Inverse FFT.
        mask: boolean tensor indicating which Fourier modes to keep (length = N//2+1)
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mask = mask               # boolean mask for Fourier modes
        self.num = torch.sum(mask)     # number of modes to keep
        self.scale = 1 / (in_channels * out_channels)
        # Complex weights for selected modes
        self.weights = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.num, dtype=torch.cfloat))

        # Optional local 1D convolution for residual / smoothing in time domain
        self.w = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=True)

    def compl_mul1d(self, input, weights):
        # input: (B, c, num_modes), weights: (c_in, c_out, num_modes)
        # output: (B, c_out, num_modes)
        return torch.einsum("bcn,con->bon", input, weights)

    def forward(self, x):
        batchsize, c, N = x.shape
        x_ft = torch.fft.rfft(x, dim=-1)
        out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-1], dtype=torch.cfloat, device=x.device)
        mask_indices = self.mask.nonzero(as_tuple=True)[0]
        x_selected = x_ft[:, :, mask_indices]
        out_ft[:, :, mask_indices] = self.compl_mul1d(x_selected, self.weights)
        x = torch.fft.irfft(out_ft, n=N, dim=-1)
        return x

Model definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Learnable1DPosEmbedding(nn.Module):
    def __init__(self, N, num_freqs):
        super().__init__()
        # Learnable embedding for N points, 2*num_freqs channels
        self.embedding = nn.Parameter(torch.randn(2*num_freqs, N))

    def forward(self, x):
        # x: (B, C_in, N)
        B = x.shape[0]
        emb = self.embedding.unsqueeze(0).repeat(B, 1, 1)  # (B, 2*num_freqs, N)
        return torch.cat([x, emb], dim=1)

class FNO1d(nn.Module):
    def __init__(self, in_channels, out_channels, width, layers, mask, N, padding=0,
                 num_pos_freqs=4, use_learnable_pe=True, use_channel_mlp=True):
        """
        1D Fourier Neural Operator inspired model.

        Args:
            in_channels: input channels (e.g. epsR, source_prop)
            out_channels: output channels (e.g. Ez, Hy)
            width: hidden width for latent space
            layers: number of Fourier layers
            mask: boolean tensor for allowed Fourier modes (length=N//2+1)
            padding: optional padding
            num_pos_freqs: number of learnable positional frequencies
            use_learnable_pe: whether to add learnable positional embeddings
            use_channel_mlp: whether to add a channel-wise MLP per Fourier layer
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.width = width
        self.layers = layers
        self.padding = padding
        self.mask = mask
        self.use_channel_mlp = use_channel_mlp

        # Optional learnable positional embeddings
        self.pos_embedding = Learnable1DPosEmbedding(N, num_pos_freqs) if use_learnable_pe else None

        # Input lifting to latent space
        lift_in_channels = in_channels + (2*num_pos_freqs if use_learnable_pe else 0)
        self.lifting = nn.Sequential(
            nn.Linear(lift_in_channels, width),
            nn.ReLU(),
            nn.Linear(width, width)
        )

        # Fourier layers + residuals + optional channel MLP
        self.conv_layers = nn.ModuleList([SpectralConv1d(width, width, mask) for _ in range(layers)])
        self.res_layers = nn.ModuleList([nn.Conv1d(width, width, kernel_size=1) for _ in range(layers)])
        self.bn_layers = nn.ModuleList([nn.BatchNorm1d(width) for _ in range(layers)])

        if use_channel_mlp:
            self.channel_mlps = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(width, width*2),
                    nn.ReLU(),
                    nn.Linear(width*2, width)
                ) for _ in range(layers)
            ])
        else:
            self.channel_mlps = [None]*layers

        # Output projection
        self.projection = nn.Sequential(
            nn.Linear(width, width*2),
            nn.ReLU(),
            nn.Linear(width*2, out_channels)
        )

    def forward(self, x):
        """
        x: (B, C_in, N)
        returns: (B, out_channels, N)
        """
        # Add positional embedding
        if self.pos_embedding is not None:
            x = self.pos_embedding(x)  # (B, C_in + PE, N)

        # Lift to latent space
        x = x.permute(0, 2, 1)  # (B, N, C)
        x = self.lifting(x)     # (B, N, width)
        x = x.permute(0, 2, 1)  # (B, width, N)

        # Optional padding
        if self.padding > 0:
            x = F.pad(x, (0, self.padding))

        # Fourier layers
        for i, (conv, res, bn) in enumerate(zip(self.conv_layers, self.res_layers, self.bn_layers)):
            x1 = conv(x)
            x2 = res(x)
            x = x1 + x2
            x = bn(x)
            x = F.relu(x)
            # Optional channel MLP
            if self.use_channel_mlp and self.channel_mlps[i] is not None:
                x_perm = x.permute(0, 2, 1)  # (B, N, width)
                x_perm = self.channel_mlps[i](x_perm)
                x = x_perm.permute(0, 2, 1)

        # Remove padding
        if self.padding > 0:
            x = x[:, :, :-self.padding]

        # Project to output
        x = x.permute(0, 2, 1)  # (B, N, width)
        x = self.projection(x)  # (B, N, out_channels)
        x = x.permute(0, 2, 1)  # (B, out_channels, N)
        return x

New mode selection scheme

In [None]:
import numpy as np
import torch

def compute_spectral_mask(all_outputs, theta=0.9):
    """
    Compute a data-driven spectral mask for 1D FNO similar to Wu paper.
    all_outputs: np.array, shape (num_samples, out_channels, N)
    theta: cumulative energy threshold (0-1)
    Returns: mask (torch.bool, shape (N//2+1,))
    """
    # FFT along spatial dimension
    spectra = np.fft.rfft(all_outputs, axis=-1)
    magnitudes = np.abs(spectra)
    
    # Average over all samples and channels
    mean_spectrum = magnitudes.mean(axis=(0,1))
    
    # Normalize cumulative sum
    cumulative = np.cumsum(mean_spectrum) / np.sum(mean_spectrum)
    
    # Keep modes where cumulative < theta
    mask = cumulative <= theta
    return torch.tensor(mask, dtype=torch.bool)

initialize model

In [None]:
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

mask = compute_spectral_mask(all_outputs, theta=0.4)
num_modes_kept = mask.sum().item()  # convert tensor to integer
print("Number of Fourier modes kept:", num_modes_kept)
in_channels = all_inputs.shape[1]   
out_channels = all_outputs.shape[1] 
width = 15
layers = 5
padding = int(0.05 * N)  # optional

model = FNO1d(in_channels, out_channels, width, layers, mask, N=all_inputs.shape[-1], padding=padding).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Relative L2 loss function. just MSE
class RelativeL2Loss1D(nn.Module):
    def __init__(self, eps=1e-8):
        super(RelativeL2Loss1D, self).__init__()
        self.eps = eps

    def forward(self, y_pred, y_true):
        """
        y_pred, y_true: (B, C, N)
        """
        diff = y_pred - y_true
        numerator = torch.linalg.norm(diff, dim=(1,2))          # L2 over channels and sequence
        denominator = torch.linalg.norm(y_true, dim=(1,2)) + self.eps
        relative_l2 = numerator / denominator
        return relative_l2.mean()

criterion = RelativeL2Loss1D()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=10
)

epochs = 200

train_losses = []
val_losses = []

for epoch in range(epochs):
    
    # Training
    
    model.train()
    train_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        optimizer.zero_grad()
        y_pred = model(x_batch)


        loss = criterion(y_pred, y_batch)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # stabilize training
        optimizer.step()

        train_loss += loss.item() * x_batch.size(0)

    train_loss /= len(train_dataset)

    
    # Validation
    
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            y_pred = model(x_batch)
            loss = criterion(y_pred, y_batch)

            val_loss += loss.item() * x_batch.size(0)

    val_loss /= len(val_dataset)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train RelL2: {train_loss:.6f} | "
        f"Val RelL2: {val_loss:.6f} | "
        f"LR: {optimizer.param_groups[0]['lr']:.2e}"
    )

plt.figure()
plt.plot(train_losses)
plt.plot(val_losses)
plt.xlabel("Epoch")
plt.ylabel("Relative L2 Loss")
plt.title("Training and Validation Loss")
plt.legend(["Train", "Validation"])
plt.show()

save checkpoint

In [None]:
torch.save(model.state_dict(), f".\checkpoints\model_{sample_count}samples_1step_linear_smooth_extralayersrun.pt")

check a prediction

In [None]:
import matplotlib.pyplot as plt

# Set model to evaluation mode
model.eval()

# Pick a batch from validation set
x_batch, y_batch = next(iter(val_loader))
x_batch, y_batch = x_batch.to(device), y_batch.to(device)

# Get model predictions
with torch.no_grad():
    y_pred = model(x_batch)

# Move back to CPU for plotting
y_pred = y_pred.cpu().numpy()
y_batch = y_batch.cpu().numpy()
x_batchp = x_batch.cpu().numpy()

# Pick the first sample in the batch
sample_idx = 0

# Assume channel 0 is Ez
plt.figure(figsize=(10,5))
plt.plot(y_batch[sample_idx, 0], label='True Ez', color='blue')
plt.plot(y_pred[sample_idx, 0], label='Predicted Ez', color='red', linestyle='--')
#plt.plot(x_batchp[sample_idx, 0], label = 'eps', color = 'green') #This plots eps optionally
#plt.plot(x_batchp[sample_idx, 1], label = 'source', color = 'purple') #this plots the source
plt.xlabel('Grid Point Index')
plt.ylabel('Ez Field')
plt.title('Model Prediction vs True Ez')
plt.legend()
plt.show()

plot more samples

In [None]:
for i in range(min(10, x_batch.size(0))):
    plt.figure(figsize=(10,4))
    plt.plot(y_batch[i,0], label='True Ez')
    plt.plot(y_pred[i,0], label='Predicted Ez', linestyle='--')
    #plt.plot(x_batchp[i, 0], label = 'eps', color = 'grey')
    #plt.plot(x_batchp[i, 1], label = 'source', color = 'purple')
    plt.title(f'Sample {i}')
    plt.legend()
    plt.show()