# Task 3: All2All Training

In [64]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, Dataset
import torch.nn.functional as F
import matplotlib.pyplot as plt

## Model Definition

In [65]:
from FNO_bn import FNO1d_bn

In [66]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


## Import Data

In [67]:
torch.manual_seed(0)
np.random.seed(0)

In [68]:
n_train = 1024 # number of training samples

# train dataset shape: (1024, 5, 128)
# 1024: number of trajectories
# 5: time snapthots of the solution: t= 0, 0.25, 0.5, 0.75, 1.0
# 128: spatial resolution of the data

train_dataset = torch.from_numpy(np.load("data/data_train_128.npy")).type(torch.float32)
test_dataset = torch.from_numpy(np.load("data/data_test_128.npy")).type(torch.float32)
# add time as input feature
time_train = torch.linspace(0, 1, train_dataset.shape[1]).reshape(1, -1, 1)
time_test = torch.linspace(0, 1, test_dataset.shape[1]).reshape(1, -1, 1)
train_dataset = torch.cat([train_dataset, time_train.repeat(train_dataset.shape[0], 1, 1)], dim=-1)
test_dataset = torch.cat([test_dataset, time_test.repeat(test_dataset.shape[0], 1, 1)], dim=-1)

# add grid coordinates as input feature
grid_train = torch.linspace(0, 1, train_dataset.shape[2]).reshape(1, 1, -1)
grid_test = torch.linspace(0, 1, test_dataset.shape[2]).reshape(1, 1, -1)
train_dataset = torch.cat([train_dataset, grid_train.repeat(train_dataset.shape[0], train_dataset.shape[1], 1)], dim=-1)
test_dataset = torch.cat([test_dataset, grid_test.repeat(test_dataset.shape[0], test_dataset.shape[1], 1)], dim=-1)

# Move data to device
train_dataset = train_dataset.to(device)
test_dataset = test_dataset.to(device)

batch_size = 20
train_loader = DataLoader(TensorDataset(train_dataset), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(TensorDataset(test_dataset), batch_size=batch_size, shuffle=False)

In [None]:
class PDEDataset(Dataset):
    def __init__(self,
                 which="training",
                 training_samples = 256,
                 resolution = 128,
                 device='cpu'):

        self.resolution = resolution
        self.device = device
        self.data = np.load(f"data/data_train_{resolution}.npy")

        self.T = 5
        # Precompute all possible (t_initial, t_final) pairs within the specified range.
        self.time_pairs = [(i, j) for i in range(0, self.T) for j in range(i + 1, self.T)]
        self.len_times  = len(self.time_pairs)

        # Total samples available in the dataset
        total_samples = self.data.shape[0]
        self.n_val = 32
        self.n_test = 32

        if which == "training":
            self.length = training_samples * self.len_times
            self.start_sample = 0
        elif which == "validation":
            self.length = self.n_val * self.len_times
            self.start_sample = total_samples - self.n_val - self.n_test
        elif which == "test":
            self.length = self.n_test * self.len_times
            self.start_sample = total_samples - self.n_test

        self.mean = 0
        self.std  = 0.3835
        
        # Pre-create grid to avoid recreating it each time
        self.grid = torch.linspace(0, 1, 128, dtype=torch.float32).reshape(128, 1).to(device)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        sample_idx = self.start_sample + index // self.len_times
        time_pair_idx = index % self.len_times
        t_inp, t_out = self.time_pairs[time_pair_idx]
        time = torch.tensor((t_out - t_inp)/5. + float(np.random.rand(1)[0]/10**6), dtype=torch.float32, device=self.device)

        inputs = torch.from_numpy(self.data[sample_idx, t_inp]).type(torch.float32).reshape(128, 1).to(self.device)
        inputs = (inputs - self.mean)/self.std #Normalize
        
        # Add grid coordinates (already on correct device and dtype)
        inputs = torch.cat((inputs, self.grid), dim=-1)  # (128, 2)

        outputs = torch.from_numpy(self.data[sample_idx, t_out]).type(torch.float32).reshape(128).to(self.device)
        outputs = (outputs - self.mean)/self.std #Normalize

        return time, inputs, outputs

### Instantiate Model

In [None]:
n_train = 1024 # Number of TRAJECTORIES for training
batch_size = 20

training_set = DataLoader(PDEDataset("training", n_train, device=device), batch_size=batch_size, shuffle=True)
testing_set = DataLoader(PDEDataset("validation", device=device), batch_size=batch_size, shuffle=False)

learning_rate = 0.001
epochs = 5
step_size = 2
gamma = 0.5

modes = 16
width = 64
fno = FNO1d_bn(modes, width).to(device)  # model

optimizer = torch.optim.Adam(fno.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
l = nn.L1Loss()
freq_print = 1
for epoch in range(epochs):
    fno.train()
    train_mse = 0.0
    for step, (time_batch, input_batch, output_batch) in enumerate(training_set):
        optimizer.zero_grad()
        output_pred_batch = fno(input_batch, time_batch).squeeze(-1)
        loss_f = l(output_pred_batch, output_batch)
        loss_f.backward()
        optimizer.step()
        train_mse += loss_f.item()
    train_mse /= len(training_set)

    scheduler.step()

    with torch.no_grad():
        fno.eval()
        test_relative_l2 = 0.0
        for step, (time_batch, input_batch, output_batch) in enumerate(testing_set):
            output_pred_batch = fno(input_batch, time_batch).squeeze(-1)
            loss_f = (torch.mean((abs(output_pred_batch - output_batch))) / torch.mean(abs(output_batch))) * 100
            test_relative_l2 += loss_f.item()
        test_relative_l2 /= len(testing_set)

    if epoch % freq_print == 0: print("######### Epoch:", epoch, " ######### Train Loss:", train_mse, " ######### Relative L1 Test Norm:", test_relative_l2)

RuntimeError: mat1 and mat2 must have the same dtype, but got Double and Float