# Task4+Bonus

In [92]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from torch.optim import AdamW
import matplotlib.pyplot as plt
import math

torch.manual_seed(0)
np.random.seed(0)

In [93]:
# Spectral (or Fourier) layer in 1d
class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        super(SpectralConv1d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))

    def compl_mul1d(self, input, weights):
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        x_ft = torch.fft.rfft(x)
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x
    
# Time-conditional BN:
class FILM(torch.nn.Module):
    def __init__(self,
                channels,
                use_bn = True):
        super(FILM, self).__init__()
        self.channels = channels

        self.inp2scale = nn.Linear(in_features=1, out_features=channels, bias=True)
        self.inp2bias = nn.Linear(in_features=1, out_features=channels, bias=True)

        self.inp2scale.weight.data.fill_(0)
        self.inp2scale.bias.data.fill_(1)
        self.inp2bias.weight.data.fill_(0)
        self.inp2bias.bias.data.fill_(0)

        if use_bn:
          self.norm = nn.BatchNorm1d(channels)
        else:
          self.norm = nn.Identity()

    def forward(self, x, time):

        x = self.norm(x)
        time = time.reshape(-1,1).type_as(x)
        scale     = self.inp2scale(time)
        bias      = self.inp2bias(time)
        scale = scale.unsqueeze(2).expand_as(x)
        bias  = bias.unsqueeze(2).expand_as(x)

        return x * scale + bias


class SpatialPositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_length=10000):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.max_length = max_length

    def forward(self, x):
        position = x.unsqueeze(-1)  # [batch_size, grid_size, 1, 1]
        div_term = torch.exp(
            torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
            (-math.log(self.max_length) / self.embedding_dim)
        )
        embeddings = torch.zeros(*position.shape[:-1], self.embedding_dim, device=x.device)
        embeddings[..., 0::2] = torch.sin(position * div_term)
        embeddings[..., 1::2] = torch.cos(position * div_term)
        return embeddings  # [batch_size, grid_size, embedding_dim]


class LearnableTimeEmbedding(nn.Module):
    def __init__(self, input_dim=1, embedding_dim=8, hidden_dim=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )

    def forward(self, t):
        batch_size, grid_size, input_dim = t.shape
        t_flat = t.reshape(-1, input_dim)  # [batch_size * grid_size, input_dim]
        embeddings = self.mlp(t_flat)  # [batch_size * grid_size, embedding_dim]
        embeddings = embeddings.reshape(batch_size, grid_size, -1)  # [batch_size, grid_size, embedding_dim]
        return embeddings

def h1_loss_finite_diff(u_pred, u_target, delta_x=1/64):
    l2_loss = torch.mean((u_pred - u_target)**2)

    # Compute first derivative using central differences
    grad_u_pred = (u_pred[:, 2:] - u_pred[:, :-2]) / (2 * delta_x)
    grad_u_target = (u_target[:, 2:] - u_target[:, :-2]) / (2 * delta_x)

    # Pad gradients to match original shape
    grad_u_pred = torch.nn.functional.pad(grad_u_pred, (1, 1), mode='replicate')
    grad_u_target = torch.nn.functional.pad(grad_u_target, (1, 1), mode='replicate')

    # Compute gradient L2 norm
    grad_loss = torch.mean((grad_u_pred - grad_u_target)**2)

    h1_loss = l2_loss + 0.1*grad_loss
    return h1_loss



