# Task 3: All2All Training

In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, Dataset
import torch.nn.functional as F
import matplotlib.pyplot as plt

## Model Definition

In [2]:
from FNO_bn import FNO1d_bn

In [3]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Verify CUDA setup
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    print(f"CUDA memory reserved: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB")

CUDA available: True
CUDA device count: 1
Current CUDA device: 0
CUDA device name: NVIDIA GeForce RTX 3070 Laptop GPU
CUDA memory allocated: 0.00 MB
CUDA memory reserved: 0.00 MB


## Import Data

In [5]:
torch.manual_seed(0)
np.random.seed(0)

In [6]:
n_train = 1024 # number of training samples

# train dataset shape: (1024, 5, 128)
# 1024: number of trajectories
# 5: time snapthots of the solution: t= 0, 0.25, 0.5, 0.75, 1.0
# 128: spatial resolution of the data

train_dataset = torch.from_numpy(np.load("data/data_train_128.npy")).type(torch.float32)
test_dataset = torch.from_numpy(np.load("data/data_test_128.npy")).type(torch.float32)
# add time as input feature
time_train = torch.linspace(0, 1, train_dataset.shape[1]).reshape(1, -1, 1)
time_test = torch.linspace(0, 1, test_dataset.shape[1]).reshape(1, -1, 1)
train_dataset = torch.cat([train_dataset, time_train.repeat(train_dataset.shape[0], 1, 1)], dim=-1)
test_dataset = torch.cat([test_dataset, time_test.repeat(test_dataset.shape[0], 1, 1)], dim=-1)

# add grid coordinates as input feature
grid_train = torch.linspace(0, 1, train_dataset.shape[2]).reshape(1, 1, -1)
grid_test = torch.linspace(0, 1, test_dataset.shape[2]).reshape(1, 1, -1)
train_dataset = torch.cat([train_dataset, grid_train.repeat(train_dataset.shape[0], train_dataset.shape[1], 1)], dim=-1)
test_dataset = torch.cat([test_dataset, grid_test.repeat(test_dataset.shape[0], test_dataset.shape[1], 1)], dim=-1)

# Move data to device
train_dataset = train_dataset.to(device)
test_dataset = test_dataset.to(device)

