# Task 2: Resolution Invariance

In [None]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import matplotlib.pyplot as plt

## Load Model

In [None]:
from FNO import FNO1d
# import trained model
model_path = "fno_1d_model.pth"
fno = FNO1d(modes=16, width=64)
fno.load_state_dict(torch.load(model_path))

## Test at different Resolutions

### Import Test sets

In [None]:
test_128 = torch.from_numpy(np.load("data/data_test_128.npy")).type(torch.float32)
test_96 = torch.from_numpy(np.load("data/data_test_96.npy")).type(torch.float32)
test_64 = torch.from_numpy(np.load("data/data_test_64.npy")).type(torch.float32)
test_32 = torch.from_numpy(np.load("data/data_test_32.npy")).type(torch.float32)

# extract initial condition
u0_test_128 = test_128[:, 0, :]
u0_test_96 = test_96[:, 0, :]
u0_test_64 = test_64[:, 0, :]
u0_test_32 = test_32[:, 0, :]

# extract solution at t=1.0
u1_test_128 = test_128[:, -1, :]
u1_test_96 = test_96[:, -1, :]
u1_test_64 = test_64[:, -1, :]
u1_test_32 = test_32[:, -1, :]

# Add grid coordinates to input: shape (batch, spatial_points, 2)
u0_test_128_grid = torch.cat([u0_test_128.unsqueeze(-1), torch.linspace(0, 1, 128).reshape(1, 128, 1).repeat(u0_test_128.shape[0], 1, 1)], dim=-1)
u0_test_96_grid = torch.cat([u0_test_96.unsqueeze(-1), torch.linspace(0, 1, 96).reshape(1, 96, 1).repeat(u0_test_96.shape[0], 1, 1)], dim=-1)
u0_test_64_grid = torch.cat([u0_test_64.unsqueeze(-1), torch.linspace(0, 1, 64).reshape(1, 64, 1).repeat(u0_test_64.shape[0], 1, 1)], dim=-1)
u0_test_32_grid = torch.cat([u0_test_32.unsqueeze(-1), torch.linspace(0, 1, 32).reshape(1, 32, 1).repeat(u0_test_32.shape[0], 1, 1)], dim=-1)

# Create separate DataLoaders for each resolution
test_set_128 = DataLoader(TensorDataset(u0_test_128_grid, u1_test_128),  shuffle=False)
test_set_96 = DataLoader(TensorDataset(u0_test_96_grid, u1_test_96),  shuffle=False)
test_set_64 = DataLoader(TensorDataset(u0_test_64_grid, u1_test_64),  shuffle=False)
test_set_32 = DataLoader(TensorDataset(u0_test_32_grid, u1_test_32),  shuffle=False)


### Evaluate at different Resolutions

In [None]:
test_loaders = [test_set_128, test_set_96, test_set_64, test_set_32]
resolutions = [128, 96, 64, 32]
test_relative_l2 = []

fno.eval()
with torch.no_grad():
    for res, test_loader in zip(resolutions, test_loaders):
        relative_l2 = 0.0
        for input_batch, output_batch in test_loader:
            output_pred_batch = fno(input_batch).squeeze(2)
            loss_f = (torch.mean((output_pred_batch - output_batch) ** 2) / torch.mean(output_batch ** 2)) ** 0.5 * 100
            relative_l2 += loss_f.item()
        relative_l2 /= len(test_loader)
        test_relative_l2.append(relative_l2)
        print(f"Resolution {res}: Relative L2 = {relative_l2:.2f}%")

plt.figure(figsize=(8, 5))
plt.plot(resolutions, test_relative_l2, marker='o')
plt.xlabel('Resolution (Number of Spatial Points)')
plt.ylabel('Relative L2 Error (%)') 
plt.xticks(resolutions)
plt.title('FNO Performance Across Different Resolutions')
# add a mark on the training resolution
plt.axvline(x=128, color='r', linestyle='--')
plt.show()