In [94]:
class FNO1d(nn.Module):
    def __init__(self, modes, width, use_bn=True):
        super(FNO1d, self).__init__()

        self.modes1 = modes
        self.width = width
        self.layers = 128
        self.padding = 1  # pad the domain if input is non-periodic

        self.embedding_dim = 2
        self.spatial_embedding = SpatialPositionalEmbedding(self.embedding_dim)

        self.time_embedding_dim = 2
        self.time_embedding = LearnableTimeEmbedding(embedding_dim=self.time_embedding_dim)

        self.linear_p = nn.Linear(1 + self.embedding_dim + self.time_embedding_dim, self.width)

        # Spectral layers
        self.spect1 = SpectralConv1d(self.width, self.width, self.modes1)
        self.spect2 = SpectralConv1d(self.width, self.width, self.modes1)
        self.spect3 = SpectralConv1d(self.width, self.width, self.modes1)
        self.spect4 = SpectralConv1d(self.width, self.width, self.modes1)

        # Conv1D layers
        self.lin0 = nn.Conv1d(self.width, self.width, 1)
        self.lin1 = nn.Conv1d(self.width, self.width, 1)
        self.lin2 = nn.Conv1d(self.width, self.width, 1)
        self.lin3 = nn.Conv1d(self.width, self.width, 1)

        # Batch normalization layers
        self.bn1 = FILM(self.width, use_bn)
        self.bn2 = FILM(self.width, use_bn)
        self.bn3 = FILM(self.width, use_bn)
        self.bn4 = FILM(self.width, use_bn)

        # Linear transformations
        self.linear_q = nn.Linear(self.width, self.layers)
        self.output_layer = nn.Linear(self.layers, 1)

        self.activation = torch.nn.GELU()

        # Learnable weights for skip connections
        self.skip_weights = nn.Parameter(torch.ones(4)) 

    def fourier_layer(self, x, spectral_layer, conv_layer, batch_norm, time, skip_weight, skip=True):
        x_old = x
        x = spectral_layer(x) + conv_layer(x)
        x = batch_norm(x, time)  
        if skip:
            x = self.activation(skip_weight * x_old + x) 
        else:
            x = self.activation(x)
        return x

    def linear_layer(self, x, linear_transformation):
        return self.activation(linear_transformation(x))

    def forward(self, x, time):
        space = self.spatial_embedding(x[:, :, 1])
        time_embed = self.time_embedding(x[:, :, 2].unsqueeze(-1))
        x = torch.cat((x[:, :, 0].unsqueeze(-1), space, time_embed), dim=-1)
        x = self.linear_p(x)
        x = x.permute(0, 2, 1) 

        x = self.fourier_layer(x, self.spect1, self.lin0, self.bn1, time, self.skip_weights[0], skip=False)
        x = self.fourier_layer(x, self.spect2, self.lin1, self.bn2, time, self.skip_weights[1])
        #x = self.fourier_layer(x, self.spect3, self.lin2, self.bn3, time, self.skip_weights[2])
        #x = self.fourier_layer(x, self.spect4, self.lin3, self.bn4, time, self.skip_weights[3])

        x = x.permute(0, 2, 1)  # Back to (batch, grid, width)
        x = self.linear_layer(x, self.linear_q)
        x = self.output_layer(x)
        return x


In [95]:
#ALL to ALL dataset
def generate_all_to_all(dataset):
    mesh_size = 64
    input_data = []
    target_data = []
    time_data = []

    # Iterate through each sequence in the dataset
    for sequence in dataset:
        num_states = 5 

        # Iterate over all possible input states (u_t)
        for t in range(num_states):
            # Iterate over all possible future states (u_{t+k})
            for k in range(t, num_states):
                # Compute the time difference between t and k
                time = (k - t) * 0.25
                time_data.append([time])  #

                # Retrieve the state at time t (u_t)
                u_t = sequence[t, :]  

                # Generate spatial coordinates for the mesh
                x = torch.linspace(0, 1, mesh_size)
                # Generate a time difference tensor
                t_tensor = time * torch.ones(mesh_size) 

                # Combine u_t and x and t_tensor as the input pair
                input_tensor = torch.stack([u_t, x, t_tensor], dim=0)  
                input_data.append(input_tensor)  

                # Target is the state at time k 
                target_data.append(sequence[k, :])  

    time_data = torch.tensor(time_data)  
    input_data = torch.stack(input_data).permute(0, 2, 1)  
    target_data = torch.stack(target_data).unsqueeze(-1)  

    # Return the generated training pairs
    return time_data, input_data, target_data


