# Notebook \#2 - Implementation of Fourier Neural Operator

In [46]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
from functools import partial
from torch.fft import fft, ifft
from data import MultiFunctionDatasetODE, custom_collate_ODE_fn

class FourierLayer(nn.Module):
    def __init__(self, in_channels, out_channels, modes):
        super(FourierLayer, self).__init__()
        self.modes = modes  # Number of Fourier modes to keep
        self.scale = 1 / (in_channels * out_channels)
        self.weights = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))

    def forward(self, x):
        x_ft = fft(x, dim=-1)  # Compute FFT
        x_ft = x_ft.clone()  # Fix: Clone to prevent in-place modification
        x_ft[:, :, :self.modes] = x_ft[:, :, :self.modes] * self.weights  # Apply learned filter
        x_out = ifft(x_ft, dim=-1)  # Compute inverse FFT
        return x_out.real  # Return real part

class FNO(nn.Module):
    def __init__(self, modes, width):
        super(FNO, self).__init__()
        self.fc0 = nn.Linear(2, width)  # Expecting (t, u) as input
        self.fourier = FourierLayer(width, width, modes)  # Fourier transform layer
        self.fc1 = nn.Linear(width, width)  
        self.fc2 = nn.Linear(width, 1)  

    def forward(self, t, u):
        x = torch.cat((t.unsqueeze(-1), u.unsqueeze(-1)), dim=-1)  # Shape: (batch, seq_len, 2)
        x = self.fc0(x)  # Shape: (batch, seq_len, width)
        x = x.permute(0, 2, 1)  # Shape: (batch, width, seq_len) for Fourier transform
        x = self.fourier(x)  # Fourier Layer output (batch, width, seq_len)
        x = x.permute(0, 2, 1)  # Convert back to (batch, seq_len, width)
        x = self.fc1(x)  # Shape: (batch, seq_len, width)
        x = self.fc2(x)  # Shape: (batch, seq_len, 1)
        return x.squeeze(-1)  # Final shape: (batch, seq_len)

def compute_loss(model, t, u):
    t = t.clone().detach().requires_grad_(True).cuda()  # Fix: Ensure requires_grad is set correctly

    x_pred = model(t, u)  # Predict x(t)

    # Compute d(x)/dt using autograd
    dx_dt = torch.autograd.grad(x_pred, t, grad_outputs=torch.ones_like(x_pred), create_graph=True)[0]
    
    # Compute PDE residual
    residual = dx_dt - (x_pred - u)
    loss_pde = torch.mean(residual**2)

    # Initial condition loss
    loss_ic = torch.mean((x_pred[0] - 1) ** 2)

    return loss_pde + loss_ic

def train(model, dataset, epochs=1000, lr=0.001, batch_size=64):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0.0
        for u, _, _, _, time_domain, _ in dataloader:
            # Fix: Clone tensors to prevent in-place modification issues
            u, time_domain = u.clone().detach().float().cuda(), time_domain.clone().detach().float().cuda()

            optimizer.zero_grad()
            loss = compute_loss(model, time_domain, u)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss / len(dataloader)}")

# Hyperparameters
modes = 12
width = 64
epochs = 2000
lr = 0.001
batch_size = 64
m = 200   

# Initialize model
model = FNO(modes, width).cuda()

# Dataset parameters
n_functions = 10000
grf_lb = 0.02
grf_ub = 0.5
end_time = 1.0
num_domain = 200
num_initial = 20

dataset = MultiFunctionDatasetODE(
    m=m,
    n_functions=n_functions,
    function_types=['grf', 'linear', 'sine', 'polynomial','constant'],
    end_time=end_time,
    num_domain=num_domain,
    num_initial=num_initial,
    grf_lb=grf_lb,
    grf_ub=grf_ub
)

# Train the model
train(model, dataset, epochs=epochs, lr=lr, batch_size=batch_size)


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [CUDAComplexFloatType [64, 64, 12]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).