Fourier Neural Operator Code
The goal of this notebook is to implement an FNO for a test dataset. Eventually this code will be used for decoding turbulence simulations

Data Preparation


In [1]:
import torch
import jax.numpy as jnp
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [2]:
data_path = "/Users/carsonmcvay/desktop/gradschool/research/turbulence_encryption/forcing_functions/multi_forcing_simulations_combined_viscneg2.npz"
data = jnp.load(data_path)
print("keys in dataset:", data.keys())

keys in dataset: KeysView(NpzFile '/Users/carsonmcvay/desktop/gradschool/research/turbulence_encryption/forcing_functions/multi_forcing_simulations_combined_viscneg2.npz' with keys: inputs, outputs, metadata)


In [3]:
# Simulated dataset generation
# Inputs: Evolved states (e.g., velocity/pressure fields over time)
# Targets: Forcing functions (spatially/temporally varying fields)


# Example:
data_path = "/Users/carsonmcvay/desktop/gradschool/research/turbulence_encryption/forcing_functions/multi_forcing_simulations_combined_viscneg2.npz"
data = jnp.load(data_path)
inputs = data["inputs"]
inputs = jnp.fft.irfft(inputs, axis=-1)  # Perform inverse FFT along the last axis # adding this to try to debug
outputs = data["outputs"]

print("Transformed Inputs Shape:", inputs.shape)





Transformed Inputs Shape: (21, 256, 256)


Split training and test data

In [4]:
X_train, X_test, y_train, y_test = train_test_split(inputs, outputs, test_size=0.2, random_state=42)

Pass data to the FNO

In [5]:
# Convert data from JAX arrays to NumPy
# doing this before passing to PyTorch
# Convert JAX arrays to NumPy
X_train_np = np.array(jnp.asarray(X_train).block_until_ready())  # Ensures conversion to NumPy
y_train_np = np.array(jnp.asarray(y_train).block_until_ready())
X_test_np = np.array(jnp.asarray(X_test).block_until_ready())
y_test_np = np.array(jnp.asarray(y_test).block_until_ready())



Debugging! (Kill me now)

In [6]:
# Debugging shapes and types
print("X_train_np type:", type(X_train_np), "shape:", X_train_np.shape)
print("y_train_np type:", type(y_train_np), "shape:", y_train_np.shape)


X_train_np type: <class 'numpy.ndarray'> shape: (16, 256, 256)
y_train_np type: <class 'numpy.ndarray'> shape: (16, 256, 256, 2)


Okay they are still jax arrays, so this is bad. Must convert explicitly to NumPy. :(
Yay it is fixed and I do not have to be ritually sacrificed! 

In [7]:
# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train_np, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train_np, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test_np, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test_np, dtype=torch.float32)

# Create data loaders
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [8]:
# FNO Architecture
import torch
import torch.nn as nn
import torch.fft

class FNO2D(nn.Module):
    def __init__(self, modes, width, input_dim, output_dim):
        """
        Args: 
        modes: number of Fourier modes to retain
        width: Feature width for the network
        """
        super(FNO2D, self).__init__()
        self.modes = modes
        self.width = width
        self.height, self.width_grid = 256, 256  # Assuming a fixed grid size

        # # input and output transforations
        # self.input_layer = nn.Linear(input_dim, width)
        # self.output_layer = nn.Linear(width, output_dim)

        # Input and output transformations
        # chatgpt claims this will fix it
        self.input_layer = nn.Linear(input_dim, self.height * self.width_grid)
        self.output_layer = nn.Linear(self.height * self.width_grid, output_dim)

        # Fourier layers
        self.fourier_layers = nn.ModuleList([
            FourierLayer(modes,width) for i in range(4)
        ])

        # Non linearity
        self.activation = nn.ReLU()
    
    def forward(self, x):
        # Flatten input for the fully connected layer
        batch_size = x.shape[0]
        # print(f"Input shape to FNO2D: {x.shape}")
        x = x.view(batch_size, -1)  # Flatten (batch, height*width)
        # print(f"Shape after flattening: {x.shape}")

        # Input transformation
        x = self.input_layer(x)
        # print(f"Shape after input_layer: {x.shape}")
        # a "fix"
        x = x.view(batch_size, self.height, self.width_grid)  # Reshape to 2D grid
        # print(f"Shape after reshaping to 2D grid: {x.shape}")

        # pass through Fourier layers
        for layer in self.fourier_layers:
            x = layer(x)
            # print(f"Shape after Fourier layer: {x.shape}")
            x = self.activation(x)

        
        # Flatten back to 1D before the output layer
        x = x.view(batch_size, -1)  # (batch, height * width)
        # print(f"Shape before output_layer: {x.shape}")
        # output transformation
        x = self.output_layer(x) # (batch, height, width, forcing_channels)
        # print(f"Shape after output_layer: {x.shape}")
        return x.view(batch_size, 256, 256, 2)  # Reshape to match the output (batch, H, W, 2)

class FourierLayer(nn.Module):
    def __init__(self, modes, width):
        super(FourierLayer, self).__init__()
        self.modes = modes
        self.width = width

        # Learnable Foureier filters
        self.weights = nn.Parameter(torch.randn(width, modes, modes, dtype=torch.cfloat))

    def forward(self, x):
        # Reshape input for FFT
        batch_size, seq_len = x.shape  # Flattened input
        height, width = 256, 256  # Replace with your grid dimensions
        assert seq_len == height * width, f"Input size mismatch: got {seq_len}, expected {height * width}"
        
        x = x.view(batch_size, height, width)  # Reshape to 2D grid
        
        # Fourer transform of input
        x_ft = torch.fft.fft2(x, dim=(-2, -1)) # FFT over spatial dimensions

        # Retain only the first modes Fourier modes
        x_ft = x_ft[...,:self.modes, :self.modes] * self.weights

        # Inverse fft
        x = torch.fft.ifft2(x_ft, dim=(-2,-1)).real
        
        

    #    Flatten back to a 1D vector
        return x.view(batch_size, -1)

In [9]:
print("Inputs shape:", inputs.shape)  # Should be (num_samples, 256, 256)
print("Outputs shape:", outputs.shape)  # Should be (num_samples, 256, 256, 2)



Inputs shape: (21, 256, 256)
Outputs shape: (21, 256, 256, 2)


In [None]:
# Training Loop

height, width = 256, 256  # Replace with your grid dimensions
input_dim = height * width  # Flattened input size
output_dim = height * width * 2  # Flattened output size with 2 channels (forcing_x and forcing_y)

# initialize model, optimizer, and loss function
fno_model = FNO2D(modes=16, width=64, input_dim=input_dim, output_dim=output_dim)
optimizer = torch.optim.Adam(fno_model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

# training loop
num_epochs = 100
for epoch in range(num_epochs):
    for evolved_state, forcing_function in train_loader: #assuming a data loader
        optimizer.zero_grad()

        # forward pass
        predicted_forcing = fno_model(evolved_state)

        # Compute loss
        loss = loss_fn(predicted_forcing, forcing_function)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, loss: {loss.item()}")



I hate all shapes and matrix multiplicaiton. I must fix this error

In [None]:
# Evaluation

# test on unseen data
with torch.no_grad():
    for evolved_state, true_forcing in test_loader:
        predicted_forcing = fno_model(evolved_state)

        # compute error or visualize
        error = loss_fn(predicted_forcing, true_forcing)
        print(f"Test Error: {error.item()}")