# Introduction to Fourier Neural Operators (FNOs)
---
**Authors**: Xuesong (Cedar) Ma, Bernard Chang, and Masa Prodanovic

Last Updated: Apr. 30, 2025

## Objective
In this notebook, we will introduce the components of a ***Fourier Neural Operator (FNO)***. Both a 2D and 3D versions of the FNO will be introduced.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

## Neural Operators

Neural Operators are mappings across discretized function spaces and aim to solve entire parametric families of Partial Differential Equations. In this workshop, we will demonstrate their capabilities through two key applications:

* Predicting static solution fields from input coefficients (e.g., material properties or source terms)
* Mapping from an initial condition to the solution function at a later temporal point

### Background
In mathematical terms, integral operators can be written as,

$(\mathcal{G}a)(x) = \int_{\Omega} K(x, y)a(y)dy$.

Neural operators learn to approximate the kernel function $K(x, y)$. The difference between Neural Operator architectures lies in how they represent the integral operator.

<img src="images/neuraloperator.svg" />



Examples of Neural Operators include:

* DeepONet [[1]](https://doi.org/10.1038/s42256-021-00302-5)
* Fourier Neural Operators [[2]](https://doi.org/10.48550/arXiv.2010.08895)
* Graph Neural Operators [[3]](https://doi.org/10.48550/arXiv.2003.03485)
* Transformer Neural Operators [[4]](https://doi.org/10.48550/arXiv.2405.19166)
* ...

***Note***: Neural operators do not explicitly parameterize $K(x,y)$. Rather, they learn a computational representation that acts like applying the kernel.


**Fourier Neural Operators** are a special class of Neural Operators that leverage the Fast Fourier Transform (FFT) to perform spectral convolution, enabling them to capture long-range dependencies and global structures in the input function space.

FNOs notably enable **zero-shot superresolution** — they can generalize to higher resolution discretizations at inference time.

<img src="images/fno_architecture.png" width="800"/>


## FNO Components
The FNO architecture is not significantly different from traditional neural networks. They are comprised of three main components:

1. **Lifting (Encoding) Block**: Maps the input function to a higher-dimensional space.
    - This is most commonly an affine transformation (nn.Linear).
    - It *could* be more complex (e.g., MLP, Convolutions, etc.)
2. **Fourier Blocks**: Performs spectral convolutions and learns key information in latent space.
3. **Projection (Decoding) Block**: Maps the latent space back to the original function space.
    - This is commonly two affine transformations (with an activation in between).

Here, we will start building the FNO architecture by implementing the spectral branch.

### Fourier Blocks

A Fourier block is a layer that has two branches:

1. *Spectral branch* - Learns global features via FFT
    - Transforms the input into Fourier space via FFT.
    - Applies learned weights in frequency space (spectral convolution)
    - Transforms the output back into the original space via inverse FFT.
2. *Spatial branch* - Learns local features via convolution

These two branches are added together and put through a non-linear activation function (e.g., ReLU).

#### Spectral Convolution


<img src="images/spectral_conv.png" width="800"/>

In [None]:
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        """
        Spectral convolution layer. Performs FFT, linear transform, and Inverse FFT.

        Parameters:
        ---
        in_channels : int,
            Number of layer input channels
        out_channels : int
            Number of layer output channels
        modes1 : int
            Number of Fourier modes to keep in the first dimension
        modes2 : int
            Number of Fourier modes to keep in the second dimension
        """
        super(SpectralConv2d, self).__init__()

        # Number of input channels (features)
        self.in_channels = in_channels
        # Number of output channels
        self.out_channels = out_channels

        # Number of Fourier modes to multiply in each dimension. Maximum floor(N/2) + 1
        self.modes1 = modes1
        self.modes2 = modes2

        # Initialize parameter weights
        self.scale = (1 / (self.in_channels * self.out_channels))
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(self.in_channels, self.out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(self.in_channels, self.out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

    def forward(self, x):
        batchsize = x.shape[0]
        # Compute FFT.
        # Note change in dimension from (batch_size, channels, x, y) to (batch_size, channels, x, y//2 + 1)
        # FFT output is Hermitian symmetric, so we should take the first modes1 and last modes1 to get low frequency components.
        x_ft = torch.fft.rfft2(x)
        # Initialize output FFT tensor
        out_ft = torch.zeros(batchsize, self.out_channels,
                             x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device)
        # Perform complex multiplication on lower Fourier modes
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.complex_multiplication2d(
                x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.complex_multiplication2d(
                x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        # Return to physical space
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x
    @staticmethod
    def complex_multiplication2d(a, b):
        """
        Complex 2D multiplication between input tensor `a` and weight tensor `b`.

        Parameters:
        ---
        a: torch.Tensor,
            Complex input tensor of shape (batch_size, in_channels, x, y)

        b: torch.Tensor, shape (in_channel, out_channel, x, y)
            Complex weight tensor of shape (in_channels, out_channels, x, y)

        Returns:
        ---
        torch.Tensor
            Complex output tensor of shape (batch_size, out_channels, x, y)
        """
        # Sum over in_channels dimension
        return torch.einsum('bixy,ioxy->boxy', a, b)

#### Fourier Layer

<img src="images/fno_block.png" width="800"/>

In [None]:
class FourierBlock2D(nn.Module):
    """
    Single FNO block with spectral convolution, Conv2D, and activation function.

    This block performs the following operations:
    1. Applies a Fourier-based spectral convolution.
    2. Applies a 1x1 convolution in the spatial domain.
    3. Adds a residual connection between the two.
    4. Applies a GELU activation function.

    Parameters:
    ---
        width: int,
            Number of block input/output channels.
        modes1: int,
            Number of Fourier modes to use in the first dimension.
        modes2: int,
            Number of Fourier modes to use in the second dimension.
    """

    def __init__(self, width, modes1, modes2):
        super(FourierBlock2D, self).__init__()
        self.spectral_conv = SpectralConv2d(width, width, modes1, modes2)
        self.w = nn.Conv2d(width, width, 1)

    def forward(self, x):
        device = x.device
        x1 = self.spectral_conv(x)
        x2 = self.w(x)
        x = x1 + x2
        x = F.gelu(x)
        return x

### Full 2D Architecture

In [None]:
class FNO2D(nn.Module):
    """
    Full FNO network. It contains `num_layers` FNO blocks.

    This network performs the following operations:
    1. Lift the input channels to the desired number of channels
    2. Perform `num_layers` layers of the integral operators v' = (W + K)(v)
    3. Project the channel space to the output space

    Input:
    ---
        torch.Tensor, shape (batch_size, x, y, channels=3)
            Coefficients or initial condition and locations (a(x, y), x, y)
    Output:
    ---
        torch.Tensor, shape (batch_size, x, y, channels=1)
            Predicted solution
    """

    def __init__(self, net_name="FNO2D", width=32, num_layers=4, modes1=8, modes2=8, lr=5e-4, hidden_p_channels=128):
        super(FNO2D, self).__init__()
        """
        Parameters:
        ---
            width: int,
                Number of higher-dimensional channels. Default = 20.
            num_layers: int,
                Number of FNO blocks in the network. Default = 4.
            modes1: int,
                Number of Fourier modes to use in the first dimension. Default = 8.
            modes2: int,
                Number of Fourier modes to use in the second dimension. Default = 8.
            hidden_p_channels: int,
                Number of channels for the hidden layer in the projecting step. Default = 128.
        """
        self.net_name = net_name
        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        self.num_layers = num_layers
        self.hidden_p_channels = hidden_p_channels
        self.lr = lr
        self.padding = 6

        # Define affine transformation to lift 3 channels to `width` channels
        self.p = nn.Linear(3, self.width)

        # Define a list of FourierBlock2D layers
        self.fno_blocks = nn.ModuleList([
            FourierBlock2D(self.width, self.modes1, self.modes2) for _ in range(self.num_layers)
        ])

        # Define affine transformations to project the channel space to the output space
        self.q1 = nn.Linear(self.width, self.hidden_p_channels)
        self.q2 = nn.Linear(self.hidden_p_channels, 1)

        self.save_hyperparameters()

    def forward(self, x):
        # Get the grid of x
        grid = self.get_grid(x.shape, x.device)
        # Concatenate the grid to the input tensor
        x = torch.cat((x, grid), dim=-1)

        # Lift input
        x = self.p(x)
        # Permute the dimensions from (batch_size, x, y, channels) to (batch_size, channels, x, y)
        # nn.Linear operates on the last dimension
        x = x.permute(0, 3, 1, 2)

        # Perform Fourier-based spectral convolution and activation
        for layer in self.fno_blocks:
            x = layer(x)

        # Permute the dimensions from (batch_size, x, y, channels) to (batch_size, channels, x, y)
        x = x.permute(0, 2, 3, 1)

        # Project the channel space to the output space
        x = self.q1(x)
        x = F.gelu(x)
        x = self.q2(x)

        return x

    @staticmethod
    def get_grid(shape, device):
        batchsize, size_x, size_y = shape[:-1]
        # Create grid for x and y coordinates using PyTorch
        gridx = torch.linspace(0, 1, steps=size_x).reshape(
            1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.linspace(0, 1, steps=size_y).reshape(
            1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)

## 3D Architecture

This architecture applies to both 2D + time and 3D static problems.

### Spectral Convolution (3D)

In [None]:
class SpectralConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        super(SpectralConv3d, self).__init__()

        """
        Spectral convolution layer. Performs FFT, linear transform, and Inverse FFT.

        Parameters:
        ---
        in_channels : int,
            Number of layer input channels
        out_channels : int
            Number of layer output channels
        modes1 : int
            Number of Fourier modes to keep in the first dimension
        modes2 : int
            Number of Fourier modes to keep in the second dimension
        modes3 : int
            Number of Fourier modes to keep in the third dimension
        """

        self.in_channels = in_channels
        self.out_channels = out_channels

        # Number of Fourier modes to multiply in each dimension. Maximum floor(N/2) + 1
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights3 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights4 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))

    # Complex multiplication
    @staticmethod
    def complex_multiplication3d(a, b):
        # (batch, in_channel, x,y,z), (in_channel, out_channel, x,y,z) -> (batch, out_channel, x,y,z)
        return torch.einsum("bixyz,ioxyz->boxyz", a, b)

    def forward(self, x):
        batchsize = x.shape[0]
        # Compute FFT.
        # Note change in dimension from (batch_size, channels, x, y) to (batch_size, channels, x, y//2 + 1)
        # FFT output is Hermitian symmetric, so we should take the first modes1/modes2 and last modes1/modes2 to get low frequency components.
        x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1])

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            self.complex_multiplication3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            self.complex_multiplication3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            self.complex_multiplication3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            self.complex_multiplication3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        # Return to physical space
        x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x

### FNO Layer (3D)

In [None]:
class FourierBlock3D(nn.Module):
    """
    Single FNO block with spectral convolution, Conv3D, and activation function.

    This block performs the following operations:
    1. Applies a Fourier-based spectral convolution.
    2. Applies a 1x1 convolution in the spatial domain.
    3. Adds a residual connection between the two.
    4. Applies a GELU activation function.

    Parameters:
    ---
        width: int,
            Number of block input/output channels.
        modes1: int,
            Number of Fourier modes to use in the first dimension.
        modes2: int,
            Number of Fourier modes to use in the second dimension.
        modes3: int,
            Number of Fourier modes to use in the third dimension.
    """
    def __init__(self, width, modes1, modes2, modes3):
        super(FourierBlock3D, self).__init__()
        self.spectral_conv = SpectralConv3d(width, width, modes1, modes2, modes3)
        self.w = nn.Conv3d(width, width, 1)

    def forward(self, x):
        device = x.device
        x1 = self.spectral_conv(x)
        x2 = self.w(x)
        x = x1 + x2
        x = F.gelu(x)
        return x

### Full 3D Architecture

In [None]:
class FNO3D(pl.LightningModule):
    def __init__(self,
                 net_name='FNO3DModel',
                 in_channels=10,
                 out_channels=3,
                 modes1=8,
                 modes2=8,
                 modes3=8,
                 width=20,
                 num_layers=4,
                 lr=1e-3,
                 ):

        super(FNO3D, self).__init__()
        self.net_name = net_name

        self.input_channels = in_channels
        self.output_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.num_layers = num_layers
        self.lr = lr

        self.padding = 6

        self.p = nn.Linear(self.input_channels + 3, self.width)  # input channel is 3: (sigma(x, y, z), x, y, z)

        self.fno_blocks = nn.ModuleList([
            FourierBlock3D(self.width, self.modes1, self.modes2, self.modes3) for _ in range(self.num_layers)
        ])

        self.q1 = nn.Linear(self.width, self.width * 4)
        self.q2 = nn.Linear(self.width * 4, self.output_channels)

        self.save_hyperparameters()

    def forward(self, x):
        # Get the grid of x
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        # Lift input
        x = self.p(x)

        # Permute the dimensions from (batch_size, x, y, z, channels) to (batch_size, channels, x, y, z)
        # nn.Linear operates on the last dimension
        x = x.permute(0, 4, 1, 2, 3)
        x = F.pad(x, [0, self.padding])

        for layer in self.fno_blocks:
            x = layer(x)

        x = x[..., :-self.padding]
        x = x.permute(0, 2, 3, 4, 1)
        x = self.q1(x)
        x = F.gelu(x)
        x = self.q2(x)

        return x.float()

    @staticmethod
    def get_grid(shape, device):
        batchsize, size_x, size_y, size_z = shape[:-1]
        gridx = torch.linspace(0, 1, steps=size_x, dtype=torch.float, device=device)
        gridy = torch.linspace(0, 1, steps=size_y, dtype=torch.float, device=device)
        gridz = torch.linspace(0, 1, steps=size_z, dtype=torch.float, device=device)

        gridx = gridx.view(1, size_x, 1, 1, 1).expand(batchsize, -1, size_y, size_z, -1)
        gridy = gridy.view(1, 1, size_y, 1, 1).expand(batchsize, size_x, -1, size_z, -1)
        gridz = gridz.view(1, 1, 1, size_z, 1).expand(batchsize, size_x, size_y, -1, -1)
        return torch.cat((gridx, gridy, gridz), dim=-1)