In [96]:
num_elements = 128
training_sample = 64
mesh_size = 64

# Load the data
path_train = "FNO - Wave Equation/train_sol.npy"
data = torch.from_numpy(np.load(path_train)).type(torch.float32)

#split into training and validation
random_indices = np.random.choice(num_elements, training_sample, replace=False)
training_dataset_task1 = data[random_indices, :, :]

validation_indices = np.setdiff1d(np.arange(num_elements), random_indices)
validation_dataset_task1 = data[validation_indices, :, :]

time_train, input_train, target_train = generate_all_to_all(training_dataset_task1)
time_validation, input_validation, target_validation = generate_all_to_all(validation_dataset_task1)

print("train input in dataloader", input_train.shape)
print("train target in dataloader", target_train.shape)
print("time train in dataloader", time_train.shape)
print()
print("validation input in dataloader", input_validation.shape)
print("validation target in dataloader", target_validation.shape)
print("time validation in dataloader", time_validation.shape)



train input in dataloader torch.Size([960, 64, 3])
train target in dataloader torch.Size([960, 64, 1])
time train in dataloader torch.Size([960, 1])

validation input in dataloader torch.Size([960, 64, 3])
validation target in dataloader torch.Size([960, 64, 1])
time validation in dataloader torch.Size([960, 1])


In [97]:
#hyperparameters
batch_size = 32
training_set = DataLoader(TensorDataset(time_train,input_train, target_train), batch_size=batch_size, shuffle=True)
validation_set = DataLoader(TensorDataset(time_validation,input_validation, target_validation), batch_size=batch_size, shuffle=False)

learning_rate = 0.01
epochs = 10000 

modes = 16
width = 128
fno = FNO1d(modes, width)

In [98]:
# Initialize optimizer and scheduler
optimizer = AdamW(fno.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-3, max_lr=1e-2, step_size_up=2000, mode='triangular')


# Frequency of printing
freq_print = 10

# Variables to track the best model
best_val_loss = float('inf')
best_model_state = None
best_epoch = 0

# Training loop
for epoch in range(epochs):
    train_mse = 0.0
    fno.train() 
    for step, (time_batch, input_batch, output_batch) in enumerate(training_set):
        optimizer.zero_grad()  
        output_pred_batch = fno(input_batch, time_batch).squeeze(2)  

        loss_f = h1_loss_finite_diff(output_pred_batch, output_batch.squeeze(2))  # H1 loss

        loss_f.backward() 
        optimizer.step() 
        train_mse += loss_f.item() 

    train_mse /= len(training_set)
    scheduler.step() 

    # Validation loop
    with torch.no_grad():
        fno.eval()  
        test_relative_l2 = 0.0
        for step, (time_batch, input_batch, output_batch) in enumerate(validation_set):
            output_pred_batch = fno(input_batch, time_batch).squeeze(2)  
            # Compute relative L2 norm
            loss_f = (torch.mean(torch.norm(output_pred_batch - output_batch.squeeze(2), p=2)) /
                      torch.norm(output_batch.squeeze(2), p=2)) * 100
            test_relative_l2 += loss_f.item()  
        test_relative_l2 /= len(validation_set)  

    # Track the best model
    if test_relative_l2 < best_val_loss:
        best_val_loss = test_relative_l2
        best_model_state = {
            'epoch': epoch,
            'model_state_dict': fno.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss
        }
        best_epoch = epoch

    # Print progress every freq_print epochs
    if epoch % freq_print == 0:
        print(f"######### Epoch: {epoch} ######### Train Loss: {train_mse:.8f} ######### Relative L2 Test Norm: {test_relative_l2:.6f}")



