# Task 4 Finetuning

In [5]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, Dataset
import matplotlib.pyplot as plt
import copy

In [6]:
from FNO_bn import FNO1d_bn

In [7]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [8]:
from FNO_bn import FNO1d_bn
# import trained model
model_path = "fno1d_bn_model.pth"
fno = FNO1d_bn(modes=16, width=64)
fno.load_state_dict(torch.load(model_path))
fno.to(device)

  fno.load_state_dict(torch.load(model_path))


FNO1d_bn(
  (linear_p): Linear(in_features=3, out_features=64, bias=True)
  (spect1): SpectralConv1d()
  (spect2): SpectralConv1d()
  (spect3): SpectralConv1d()
  (lin0): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (lin1): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (lin2): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (batch_norm1): FILM(
    (inp2scale): Linear(in_features=1, out_features=64, bias=True)
    (inp2bias): Linear(in_features=1, out_features=64, bias=True)
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (batch_norm2): FILM(
    (inp2scale): Linear(in_features=1, out_features=64, bias=True)
    (inp2bias): Linear(in_features=1, out_features=64, bias=True)
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (batch_norm3): FILM(
    (inp2scale): Linear(in_features=1, out_features=64, bias=True)
    (inp2bias): Linear(in_features=1, out_features=64, bias=True)
    (norm):

In [9]:
torch.manual_seed(0)
np.random.seed(0)

## Zero shot test on unknown distribution

### Import Data

In [10]:
# Load test data at resolution 128
test_data_raw = np.load("data/data_test_unknown_128.npy")
spatial_resolution = 128

# Extract initial conditions (t=0) and final time solutions (t=1)
initial_conditions = test_data_raw[:, 0, :]  # Shape: (n_samples, 128)
final_time_ground_truth = test_data_raw[:, -1, :]  # Shape: (n_samples, 128)

# Prepare initial conditions with grid coordinates
grid = torch.linspace(0, 1, spatial_resolution, dtype=torch.float32).reshape(spatial_resolution, 1)
initial_conditions_tensor = torch.from_numpy(initial_conditions).type(torch.float32).reshape(-1, spatial_resolution, 1)

# Add grid coordinates to each sample
initial_conditions_with_grid = torch.cat([initial_conditions_tensor, grid.repeat(initial_conditions_tensor.shape[0], 1, 1)], dim=-1)

# Normalize initial conditions
mean = 0.018484
std = 0.685405
initial_conditions_with_grid[:, :, 0] = (initial_conditions_with_grid[:, :, 0] - mean) / std

# Prepare time tensor (full time span from 0 to 1)
time_full = torch.ones(initial_conditions_with_grid.shape[0], dtype=torch.float32) * 1.0

# Move to device
initial_conditions_with_grid = initial_conditions_with_grid.to(device)
time_full = time_full.to(device)

print(f"Test samples: {initial_conditions_with_grid.shape[0]}")
print(f"Input shape: {initial_conditions_with_grid.shape}")
print(f"Time values: {time_full[0].item()}")

Test samples: 128
Input shape: torch.Size([128, 128, 2])
Time values: 1.0


### Run test

In [11]:
# Test the model: predict from t=0 to t=1
batch_size_test = 20
n_test_samples = initial_conditions_with_grid.shape[0]

all_predictions = []
fno.eval()
with torch.no_grad():
    for i in range(0, n_test_samples, batch_size_test):
        batch_end = min(i + batch_size_test, n_test_samples)
        batch_input = initial_conditions_with_grid[i:batch_end]
        batch_time = time_full[i:batch_end]
        
        # Predict final time
        predictions = fno(batch_input, batch_time).squeeze(-1)
        all_predictions.append(predictions)

# Concatenate all predictions
all_predictions = torch.cat(all_predictions, dim=0)

# Denormalize predictions
all_predictions_denorm = all_predictions * std + mean

# Convert ground truth to tensor
final_time_ground_truth_tensor = torch.from_numpy(final_time_ground_truth).type(torch.float32).to(device)

# Calculate relative L2 error
relative_l2_error_zero_shot = torch.mean(torch.norm(all_predictions_denorm - final_time_ground_truth_tensor, dim=1) / 
                                torch.norm(final_time_ground_truth_tensor, dim=1)) * 100

print(f"Relative L2 Error: {relative_l2_error_zero_shot:.4f}%")

Relative L2 Error: 12.5500%


## Finetuning

### Data Preparation

In [12]:
# Check actual data statistics
train_data = np.load("data/data_finetune_train_unknown_128.npy")
print(f"Training data shape: {train_data.shape}")
print(f"Training data mean: {train_data.mean():.6f}")
print(f"Training data std: {train_data.std():.6f}")
print(f"Training data min: {train_data.min():.6f}")
print(f"Training data max: {train_data.max():.6f}")

Training data shape: (32, 5, 128)
Training data mean: -0.034377
Training data std: 0.358928
Training data min: -1.549701
Training data max: 1.362039


In [28]:
class PDEDataset(Dataset):
    def __init__(self,
                 which="training",
                 training_samples = 32,
                 resolution = 128,
                 device='cpu'):

        self.resolution = resolution
        self.device = device
        self.data = np.load(f"data/data_finetune_train_unknown_{resolution}.npy")

        self.T = 5
        # Precompute all possible (t_initial, t_final) pairs within the specified range.
        self.time_pairs = [(i, j) for i in range(0, self.T) for j in range(i + 1, self.T)]
        self.len_times  = len(self.time_pairs)

        # Total samples available in the dataset
        total_samples = self.data.shape[0]
        self.n_val = 32
        self.n_test = 32

        if which == "training":
            self.length = training_samples * self.len_times
            self.start_sample = 0
        elif which == "validation":
            self.length = self.n_val * self.len_times
            self.start_sample = total_samples - self.n_val - self.n_test
        elif which == "test":
            self.length = self.n_test * self.len_times
            self.start_sample = total_samples - self.n_test

        # self.mean = -0.034377
        # self.std  = 0.358928

        self.mean = 0.018484
        self.std  = 0.685405
        
        # Pre-create grid to avoid recreating it each time
        self.grid = torch.linspace(0, 1, 128, dtype=torch.float32).reshape(128, 1).to(device)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        sample_idx = self.start_sample + index // self.len_times
        time_pair_idx = index % self.len_times
        t_inp, t_out = self.time_pairs[time_pair_idx]
        time = torch.tensor((t_out - t_inp)/4. + float(np.random.rand(1)[0]/10**6), dtype=torch.float32, device=self.device)

        inputs = torch.from_numpy(self.data[sample_idx, t_inp]).type(torch.float32).reshape(128, 1).to(self.device)
        inputs = (inputs - self.mean)/self.std #Normalize
        
        # Add grid coordinates (already on correct device and dtype)
        inputs = torch.cat((inputs, self.grid), dim=-1)  # (128, 2)

        outputs = torch.from_numpy(self.data[sample_idx, t_out]).type(torch.float32).reshape(128).to(self.device)
        outputs = (outputs - self.mean)/self.std #Normalize

        return time, inputs, outputs

### Instantiate Model

In [29]:
n_train = 32 # Number of TRAJECTORIES for training
batch_size = 16

training_set = DataLoader(PDEDataset("training", n_train, device=device), batch_size=batch_size, shuffle=True)
testing_set = DataLoader(PDEDataset("validation", device=device), batch_size=batch_size, shuffle=False)

learning_rate = 0.001
epochs = 30
step_size = 10
gamma = 0.5

# copy the pre-trained model
fno_finetune = copy.deepcopy(fno).to(device)

optimizer = torch.optim.Adam(fno_finetune.parameters(), lr=learning_rate, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min',           # Minimize validation loss
    factor=0.5,          # Multiply LR by 0.5
    patience=5,          # Wait 5 epochs before reducing
    min_lr=1e-6         # Don't go below this
)

### Training Loop

In [30]:
# Define the error function
def relative_l2_error(pred, true):
    diff_norm = torch.norm(pred - true, p=2, dim=1)
    true_norm = torch.norm(true, p=2, dim=1)
    return torch.mean(diff_norm / true_norm) * 100

l = nn.MSELoss()  
freq_print = 1

for epoch in range(epochs):
    fno_finetune.train()
    train_mse = 0.0
    for step, (time_batch, input_batch, output_batch) in enumerate(training_set):
        optimizer.zero_grad()
        output_pred_batch = fno_finetune(input_batch, time_batch).squeeze(-1)
        loss_f = l(output_pred_batch, output_batch)
        loss_f.backward()
        optimizer.step()
        train_mse += loss_f.item()
    train_mse /= len(training_set)


    # Validation with CORRECT L2 error
    with torch.no_grad():
        fno_finetune.eval()
        test_relative_l2 = 0.0
        for step, (time_batch, input_batch, output_batch) in enumerate(testing_set):
            output_pred_batch = fno_finetune(input_batch, time_batch).squeeze(-1)
            # Use the correct relative L2 error
            loss_f = relative_l2_error(output_pred_batch, output_batch)
            test_relative_l2 += loss_f.item()
        test_relative_l2 /= len(testing_set)
    scheduler.step(test_relative_l2)


    if epoch % freq_print == 0: print("######### Epoch:", epoch, " ######### Train Loss:", train_mse, " ######### Relative L2 Test Norm:", test_relative_l2)

######### Epoch: 0  ######### Train Loss: 0.03990051739383489  ######### Relative L2 Test Norm: 30.442450523376465
######### Epoch: 1  ######### Train Loss: 0.014307809830643236  ######### Relative L2 Test Norm: 14.91901535987854
######### Epoch: 2  ######### Train Loss: 0.007367103558499366  ######### Relative L2 Test Norm: 14.380645751953125
######### Epoch: 3  ######### Train Loss: 0.005318381142569706  ######### Relative L2 Test Norm: 14.4850914478302
######### Epoch: 4  ######### Train Loss: 0.004313120321603492  ######### Relative L2 Test Norm: 10.085830330848694
######### Epoch: 5  ######### Train Loss: 0.002903283847263083  ######### Relative L2 Test Norm: 7.209062027931213
######### Epoch: 6  ######### Train Loss: 0.003120471228612587  ######### Relative L2 Test Norm: 10.382948684692384
######### Epoch: 7  ######### Train Loss: 0.002937784866662696  ######### Relative L2 Test Norm: 11.992182946205139
######### Epoch: 8  ######### Train Loss: 0.003220819536363706  ######### Rel

### Test finetuned model

In [31]:
all_predictions = []
fno_finetune.eval()
with torch.no_grad():
    for i in range(0, n_test_samples, batch_size_test):
        batch_end = min(i + batch_size_test, n_test_samples)
        batch_input = initial_conditions_with_grid[i:batch_end]
        batch_time = time_full[i:batch_end]
        
        # Predict final time
        predictions = fno_finetune(batch_input, batch_time).squeeze(-1)
        all_predictions.append(predictions)

# Concatenate all predictions
all_predictions = torch.cat(all_predictions, dim=0)

# Denormalize predictions
all_predictions_denorm = all_predictions * std + mean

# Convert ground truth to tensor
final_time_ground_truth_tensor = torch.from_numpy(final_time_ground_truth).type(torch.float32).to(device)

# Calculate relative L2 error
relative_l2_error_finetuned = torch.mean(torch.norm(all_predictions_denorm - final_time_ground_truth_tensor, dim=1) / 
                                torch.norm(final_time_ground_truth_tensor, dim=1)) * 100

print(f"Relative L2 Error: {relative_l2_error_finetuned:.4f}%")

Relative L2 Error: 10.4411%


## Bonus: Train new model on unknown distribution

### Instantiate Model

In [34]:
n_train = 32 # Number of TRAJECTORIES for training
batch_size = 16

training_set = DataLoader(PDEDataset("training", n_train, device=device), batch_size=batch_size, shuffle=True)
testing_set = DataLoader(PDEDataset("validation", device=device), batch_size=batch_size, shuffle=False)

learning_rate = 0.001
epochs = 50
step_size = 10
gamma = 0.5

modes = 16
width = 64
fno_new = FNO1d_bn(modes, width).to(device)

optimizer = torch.optim.Adam(fno_new.parameters(), lr=learning_rate, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min',           # Minimize validation loss
    factor=0.5,          # Multiply LR by 0.5
    patience=5,          # Wait 5 epochs before reducing
    min_lr=1e-6         # Don't go below this
)

### Training loop

In [35]:
for epoch in range(epochs):
    fno_new.train()
    train_mse = 0.0
    for step, (time_batch, input_batch, output_batch) in enumerate(training_set):
        optimizer.zero_grad()
        output_pred_batch = fno_new(input_batch, time_batch).squeeze(-1)
        loss_f = l(output_pred_batch, output_batch)
        loss_f.backward()
        optimizer.step()
        train_mse += loss_f.item()
    train_mse /= len(training_set)


    # Validation with CORRECT L2 error
    with torch.no_grad():
        fno_new.eval()
        test_relative_l2 = 0.0
        for step, (time_batch, input_batch, output_batch) in enumerate(testing_set):
            output_pred_batch = fno_new(input_batch, time_batch).squeeze(-1)
            # Use the correct relative L2 error
            loss_f = relative_l2_error(output_pred_batch, output_batch)
            test_relative_l2 += loss_f.item()
        test_relative_l2 /= len(testing_set)
    scheduler.step(test_relative_l2)


    if epoch % freq_print == 0: print("######### Epoch:", epoch, " ######### Train Loss:", train_mse, " ######### Relative L2 Test Norm:", test_relative_l2)

######### Epoch: 0  ######### Train Loss: 0.07147261807695031  ######### Relative L2 Test Norm: 31.054347133636476
######### Epoch: 1  ######### Train Loss: 0.025299772527068852  ######### Relative L2 Test Norm: 21.186342763900758
######### Epoch: 2  ######### Train Loss: 0.015365667385049164  ######### Relative L2 Test Norm: 16.70710678100586
######### Epoch: 3  ######### Train Loss: 0.013853256381116808  ######### Relative L2 Test Norm: 15.826114940643311
######### Epoch: 4  ######### Train Loss: 0.012427917285822331  ######### Relative L2 Test Norm: 14.757332372665406
######### Epoch: 5  ######### Train Loss: 0.00663308686343953  ######### Relative L2 Test Norm: 15.404463958740234
######### Epoch: 6  ######### Train Loss: 0.01010068696923554  ######### Relative L2 Test Norm: 12.60625729560852
######### Epoch: 7  ######### Train Loss: 0.01023712222231552  ######### Relative L2 Test Norm: 22.39613308906555
######### Epoch: 8  ######### Train Loss: 0.013925889716483652  ######### Relat

### Test new model

In [36]:
all_predictions = []
fno_new.eval()
with torch.no_grad():
    for i in range(0, n_test_samples, batch_size_test):
        batch_end = min(i + batch_size_test, n_test_samples)
        batch_input = initial_conditions_with_grid[i:batch_end]
        batch_time = time_full[i:batch_end]
        
        # Predict final time
        predictions = fno_new(batch_input, batch_time).squeeze(-1)
        all_predictions.append(predictions)

# Concatenate all predictions
all_predictions = torch.cat(all_predictions, dim=0)

# Denormalize predictions
all_predictions_denorm = all_predictions * std + mean

# Convert ground truth to tensor
final_time_ground_truth_tensor = torch.from_numpy(final_time_ground_truth).type(torch.float32).to(device)

# Calculate relative L2 error
relative_l2_error_new = torch.mean(torch.norm(all_predictions_denorm - final_time_ground_truth_tensor, dim=1) / 
                                torch.norm(final_time_ground_truth_tensor, dim=1)) * 100

print(f"Relative L2 Error: {relative_l2_error_new:.4f}%")

Relative L2 Error: 14.0566%