batch_size = 20
train_loader = DataLoader(TensorDataset(train_dataset), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(TensorDataset(test_dataset), batch_size=batch_size, shuffle=False)

In [7]:
class PDEDataset(Dataset):
    def __init__(self,
                 which="training",
                 training_samples = 256,
                 resolution = 128,
                 device='cpu'):

        self.resolution = resolution
        self.device = device
        self.data = np.load(f"data/data_train_{resolution}.npy")

        self.T = 5
        # Precompute all possible (t_initial, t_final) pairs within the specified range.
        self.time_pairs = [(i, j) for i in range(0, self.T) for j in range(i + 1, self.T)]
        self.len_times  = len(self.time_pairs)

        # Total samples available in the dataset
        total_samples = self.data.shape[0]
        self.n_val = 32
        self.n_test = 32

        if which == "training":
            self.length = training_samples * self.len_times
            self.start_sample = 0
        elif which == "validation":
            self.length = self.n_val * self.len_times
            self.start_sample = total_samples - self.n_val - self.n_test
        elif which == "test":
            self.length = self.n_test * self.len_times
            self.start_sample = total_samples - self.n_test

        self.mean = 0.018484
        self.std  = 0.685405
        
        # Pre-create grid to avoid recreating it each time
        self.grid = torch.linspace(0, 1, 128, dtype=torch.float32).reshape(128, 1).to(device)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        sample_idx = self.start_sample + index // self.len_times
        time_pair_idx = index % self.len_times
        t_inp, t_out = self.time_pairs[time_pair_idx]
        time = torch.tensor((t_out - t_inp)/4. + float(np.random.rand(1)[0]/10**6), dtype=torch.float32, device=self.device)

        inputs = torch.from_numpy(self.data[sample_idx, t_inp]).type(torch.float32).reshape(128, 1).to(self.device)
        inputs = (inputs - self.mean)/self.std #Normalize
        
        # Add grid coordinates (already on correct device and dtype)
        inputs = torch.cat((inputs, self.grid), dim=-1)  # (128, 2)

        outputs = torch.from_numpy(self.data[sample_idx, t_out]).type(torch.float32).reshape(128).to(self.device)
        outputs = (outputs - self.mean)/self.std #Normalize

        return time, inputs, outputs

In [8]:
# Check actual data statistics
train_data = np.load("data/data_train_128.npy")
print(f"Training data shape: {train_data.shape}")
print(f"Training data mean: {train_data.mean():.6f}")
print(f"Training data std: {train_data.std():.6f}")
print(f"Training data min: {train_data.min():.6f}")
print(f"Training data max: {train_data.max():.6f}")
print(f"\nCurrently using:")
print(f"  mean = 0")
print(f"  std = 0.3835")

Training data shape: (1024, 5, 128)
Training data mean: 0.018484
Training data std: 0.685405
Training data min: -3.095698
Training data max: 3.086819

Currently using:
  mean = 0
  std = 0.3835


### Instantiate Model

In [17]:
n_train = 1024 # Number of TRAJECTORIES for training
batch_size = 128

training_set = DataLoader(PDEDataset("training", n_train, device=device), batch_size=batch_size, shuffle=True)
testing_set = DataLoader(PDEDataset("validation", device=device), batch_size=batch_size, shuffle=False)

learning_rate = 0.001
epochs = 50
step_size = 10
gamma = 0.5

modes = 12
width = 64 # 64
fno = FNO1d_bn(modes, width).to(device)  # model

optimizer = torch.optim.Adam(fno.parameters(), lr=learning_rate, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min',           # Minimize validation loss
    factor=0.5,          # Multiply LR by 0.5
    patience=5,          # Wait 5 epochs before reducing
    min_lr=1e-6         # Don't go below this
)

In [18]:
# Define the error function
def relative_l2_error(pred, true):
    diff_norm = torch.norm(pred - true, p=2, dim=1)
    true_norm = torch.norm(true, p=2, dim=1)
    return torch.mean(diff_norm / true_norm) * 100

l = nn.MSELoss()  
freq_print = 1

for epoch in range(epochs):
    fno.train()
    train_mse = 0.0
    for step, (time_batch, input_batch, output_batch) in enumerate(training_set):
        optimizer.zero_grad()
        output_pred_batch = fno(input_batch, time_batch).squeeze(-1)
        loss_f = l(output_pred_batch, output_batch)
        loss_f.backward()
        optimizer.step()
        train_mse += loss_f.item()
    train_mse /= len(training_set)


    # Validation with CORRECT L2 error
    with torch.no_grad():
        fno.eval()
        test_relative_l2 = 0.0
        for step, (time_batch, input_batch, output_batch) in enumerate(testing_set):
            output_pred_batch = fno(input_batch, time_batch).squeeze(-1)
            # Use the correct relative L2 error
            loss_f = relative_l2_error(output_pred_batch, output_batch)
            test_relative_l2 += loss_f.item()
        test_relative_l2 /= len(testing_set)
    scheduler.step(test_relative_l2)


    if epoch % freq_print == 0: print("######### Epoch:", epoch, " ######### Train Loss:", train_mse, " ######### Relative L2 Test Norm:", test_relative_l2)

######### Epoch: 0  ######### Train Loss: 0.069055861083325  ######### Relative L2 Test Norm: 30.784061431884766
######### Epoch: 1  ######### Train Loss: 0.013917891695746221  ######### Relative L2 Test Norm: 25.36121940612793
######### Epoch: 2  ######### Train Loss: 0.01021944034146145  ######### Relative L2 Test Norm: 23.08533541361491
######### Epoch: 3  ######### Train Loss: 0.009151038801064715  ######### Relative L2 Test Norm: 21.494194984436035
######### Epoch: 4  ######### Train Loss: 0.007858064035826829  ######### Relative L2 Test Norm: 18.88886610666911
######### Epoch: 5  ######### Train Loss: 0.006972958079131786  ######### Relative L2 Test Norm: 18.218523025512695
######### Epoch: 6  ######### Train Loss: 0.00868938402272761  ######### Relative L2 Test Norm: 20.032304445902508
######### Epoch: 7  ######### Train Loss: 0.006984457839280367  ######### Relative L2 Test Norm: 17.835336685180664
######### Epoch: 8  ######### Train Loss: 0.0074571165634552015  ######### Relat

In [12]:
# Check GPU memory before training
if torch.cuda.is_available():
    print(f"GPU memory before training: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    print(f"Model on device: {next(fno.parameters()).device}")
    print(f"Training data on device: {next(iter(training_set))[1].device}")

GPU memory before training: 47.62 MB
Model on device: cuda:0
Training data on device: cuda:0


In [None]:
# save trained model
# torch.save(fno.state_dict(), "fno1d_bn_model.pth")


### Test model

In [13]:
# Load test data at resolution 128
test_data_raw = np.load("data/data_test_128.npy")
spatial_resolution = 128

# Extract initial conditions (t=0) and final time solutions (t=1)
initial_conditions = test_data_raw[:, 0, :]  # Shape: (n_samples, 128)
final_time_ground_truth = test_data_raw[:, -1, :]  # Shape: (n_samples, 128)

# Prepare initial conditions with grid coordinates
grid = torch.linspace(0, 1, spatial_resolution, dtype=torch.float32).reshape(spatial_resolution, 1)
initial_conditions_tensor = torch.from_numpy(initial_conditions).type(torch.float32).reshape(-1, spatial_resolution, 1)

# Add grid coordinates to each sample
initial_conditions_with_grid = torch.cat([initial_conditions_tensor, grid.repeat(initial_conditions_tensor.shape[0], 1, 1)], dim=-1)

# Normalize initial conditions
mean = 0.018484
std = 0.685405
initial_conditions_with_grid[:, :, 0] = (initial_conditions_with_grid[:, :, 0] - mean) / std

# Prepare time tensor (full time span from 0 to 1)
time_full = torch.ones(initial_conditions_with_grid.shape[0], dtype=torch.float32) * 1.0

# Move to device
initial_conditions_with_grid = initial_conditions_with_grid.to(device)
time_full = time_full.to(device)

print(f"Test samples: {initial_conditions_with_grid.shape[0]}")
print(f"Input shape: {initial_conditions_with_grid.shape}")
print(f"Time values: {time_full[0].item()}")

Test samples: 128
Input shape: torch.Size([128, 128, 2])
Time values: 1.0


In [14]:
# Test the model: predict from t=0 to t=1
batch_size_test = 20
n_test_samples = initial_conditions_with_grid.shape[0]

all_predictions = []
fno.eval()
with torch.no_grad():
    for i in range(0, n_test_samples, batch_size_test):
        batch_end = min(i + batch_size_test, n_test_samples)
        batch_input = initial_conditions_with_grid[i:batch_end]
        batch_time = time_full[i:batch_end]
        
        # Predict final time
        predictions = fno(batch_input, batch_time).squeeze(-1)
        all_predictions.append(predictions)

# Concatenate all predictions
all_predictions = torch.cat(all_predictions, dim=0)

# Denormalize predictions
all_predictions_denorm = all_predictions * std + mean

# Convert ground truth to tensor
final_time_ground_truth_tensor = torch.from_numpy(final_time_ground_truth).type(torch.float32).to(device)

# Calculate relative L2 error
relative_l2_error = torch.mean(torch.norm(all_predictions_denorm - final_time_ground_truth_tensor, dim=1) / 
                                torch.norm(final_time_ground_truth_tensor, dim=1)) * 100

print(f"Relative L2 Error: {relative_l2_error:.4f}%")


Relative L2 Error: 38.5847%