######### Epoch: 0 ######### Train Loss: 0.37727687 ######### Relative L2 Test Norm: 55.247107


KeyboardInterrupt: 

In [None]:
#torch.save(best_model_state, "best_fno_model_t4.pth")
#print(f"Best model saved from epoch {best_epoch} with validation loss {best_val_loss:.4f} as 'best_fno_model_t4.pth'.")

In [99]:
# Carica lo stato salvato
checkpoint = torch.load("best_fno_model_t4.pth", weights_only=True)

# Ripristina lo stato del modello
fno.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

## Testset preparation

In [100]:
test_path = "FNO - Wave Equation/test_sol.npy"
data_test = torch.from_numpy(np.load(test_path)).type(torch.float32)
num_test = len(data_test)
# Generate test data
time_test, input_test, target_test = generate_all_to_all(data_test)
test_set = DataLoader(TensorDataset(time_test, input_test, target_test), batch_size=1, shuffle=False)


#### ALL2ALL testing

In [101]:
# Initialize evaluation
with torch.no_grad():
    fno.eval()  

    num_time_steps = 5
    test_relative_l2_vet = np.zeros(num_time_steps)
    counter_time = np.zeros(num_time_steps)


    # Iterate through test dataset
    for step, (time, input, output) in enumerate(test_set):
        output_pred = fno(input, time).squeeze(2)  
        #loss_f = (torch.mean(torch.norm(output_pred - output.squeeze(2), p=2)) / torch.norm(output.squeeze(2),p=2)) * 100
        loss_f =torch.norm(output_pred - output.squeeze(2), p=2) / torch.norm(output.squeeze(2),p=2) * 100

        # Update counters
        time_index = int(time[0, 0] * 4)
        counter_time[time_index] += 1
        test_relative_l2_vet[time_index] += loss_f.item()

    # Compute mean relative L2 for each time step
    mean_relative_l2 = test_relative_l2_vet / counter_time
    print("samples over time", counter_time)
    for i in range(5):
        print("Mean relative L2 loss at t = ", 0.25*i,":", mean_relative_l2[i])


samples over time [640. 512. 384. 256. 128.]
Mean relative L2 loss at t =  0.0 : 4.252252138219774
Mean relative L2 loss at t =  0.25 : 39.09376021102071
Mean relative L2 loss at t =  0.5 : 36.505533856650196
Mean relative L2 loss at t =  0.75 : 27.7706618309021
Mean relative L2 loss at t =  1.0 : 20.241275161504745


#### OOD testing


In [102]:
test_path = "FNO - Wave Equation/test_sol_OOD.npy"
data_test_OOD = torch.from_numpy(np.load(test_path)).type(torch.float32)
num_test_OOD = len(data_test_OOD)

print("Test OOD dataset shape:", data_test_OOD.shape)

# Split initial and final condition
input_test_OOD = data_test_OOD[:, 0, :].unsqueeze(-1)

target_test_OOD = data_test_OOD[:, -1, :]

# Add spatial mesh to the tensor
x_OOD = torch.linspace(0, 1, mesh_size).repeat(num_test, 1).unsqueeze(-1)  
t_OOD = torch.ones(mesh_size).unsqueeze(0).repeat(num_test_OOD, 1).unsqueeze(-1)
input_test_withtime_OOD = torch.cat((input_test_OOD, x_OOD,t_OOD), dim=-1)



Test OOD dataset shape: torch.Size([128, 2, 64])


In [104]:

with torch.no_grad():
    output_function_test_pred_OOD = fno(input_test_withtime_OOD, torch.ones((num_test_OOD, 1))).squeeze(2)

err_OOD = (torch.mean(torch.norm(output_function_test_pred_OOD - target_test_OOD, p=2)) / torch.norm(target_test_OOD,p=2)) * 100
print("Relative L2 error (OOD):", err_OOD.item())


Relative L2 error (OOD): 46.40522766113281
