<a href="https://colab.research.google.com/github/jaysulk/AFNO-transformer/blob/master/Experiment_1_Hartlrey_vs_Fourier_Neural_Operator_on_2D_Burger's_PDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Experiment 1: Hartlrey Neural Operator on 2D Burger's PDE


###Imports

In [None]:
from google.colab import drive

import os
import math
import pickle
from datetime import timedelta
from timeit import default_timer as timer

import numpy as np
import matplotlib.pyplot as plt
from skimage.filters import gaussian
from tqdm import tqdm
import psutil  # For CPU memory tracking

import torch
from torch import nn, Tensor, vmap
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from typing import List, Optional

###Set CUDA

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
device

device(type='cuda', index=0)

###Mount Google Drive for data and figure saving

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


## Neural Network

In [None]:
class SpectralConv3d(nn.Module):
    """
    A spectral convolution layer that applies the Discrete Hartley Transform (DHT)
    to perform convolution in the frequency domain for 3D tensors.

    This layer takes 3D input tensors and applies a convolution operation using a
    specified number of Hartley modes, incorporating a phase factor W[k].

    Attributes:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        modes1 (int): Number of Hartley modes to use in the first dimension.
        modes2 (int): Number of Hartley modes to use in the second dimension.
        modes3 (int): Number of Hartley modes to use in the third dimension.
        scale (float): Scaling factor for the weights initialization.
        weights1 (nn.Parameter): Learnable weights for the convolution, initialized randomly.
    """

    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        """
        Initialize the SpectralConv3d layer.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            modes1 (int): Number of Hartley modes in the first dimension.
            modes2 (int): Number of Hartley modes in the second dimension.
            modes3 (int): Number of Hartley modes in the third dimension.
        """
        super(SpectralConv3d, self).__init__()  # Call to the parent class constructor
        self.in_channels = in_channels  # Store the number of input channels
        self.out_channels = out_channels  # Store the number of output channels
        self.modes1 = modes1  # Store the number of modes for the first dimension
        self.modes2 = modes2  # Store the number of modes for the second dimension
        self.modes3 = modes3  # Store the number of modes for the third dimension

        self.scale = (1 / (in_channels * out_channels))  # Compute scaling factor
        self.weights1 = nn.Parameter(  # Initialize learnable weights
            self.scale * torch.rand(
                in_channels, out_channels, self.modes1, self.modes2, self.modes3
            )
        )

    def dht_3d(self, x: torch.Tensor) -> torch.Tensor:
        """
        Perform the Discrete Hartley Transform (DHT) on a 3D tensor.

        Args:
            x (torch.Tensor): Input tensor of arbitrary shape. The last three dimensions
                              are assumed to be the dimensions to transform.

        Returns:
            torch.Tensor: Transformed tensor after applying the DHT.
        """
        # Transform over the last three dimensions
        transform_dims = [-3, -2, -1]
        # Compute the FFT without normalization
        X = torch.fft.fftn(x, dim=transform_dims)
        # Compute the DHT by combining real and imaginary parts
        X = X.real - X.imag
        return X

    def idht_3d(self, x: torch.Tensor) -> torch.Tensor:
        """
        Perform the Inverse Discrete Hartley Transform (IDHT) on a 3D tensor.

        Args:
            x (torch.Tensor): Input tensor of arbitrary shape. The last three dimensions
                              are assumed to be the dimensions to transform.

        Returns:
            torch.Tensor: Inverse transformed tensor after applying the IDHT.
        """
        # The inverse DHT is the same as the DHT
        X = self.dht_3d(x)
        # Multiply by the scaling factor
        scaling_factor = x.shape[-3] * x.shape[-2] * x.shape[-1]
        X = X / scaling_factor
        return X

    def compl_mul3d(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        """
        Perform Hartley convolution by flipping, rolling, and combining in the Hartley domain.

        This method follows the convolution approach:
            Z = 0.5 * (X * (Y + Y_flip) + X_flip * (Y - Y_flip))

        Args:
            x1 (torch.Tensor): Hartley transformed input tensor of shape (batch_size, in_channels, D, H, W).
            x2 (torch.Tensor): Hartley transformed weights tensor of shape (in_channels, out_channels, D, H, W).

        Returns:
            torch.Tensor: Hartley transformed output tensor of shape (batch_size, out_channels, D, H, W).
        """
        # Flip and roll x1 and x2 along the spatial dimensions
        # Flip along depth, height, and width (last three dimensions)
        x1_flip = torch.roll(torch.flip(x1, dims=[-3, -2, -1]), shifts=(1, 1, 1), dims=[-3, -2, -1])
        x2_flip = torch.roll(torch.flip(x2, dims=[-3, -2, -1]), shifts=(1, 1, 1), dims=[-3, -2, -1])

        # Compute Y_plus and Y_minus
        Y_plus = x2 + x2_flip  # Shape: (in_channels, out_channels, D, H, W)
        Y_minus = x2 - x2_flip  # Shape: (in_channels, out_channels, D, H, W)

        # Perform element-wise multiplication and sum over in_channels
        # Expand dimensions for broadcasting
        # x1: (batch_size, in_channels, D, H, W) -> (batch_size, in_channels, 1, D, H, W)
        # Y_plus, Y_minus: (in_channels, out_channels, D, H, W) -> (1, in_channels, out_channels, D, H, W)
        X = x1.unsqueeze(2)  # (batch_size, in_channels, 1, D, H, W)
        X_flip = x1_flip.unsqueeze(2)  # (batch_size, in_channels, 1, D, H, W)

        # Compute the products
        term1 = X * Y_plus  # (batch_size, in_channels, out_channels, D, H, W)
        term2 = X_flip * Y_minus  # (batch_size, in_channels, out_channels, D, H, W)

        # Combine the terms
        Z = 0.5 * (term1 + term2)  # (batch_size, in_channels, out_channels, D, H, W)

        # Sum over the input channels to get the final output
        Z = Z.sum(dim=1)  # (batch_size, out_channels, D, H, W)

        return Z


#    def compl_mul3d(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        """
        Perform the complex multiplication of two 3D tensors in the frequency domain.

        This function computes the product of two complex-valued tensors using a specific
        method that involves flipping the tensors and performing Einstein summation.

        Args:
            x1 (torch.Tensor): The first input tensor of shape (batch_size, in_channels, depth, height, width).
            x2 (torch.Tensor): The second input tensor of the same shape as x1.

        Returns:
            torch.Tensor: The result of the complex multiplication, with shape
                          (batch_size, out_channels, depth, height, width).
        """
#        X1_H_k = x1  # Forward tensor for x1
#        X2_H_k = x2  # Forward tensor for x2
#        X1_H_neg_k = self.flip_periodic_3d(x1)  # Flipped tensor for x1
#        X2_H_neg_k = self.flip_periodic_3d(x2)  # Flipped tensor for x2

        # Calculate the result using Einstein summation for complex multiplication
#        result = 0.5 * (
#            torch.einsum('bixyz,ioxyz->boxyz', X1_H_k, X2_H_k) -  # Term 1: Forward multiplication
#            torch.einsum('bixyz,ioxyz->boxyz', X1_H_neg_k, X2_H_neg_k) +  # Term 2: Flipped multiplication
#            torch.einsum('bixyz,ioxyz->boxyz', X1_H_k, X2_H_neg_k) +  # Term 3: Mixed multiplication
#            torch.einsum('bixyz,ioxyz->boxyz', X1_H_neg_k, X2_H_k)  # Term 4: Mixed multiplication
#        )

#        return result  # Return the result of the complex multiplication



    def flip_periodic_3d(self, x: torch.Tensor) -> torch.Tensor:
        """
        Flip the input tensor along all three dimensions, maintaining periodicity.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, channels, depth, height, width).

        Returns:
            torch.Tensor: Flipped tensor of the same shape.
        """
        # Flipping operation to maintain periodicity
        return torch.flip(x, dims=[2, 3, 4])  # Flip along depth, height, and width

    def forward(self, x):
        """
        Forward pass of the spectral convolution layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_channels, depth, height, width).

        The forward pass consists of the following steps:
            1. Compute the batch size and input tensor sizes.
            2. Apply the Discrete Hartley Transform (DHT) to the input tensor to obtain the Hartley coefficients.
            3. Compute the phase factor W[k] for each relevant frequency mode.
            4. Apply the phase factor to the Hartley coefficients.
            5. Initialize the output tensor for the transformed coefficients.
            6. Perform complex multiplication of the relevant Hartley modes using the learned weights.
            7. Apply the Inverse DHT (IDHT) to transform the output back to physical space.
        """
        batchsize = x.shape[0]  # Get the batch size
        size1, size2, size3 = x.shape[-3], x.shape[-2], x.shape[-1]  # Get the sizes of the input tensor

        # Compute Hartley coefficients
        x_ht = self.dht_3d(x)  # Apply the DHT to obtain Hartley coefficients

        # Compute the phase factor W[k] for each mode in all three dimensions
        # W[k] = cos(2 * pi * k / N) + sin(2 * pi * k / N)
        # where k is the mode index and N is the size in that dimension

        # Compute mode indices
        k1 = torch.arange(self.modes1, device=x.device).unsqueeze(1).unsqueeze(2)  # Shape: (modes1, 1, 1)
        k2 = torch.arange(self.modes2, device=x.device).unsqueeze(0).unsqueeze(2)  # Shape: (1, modes2, 1)
        k3 = torch.arange(self.modes3, device=x.device).unsqueeze(0).unsqueeze(1)  # Shape: (1, 1, modes3)

        # Compute theta for each dimension
        theta1 = 2 * np.pi * k1 / size1  # Shape: (modes1, 1, 1)
        theta2 = 2 * np.pi * k2 / size2  # Shape: (1, modes2, 1)
        theta3 = 2 * np.pi * k3 / size3  # Shape: (1, 1, modes3)

        # Compute W for each dimension
        W1 = torch.cos(theta1) + torch.sin(theta1)  # Shape: (modes1, 1, 1)
        W2 = torch.cos(theta2) + torch.sin(theta2)  # Shape: (1, modes2, 1)
        W3 = torch.cos(theta3) + torch.sin(theta3)  # Shape: (1, 1, modes3)

        # Combine W for all three dimensions using broadcasting
        W = W1 * W2 * W3  # Shape: (modes1, modes2, modes3)

        # Reshape W to match the Hartley coefficients tensor shape
        W = W.view(1, 1, self.modes1, self.modes2, self.modes3)  # Shape: (1, 1, modes1, modes2, modes3)

        # Apply the phase factor to the Hartley coefficients
        x_ht = x_ht[:, :, :self.modes1, :self.modes2, :self.modes3] * W  # Element-wise multiplication

        # Initialize the output tensor for the transformed coefficients
        out_ht = torch.zeros(
            batchsize,
            self.out_channels,
            size1,
            size2,
            size3,
            device=x.device,  # Ensure the output tensor is on the same device as the input
            dtype=x.dtype  # Ensure the output tensor has the same data type as the input
        )

        # Perform complex multiplication on the relevant Hartley modes
        out_ht[:, :, :self.modes1, :self.modes2, :self.modes3] = self.compl_mul3d(
            x_ht,  # Hartley coefficients after applying phase factor
            self.weights1  # Learnable weights
        )

        # Return to physical space using the Inverse DHT
        x = self.idht_3d(out_ht)  # Apply the IDHT to get the output in physical space

        return x  # Return the output

In [None]:
class FNN3d(nn.Module):
    """
    A 3D Feedforward Neural Network (FNN) that utilizes spectral convolutions
    and fully connected layers to process 3D input data.

    Attributes:
        modes1 (list of int): List of maximal modes for the first dimension for each layer.
        modes2 (list of int): List of maximal modes for the second dimension for each layer.
        modes3 (list of int): List of maximal modes for the third dimension for each layer.
        in_dim (int): Input dimension.
        out_dim (int): Output dimension.
        padding (tuple): Padding values for each dimension.
        layers (list of int): List defining the number of channels in each layer.
        fc0 (nn.Linear): Fully connected layer for initial input processing.
        sp_convs (nn.ModuleList): List of spectral convolution layers.
        ws (nn.ModuleList): List of 1D convolution layers for processing the output of spectral convolutions.
        fc1 (nn.Linear): Fully connected layer for intermediate output processing.
        fc2 (nn.Linear): Final fully connected layer to produce output.
        activations (nn.ModuleList or callable): Activation functions to be used in the network.
    """

    def __init__(self, modes1, modes2, modes3, width=16, fc_dim=128, layers=None, in_dim=4, out_dim=1, activation='tanh', pad_x=0, pad_y=0, pad_z=0):
        """
        Initialize the FNN3d network.

        Args:
            modes1 (list of int): First dimension maximal modes for each layer.
            modes2 (list of int): Second dimension maximal modes for each layer.
            modes3 (list of int): Third dimension maximal modes for each layer.
            width (int): Width of the layers if layers is not provided.
            fc_dim (int): Dimension of the fully connected layers.
            layers (list of int): List of integers defining channels for each layer.
            in_dim (int): Input dimension.
            out_dim (int): Output dimension.
            activation (str): Activation function to be used ('tanh', 'gelu', 'relu', 'elu', 'swish', 'leaky_relu', 'prelu', 'sine', 'sinc').
            pad_x (int): Padding for the x dimension.
            pad_y (int): Padding for the y dimension.
            pad_z (int): Padding for the z dimension.
        """
        super(FNN3d, self).__init__()  # Call to the parent class constructor
        self.modes1 = modes1  # Store the modes for the first dimension
        self.modes2 = modes2  # Store the modes for the second dimension
        self.modes3 = modes3  # Store the modes for the third dimension
        self.in_dim = in_dim  # Store input dimension
        self.out_dim = out_dim  # Store output dimension
        self.padding = (0, 0, 0, pad_z, 0, pad_y, 0, pad_x)  # Define padding for each dimension

        # Initialize layers
        if layers is None:
            self.layers = [width] * 3  # Default to a list of three layers if none provided
        else:
            self.layers = layers  # Use provided layers
        self.fc0 = nn.Linear(self.in_dim, self.layers[0])  # First fully connected layer

        # Create spectral convolution layers
        self.sp_convs = nn.ModuleList([SpectralConv3d(
            in_size, out_size, mode1_num, mode2_num, mode3_num)
            for in_size, out_size, mode1_num, mode2_num, mode3_num
            in zip(self.layers, self.layers[1:], self.modes1, self.modes2, self.modes3)])  # List of spectral convolution layers

        # Create 1D convolution layers
        self.ws = nn.ModuleList([nn.Conv1d(in_size, out_size, 1)
                                 for in_size, out_size in zip(self.layers, self.layers[1:])])  # List of 1D convolutions for residual connections

        self.fc1 = nn.Linear(self.layers[-1], fc_dim)  # Fully connected layer for intermediate processing
        self.fc2 = nn.Linear(fc_dim, self.out_dim)  # Final fully connected layer to produce output

        # Initialize activations
        if activation == 'tanh':
            self.activation = F.tanh  # Tanh activation
        elif activation == 'gelu':
            self.activation = F.gelu  # GELU activation
        elif activation == 'relu':
            self.activation = F.relu  # ReLU activation
        elif activation == 'elu':
            self.activation = F.elu  # ELU activation
        elif activation == 'swish':
            self.activation = self.swish  # Swish activation
        elif activation == 'leaky_relu':
            self.activation = F.leaky_relu  # Leaky ReLU activation
        elif activation == 'prelu':
            self.activation = nn.PReLU()  # PReLU activation
        elif activation == 'sine':
            # Create a ModuleList of SineActivationLearnable for each layer
            self.activations = nn.ModuleList([SineActivationLearnable(num_features=layer) for layer in self.layers[1:]])
        elif activation == 'sinc':
            self.activation = self.sinc  # Sinc activation
        else:
            raise ValueError(f'{activation} is not supported')  # Raise error for unsupported activation functions

    def swish(self, x):
        """Swish activation function: x * sigmoid(x)"""
        return x * torch.sigmoid(x)

    def sinc(self, x):
        """Sinc activation function: sinc(x) = sin(x) / x (with sinc(0) = 1)"""
        return torch.where(x != 0, torch.sin(x) / x, torch.ones_like(x))

    def forward(self, x):
        """
        Forward pass of the FNN3d network.

        Args:
            x (torch.Tensor): Input tensor of shape (batchsize, x_grid, y_grid, z_grid, in_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batchsize, x_grid, y_grid, z_grid, out_dim).

        The forward pass consists of the following steps:
            1. Determine the batch size and dimensions of the input tensor.
            2. Apply padding to the input tensor.
            3. Process the input through the first fully connected layer.
            4. Permute the dimensions of the tensor for compatibility with spectral convolutions.
            5. Sequentially apply spectral convolutions and 1D convolutions,
               with activation functions in between.
            6. Permute the output back to original dimensions.
            7. Pass through fully connected layers to produce final output.
            8. Reshape the output tensor to ensure dimensions are as expected.
        """
        length = len(self.ws)  # Get the number of 1D convolution layers
        batchsize = x.shape[0]  # Get the batch size
        nx, ny, nz = x.shape[1], x.shape[2], x.shape[3]  # Get the dimensions of the input tensor
        x = F.pad(x, self.padding, "constant", 0)  # Pad the input tensor
        size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3]  # Get the new dimensions after padding

        x = self.fc0(x)  # Process through the first fully connected layer
        x = x.permute(0, 4, 1, 2, 3)  # Permute dimensions for spectral convolutions

        # Sequentially apply spectral and 1D convolutions
        for i, (speconv, w) in enumerate(zip(self.sp_convs, self.ws)):
            x1 = speconv(x)  # Apply spectral convolution
            x2 = w(x.view(batchsize, self.layers[i], -1)).view(batchsize, self.layers[i+1], size_x, size_y, size_z)  # Apply 1D convolution
            x = x1 + x2  # Combine the results
            if i != length - 1:
                if hasattr(self, 'activations'):
                    x = self.activations[i](x)  # Apply SineActivationLearnable
                else:
                    x = self.activation(x)  # Apply standard activation function

        x = x.permute(0, 2, 3, 4, 1)  # Permute back to original dimensions
        x = self.fc1(x)  # Process through the first fully connected layer
        if hasattr(self, 'activations'):
            # If activations are defined as ModuleList (e.g., 'sine'), apply activations to fc1 output
            # Assuming fc1 output has shape (batchsize, x_grid, y_grid, z_grid, fc_dim)
            # We'll apply activation across the last dimension (fc_dim)
            # Reshape to (batchsize * x_grid * y_grid * z_grid, fc_dim) for activation
            x = x.view(-1, x.shape[-1])
            x = self.activations[-1](x)
            x = x.view(batchsize, size_x, size_y, size_z, -1)
        else:
            x = self.activation(x)  # Apply activation function
        x = self.fc2(x)  # Process through the final fully connected layer
        x = x.reshape(batchsize, size_x, size_y, size_z, self.out_dim)  # Ensure output dimensions are correct
        x = x[..., :nx, :ny, :nz]  # Remove any padding before returning the output
        return x  # Return the final output tensor


In [None]:
def FDM_Burgers2D(u, D=1, nu=0.01):
    """
    Compute the right-hand side of the 2D Burgers' equation using finite difference methods.

    This function calculates the spatial and temporal derivatives of the input tensor
    representing the velocity field in a 2D Burgers' equation, including the effects
    of nonlinearity and diffusion.

    Args:
        u (torch.Tensor): Input tensor of shape (batchsize, nx, ny, nt) representing the velocity field.
        D (float): Domain size (length) of the spatial grid.
        nu (float): Viscosity coefficient for the diffusion term.

    Returns:
        torch.Tensor: Derivative tensor representing the changes in the velocity field.
    """
    batchsize = u.size(0)  # Number of samples in the batch
    nx = u.size(1)  # Number of grid points in the x direction
    ny = u.size(2)  # Number of grid points in the y direction
    nt = u.size(3)  # Number of time steps
    u = u.reshape(batchsize, nx, ny, nt)  # Reshape u to 4D tensor

    dt = D / (nt - 1)  # Time step size
    dx = D / nx  # Grid spacing in the x direction
    dy = D / ny  # Grid spacing in the y direction

    # Calculate spatial derivatives using torch.gradient
    u_x, u_y = torch.gradient(u, dim=[1, 2])  # First derivatives in x and y directions

    # Calculate second derivatives (Laplacian)
    u_xx, _ = torch.gradient(u_x, dim=[1, 2])  # Second derivative in x
    _, u_yy = torch.gradient(u_y, dim=[1, 2])  # Second derivative in y

    # Time derivative (central difference)
    ut = (u[..., 2:] - u[..., :-2]) / (2 * dt)  # Central difference in time

    # Compute the nonlinear term and diffusion
    Du = ut + (0.5 * (u_x ** 2 + u_y ** 2) - nu * (u_xx + u_yy))[..., 1:-1]  # Combine results

    return Du  # Return the computed derivatives

In [None]:
def PINO_loss_burgers2D(u, u0, nu=0.01):
    """
    Calculate the loss for the PINO model applied to the 2D Burgers' equation.

    This function computes the initial condition loss and the loss associated with
    the right-hand side of the Burgers' equation.

    Args:
        u (torch.Tensor): Predicted tensor of shape (batchsize, nx, ny, nt) representing the velocity field.
        u0 (torch.Tensor): Initial condition tensor of shape (batchsize, nx, ny) to compare against.
        nu (float): Viscosity coefficient for the diffusion term.

    Returns:
        tuple: A tuple containing the initial condition loss and the loss for the right-hand side of the equation.
    """
    batchsize = u.size(0)  # Number of samples in the batch
    nx = u.size(1)  # Number of grid points in the x direction
    ny = u.size(2)  # Number of grid points in the y direction
    nt = u.size(3)  # Number of time steps
    u = u.reshape(batchsize, nx, ny, nt)  # Reshape u to 4D tensor

    lploss = F.mse_loss  # Use Mean Squared Error loss for initial condition
    u_ic = u[..., 0]  # Extract initial condition from predicted velocity field
    loss_ic = lploss(u_ic, u0)  # Compute loss against the true initial condition

    # Compute the derivatives using the finite difference method
    Du = FDM_Burgers2D(u, nu=nu)  # Get the derivative tensor for the current velocity field
    f = torch.zeros(Du.shape, device=u.device)  # Target tensor for the PDE right-hand side
    loss_f = F.mse_loss(Du, f)  # Compute the loss for the right-hand side of the equation

    return loss_ic, loss_f  # Return both losses

In [None]:
"""
Contains classes for solving the 2D Burgers' equation using finite difference methods.
The Burgers' equation is a fundamental partial differential equation that describes various physical phenomena,
including fluid dynamics and shock wave formation.

The module includes two classes:

1. BurgersEq2D: This class implements the 2D Burgers' equation using scalar velocity.
   - It initializes the spatial domain, grid points, and physical parameters such as viscosity.
   - It provides methods for computing spatial derivatives using central differencing,
   calculating the right-hand side (RHS) of the equation, and performing time-stepping via
   the Runge-Kutta 4th order (RK4) method.
   - The class includes a method for plotting the velocity field at each time step.

   Inputs:
   - xmin: Minimum x-coordinate of the spatial domain (default 0)
   - xmax: Maximum x-coordinate of the spatial domain (default 1)
   - ymin: Minimum y-coordinate of the spatial domain (default 0)
   - ymax: Maximum y-coordinate of the spatial domain (default 1)
   - Nx: Number of grid points in the x-direction (default 100)
   - Ny: Number of grid points in the y-direction (default 100)
   - nu: Viscosity coefficient (default 0.01)
   - dt: Time step (default 1e-3)
   - tend: End time for the simulation (default 1.0)
   - device: Device (CPU/GPU) to use (default None)
   - dtype: Data type for tensors (default torch.float64)

   Outputs:
   - A tensor containing the time evolution of the velocity field.

2. BurgersEq2D_Vec: This class extends the BurgersEq2D class to handle vector fields,
   allowing for the simulation of two velocity components (u and v) in the 2D space.
   - It includes similar methods for computing derivatives, updating fields, and
   plotting the results.

   Inputs and Outputs are analogous to those in BurgersEq2D, but handle two velocity fields.

Usage:
- Instantiate the desired class with the appropriate parameters,
  and call the `burgers_driver` method with an initial condition to run the simulation.
"""

class BurgersEq2D():
    def __init__(self,
                 xmin=0,
                 xmax=1,
                 ymin=0,
                 ymax=1,
                 Nx=100,
                 Ny=100,
                 nu=0.01,
                 dt=1e-3,
                 tend=1.0,
                 device=None,
                 dtype=torch.float64,
                 ):
        # Initialize parameters for the 2D Burgers' equation
        self.xmin = xmin  # Minimum x-coordinate
        self.xmax = xmax  # Maximum x-coordinate
        self.ymin = ymin  # Minimum y-coordinate
        self.ymax = ymax  # Maximum y-coordinate
        self.Nx = Nx  # Number of grid points in the x-direction
        self.Ny = Ny  # Number of grid points in the y-direction
        # Create linearly spaced grids in the x and y directions
        x = torch.linspace(xmin, xmax, Nx + 1, device=device, dtype=dtype)[:-1]
        y = torch.linspace(ymin, ymax, Ny + 1, device=device, dtype=dtype)[:-1]
        self.x = x  # x-coordinates
        self.y = y  # y-coordinates
        self.dx = x[1] - x[0]  # Grid spacing in the x-direction
        self.dy = y[1] - y[0]  # Grid spacing in the y-direction
        self.X, self.Y = torch.meshgrid(x, y, indexing='ij')  # Create a meshgrid for plotting
        self.nu = nu  # Viscosity coefficient
        self.u = torch.zeros_like(self.X, device=device)  # Initialize velocity field
        self.u0 = torch.zeros_like(self.u, device=device)  # Initial condition for velocity
        self.dt = dt  # Time step
        self.tend = tend  # End time
        self.t = 0  # Current time
        self.it = 0  # Time step counter
        self.U = []  # List to store results
        self.T = []  # List to store time points
        self.device = device  # Device (CPU/GPU) to use

    # Central differencing for first derivative in specified axis
    def CD_i(self, data, axis, dx):
        data_m2 = torch.roll(data, shifts=2, dims=axis)
        data_m1 = torch.roll(data, shifts=1, dims=axis)
        data_p1 = torch.roll(data, shifts=-1, dims=axis)
        data_p2 = torch.roll(data, shifts=-2, dims=axis)
        data_diff_i = (data_m2 - 8.0 * data_m1 + 8.0 * data_p1 - data_p2) / (12.0 * dx)
        return data_diff_i

    # Central differencing for second derivatives in two axes
    def CD_ij(self, data, axis_i, axis_j, dx, dy):
        data_diff_i = self.CD_i(data, axis_i, dx)
        data_diff_ij = self.CD_i(data_diff_i, axis_j, dy)
        return data_diff_ij

    # Central differencing for second derivative in specified axis
    def CD_ii(self, data, axis, dx):
        data_m2 = torch.roll(data, shifts=2, dims=axis)
        data_m1 = torch.roll(data, shifts=1, dims=axis)
        data_p1 = torch.roll(data, shifts=-1, dims=axis)
        data_p2 = torch.roll(data, shifts=-2, dims=axis)
        data_diff_ii = (-data_m2 + 16.0 * data_m1 - 30.0 * data + 16.0 * data_p1 - data_p2) / (12.0 * dx ** 2)
        return data_diff_ii

    # First derivative in the x-direction
    def Dx(self, data):
        data_dx = self.CD_i(data=data, axis=0, dx=self.dx)
        return data_dx

    # First derivative in the y-direction
    def Dy(self, data):
        data_dy = self.CD_i(data=data, axis=1, dx=self.dy)
        return data_dy

    # Second derivative in the x-direction
    def Dxx(self, data):
        data_dxx = self.CD_ii(data, axis=0, dx=self.dx)
        return data_dxx

    # Second derivative in the y-direction
    def Dyy(self, data):
        data_dyy = self.CD_ii(data, axis=1, dx=self.dy)
        return data_dyy

    # Calculate the right-hand side (RHS) of the Burgers' equation
    def burgers_calc_RHS(self, u):
        u_xx = self.Dxx(u)  # Second derivative in x
        u_yy = self.Dyy(u)  # Second derivative in y
        u2 = u ** 2.0  # Square of the velocity
        u2_x = self.Dx(u2)  # First derivative of the square in x
        u2_y = self.Dy(u2)  # First derivative of the square in y
        u_RHS = -0.5 * (u2_x + u2_y) + self.nu * (u_xx + u_yy)  # Calculate RHS
        return u_RHS

    # Update the velocity field based on the RHS and step fraction
    def update_field(self, field, RHS, step_frac):
        field_new = field + self.dt * step_frac * RHS
        return field_new

    # Merge RHS results for RK4 integration
    def rk4_merge_RHS(self, field, RHS1, RHS2, RHS3, RHS4):
        field_new = field + self.dt / 6.0 * (RHS1 + 2 * RHS2 + 2.0 * RHS3 + RHS4)
        return field_new

    # Runge-Kutta 4th order integration step
    def burgers_rk4(self, u, t=0):
        u_RHS1 = self.burgers_calc_RHS(u)  # Calculate RHS at t
        u1 = self.update_field(u, u_RHS1, step_frac=0.5)  # Update for half-step

        u_RHS2 = self.burgers_calc_RHS(u1)  # Calculate RHS at half-step
        u2 = self.update_field(u, u_RHS2, step_frac=0.5)  # Update for half-step

        u_RHS3 = self.burgers_calc_RHS(u2)  # Calculate RHS at t + dt/2
        u3 = self.update_field(u, u_RHS3, step_frac=1.0)  # Update for full-step

        u_RHS4 = self.burgers_calc_RHS(u3)  # Calculate RHS at t + dt

        u_new = self.rk4_merge_RHS(u, u_RHS1, u_RHS2, u_RHS3, u_RHS4)  # Merge results
        return u_new


    # Main driver function to run the simulation
    def burgers_driver(self, u0, save_interval=10):
        self.u0 = u0[:self.Nx, :self.Ny]
        self.u = self.u0
        self.t = 0
        self.it = 0
        self.T = []
        self.U = []

        if save_interval != 0 and self.it % save_interval == 0:
            self.U.append(self.u)
            self.T.append(self.t)

        while self.t < self.tend:
            self.u = self.burgers_rk4(self.u, self.t)  # Only update self.u
            self.t += self.dt  # Increment time

            self.it += 1
            if save_interval != 0 and self.it % save_interval == 0:
                self.U.append(self.u)
                self.T.append(self.t)

        return torch.stack(self.U)

In [None]:
def save_checkpoint(path, name, model, optimizer=None):
    """
    Save the model and optimizer state to a checkpoint file.

    This function creates a directory for checkpoints if it doesn't exist,
    and then saves the state dictionaries of the model and optimizer (if provided)
    to a specified file.

    Args:
        path (str): The directory path where the checkpoint will be saved.
        name (str): The name of the checkpoint file.
        model (nn.Module): The neural network model whose state will be saved.
        optimizer (torch.optim.Optimizer, optional): The optimizer whose state will be saved.
                                                     Defaults to None, which skips saving the optimizer state.

    Returns:
        None
    """
    ckpt_dir = 'checkpoints/%s/' % path  # Define the checkpoint directory based on the provided path
    if not os.path.exists(ckpt_dir):  # Check if the directory exists
        os.makedirs(ckpt_dir)  # Create the directory if it does not exist

    try:
        model_state_dict = model.module.state_dict()  # Get the state dict for a model wrapped in DataParallel
    except AttributeError:
        model_state_dict = model.state_dict()  # Get the state dict for a standard model

    # Get the optimizer state if provided, otherwise set to a default value
    if optimizer is not None:
        optim_dict = optimizer.state_dict()  # Get the state dict for the optimizer
    else:
        optim_dict = 0.0  # Default value if no optimizer is provided

    # Save the model and optimizer state dictionaries to the specified checkpoint file
    torch.save({
        'model': model_state_dict,  # Save the model state
        'optim': optim_dict  # Save the optimizer state
    }, ckpt_dir + name)  # Combine directory and file name for saving

    print('Checkpoint is saved at %s' % (ckpt_dir + name))  # Confirm that the checkpoint has been saved

In [None]:
class LpLoss(nn.Module):
    """
    Loss function that computes relative and/or absolute Lp losses,
    and optionally includes Total Variation (TV) loss.

    This class implements the Lp loss which can be used to measure the
    distance between predicted and true values in various norms. Additionally,
    it can compute the Total Variation loss to promote smoothness in the solution.

    Args:
        d (int, optional): Number of dimensions for the loss scaling. Default is 2.
        p (float, optional): Type of Lp-norm (e.g., 1 for L1, 2 for L2). Default is 2.
        size_average (bool, optional): If True, the losses are averaged over each loss element. Default is True.
        reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default is 'mean'.
        loss_type (str, optional): Type of loss to compute: 'rel', 'abs', 'both', or 'all'.
                                   'all' includes relative, absolute, and TV losses. Default is 'rel'.
        tv_weight (float, optional): Weighting factor for the Total Variation loss. Relevant only if TV loss is included. Default is 0.0.
        input_coords (torch.Tensor, optional): Tensor containing the input coordinates corresponding to `x`. Required for TV loss computation.
                                               Should have gradients enabled.
    """
    def __init__(self,
                 d: int = 2,
                 p: float = 2.0,
                 size_average: bool = True,
                 reduction: str = 'mean',
                 loss_type: str = 'rel',
                 tv_weight: float = 0.0,
                 input_coords: torch.Tensor = None):
        super(LpLoss, self).__init__()

        # Validate inputs
        if d <= 0:
            raise ValueError("Number of dimensions 'd' must be positive.")
        if p <= 0:
            raise ValueError("Lp-norm type 'p' must be positive.")
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError("Reduction must be one of 'none', 'mean', or 'sum'.")
        if loss_type not in ['rel', 'abs', 'both', 'all']:
            raise ValueError("Loss type must be one of 'rel', 'abs', 'both', or 'all'.")
        if loss_type == 'all' and tv_weight <= 0.0:
            raise ValueError("For 'all' loss_type, 'tv_weight' must be positive to include TV loss.")
        if loss_type == 'all' and input_coords is None:
            raise ValueError("For 'all' loss_type, 'input_coords' must be provided for TV loss computation.")
        if loss_type != 'all' and tv_weight != 0.0:
            raise ValueError("tv_weight should be 0.0 unless loss_type is 'all'.")

        self.d = d
        self.p = p
        self.size_average = size_average
        self.reduction = reduction
        self.loss_type = loss_type
        self.tv_weight = tv_weight
        self.input_coords = input_coords  # Coordinates should have requires_grad=True

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Compute the specified Lp loss and optionally the Total Variation loss between predicted and true values.

        Args:
            x (torch.Tensor): Predicted values. Shape: (batch_size, ..., ...)
            y (torch.Tensor): True values. Shape: same as x.

        Returns:
            torch.Tensor: Computed loss value.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x shape {x.shape} vs y shape {y.shape}")

        # Compute Lp losses
        loss_components = []

        # Flatten the tensors except for the batch dimension
        diff = x.view(x.size(0), -1) - y.view(y.size(0), -1)
        diff_norm = torch.norm(diff, p=self.p, dim=1)  # Shape: (batch_size,)
        y_norm = torch.norm(y.view(y.size(0), -1), p=self.p, dim=1)  # Shape: (batch_size,)

        # Compute absolute loss
        h = 1.0 / (x.size(1) - 1.0)  # Assuming uniform mesh along the second dimension
        abs_loss = (h ** (self.d / self.p)) * diff_norm
        loss_components.append(abs_loss)

        # Compute relative loss with epsilon to prevent division by zero
        epsilon = 1e-12
        rel_loss = diff_norm / (y_norm + epsilon)
        loss_components.append(rel_loss)

        # Initialize total loss
        total_loss = 0.0

        # Handle different loss types
        if self.loss_type == 'rel':
            total_loss = rel_loss
        elif self.loss_type == 'abs':
            total_loss = abs_loss
        elif self.loss_type == 'both':
            total_loss = rel_loss + abs_loss
        elif self.loss_type == 'all':
            # Compute TV loss
            tv_loss = self.compute_total_variation(x)
            total_loss = rel_loss + abs_loss + self.tv_weight * tv_loss
        else:
            raise ValueError(f"Unsupported loss type: {self.loss_type}")

        # Apply reduction
        if self.reduction == 'mean':
            return total_loss.mean()
        elif self.reduction == 'sum':
            return total_loss.sum()
        else:  # 'none'
            return total_loss

    def compute_total_variation(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute the Total Variation (TV) loss for the predicted tensor.

        Args:
            x (torch.Tensor): Predicted values. Shape: (batch_size, channels, ...)

        Returns:
            torch.Tensor: TV loss value per batch.
        """
        if self.input_coords is None:
            raise ValueError("Input coordinates must be provided for TV loss computation.")

        # Ensure input_coords requires gradients
        if not self.input_coords.requires_grad:
            raise ValueError("input_coords must have requires_grad=True for TV loss computation.")

        # Compute gradients with respect to input coordinates
        grads = torch.autograd.grad(
            outputs=x,
            inputs=self.input_coords,
            grad_outputs=torch.ones_like(x),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]  # Shape: same as input_coords

        # Compute the L1 norm of gradients (Total Variation)
        # Assuming input_coords has spatial dimensions; sum over spatial dims
        # For example, input_coords shape: (batch_size, num_dims, H, W) for 2D
        tv_loss = torch.sum(torch.abs(grads), dim=1)  # Sum over the spatial dimensions

        # Depending on the desired reduction, return appropriate value
        # Here, we return the mean TV loss per batch
        return tv_loss.mean()

In [None]:
def train_burgers2d(
    model,
    dataset,
    train_loader,
    optimizer,
    scheduler,
    nu=0.01,                   # Viscosity coefficient for the Burgers' equation
    rank=0,
    log=False,
    use_tqdm=True
):
    """
    Train the PINO model for the 2D Burgers' equation.

    Args:
        model (nn.Module): The neural network model to train.
        dataset: Dataset containing the training data.
        train_loader: DataLoader for batching the training data.
        optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
        scheduler: Learning rate scheduler for adjusting the learning rate.
        nu (float): Viscosity coefficient for the Burgers' equation.
        rank (int): Rank of the process (for distributed training).
        log (bool): Whether to log metrics.
        use_tqdm (bool): Whether to use tqdm for progress visualization.

    Returns:
        None
    """
    # Extract weights for different loss components
    data_weight = 10.0
    f_weight = 1.0
    ic_weight = 10.0
    ckpt_freq = 25

    model.train()  # Set the model to training mode
    myloss = LpLoss(size_average=True)  # Loss function for data
    S, T = dataset.S, dataset.T  # Extract spatial and temporal information from the dataset

    # Progress bar setup
    pbar = range(150)  # config['train']['epochs']
    if use_tqdm:
        pbar = tqdm(pbar, dynamic_ncols=True, smoothing=0.1)

    # Initialize max memory tracking
    max_memory_usage = 0

    for e in pbar:
        train_loss = 0.0  # Initialize cumulative training loss

        for x, y in train_loader:  # Iterate over batches of data
            x, y = x.to(rank), y.to(rank)  # Move data to the appropriate device
            out = model(x).reshape(y.shape)  # Model output reshaped to match target shape

            # Calculate data loss
            data_loss = myloss(out, y)

            # Compute the initial condition and PDE loss
            loss_ic, loss_f = PINO_loss_burgers2D(out, x[..., 0, -1], nu=nu)
            total_loss = loss_ic * ic_weight + loss_f * f_weight + data_loss * data_weight  # Total loss

            optimizer.zero_grad()  # Zero the gradients
            total_loss.backward()  # Backpropagate to compute gradients
            optimizer.step()  # Update model parameters

            # Accumulate losses for reporting
            train_loss += total_loss.item()

            # Track memory usage during training
            if torch.cuda.is_available():
                torch.cuda.synchronize()  # Ensure all computations are done
                max_memory_usage = max(max_memory_usage, torch.cuda.max_memory_allocated(rank))
            else:
                process = psutil.Process()
                max_memory_usage = max(max_memory_usage, process.memory_info().rss)

        scheduler.step()  # Step the learning rate scheduler

        # Calculate average losses for reporting
        train_loss /= len(train_loader)

        if use_tqdm:
            pbar.set_description(f'Epoch {e}, train loss: {train_loss:.5f}')

        # Save checkpoints at specified intervals
        if e % ckpt_freq == 0:
            save_checkpoint(
                'Burgers2D',
                'Burgers2D-0001.pt'.replace('.pt', f'_{e}.pt'),
                model, optimizer
            )

    # Save the final model checkpoint
    save_checkpoint(
        'Burgers2D',
        'Burgers2D-0001.pt',
        model, optimizer
    )

    # Print the maximum memory usage at the end
    if torch.cuda.is_available():
        print(f'Maximum GPU memory allocated during training: {max_memory_usage / (1024 ** 2):.2f} MB')
    else:
        print(f'Maximum CPU memory used during training: {max_memory_usage / (1024 ** 2):.2f} MB')

    print('Training complete!')


In [None]:
def eval_burgers2D(model,
                   dataloader,
                   device,
                   nu=0.01,
                   use_tqdm=True):
    """
    Evaluate the PINO model for the 2D Burgers' equation.

    This function evaluates the model's performance on a given dataset, calculating
    the average L2 error and the error associated with the Burgers' equation. It
    operates in evaluation mode to disable gradient calculations.

    Args:
        model (nn.Module): The trained neural network model to evaluate.
        dataloader: DataLoader containing the test data.
        config (dict): Configuration dictionary for the evaluation settings.
        device (torch.device): Device to perform the computation (CPU or GPU).
        nu (float): Viscosity coefficient for the Burgers' equation.
        use_tqdm (bool): Whether to use tqdm for progress visualization.

    Returns:
        None
    """
    model.eval()  # Set the model to evaluation mode
    myloss = LpLoss(size_average=True)  # Define the loss function for data evaluation

    # Set up progress bar
    pbar = tqdm(dataloader, dynamic_ncols=True, smoothing=0.05) if use_tqdm else dataloader

    test_err = []  # List to store data loss for evaluation
    f_err = []  # List to store equation loss for evaluation

    with torch.no_grad():  # Disable gradient calculations
        for x, y in pbar:
            x, y = x.to(device), y.to(device)  # Move data to the specified device
            out = model(x).reshape(y.shape)  # Get model output and reshape to match target shape

            # Calculate data loss
            data_loss = myloss(out, y)

            # Compute the initial condition and PDE loss
            loss_ic, f_loss = PINO_loss_burgers2D(out, x[..., 0, -1], nu=nu)
            test_err.append(data_loss.item())  # Store data loss
            f_err.append(f_loss.item())  # Store PDE loss

    # Calculate mean and standard deviation of errors
    mean_f_err = np.mean(f_err)  # Mean of equation errors
    std_f_err = np.std(f_err, ddof=1) / np.sqrt(len(f_err))  # Standard error of equation errors

    mean_err = np.mean(test_err)  # Mean of data errors
    std_err = np.std(test_err, ddof=1) / np.sqrt(len(test_err))  # Standard error of data errors

    # Print the evaluation results
    print(f'==Averaged relative L2 error mean: {mean_err:.5f}, std error: {std_err:.5f}==\n'
          f'==Averaged equation error mean: {mean_f_err:.5f}, std error: {std_f_err:.5f}==')

In [None]:
class DataLoader2D(object):
    """
    A custom DataLoader for handling 2D data for training and evaluation of models
    related to the Burgers' equation or similar tasks.

    This class allows for subsampling of the input data and provides functionality
    to create DataLoader instances for batching the data.

    Attributes:
        sub (int): Factor for subsampling in the spatial dimensions.
        sub_t (int): Factor for subsampling in the temporal dimension.
        S (int): Effective size of the spatial grid after subsampling.
        T (int): Effective size of the temporal grid after subsampling.
        data (torch.Tensor): Processed tensor containing the input data.
    """

    def __init__(self, data, nx=128, nt=100, sub=1, sub_t=1):
        """
        Initialize the DataLoader2D instance.

        Args:
            data (torch.Tensor): Input tensor of shape (batch_size, nt, nx, nx).
            nx (int): Size of the spatial dimension.
            nt (int): Size of the temporal dimension.
            sub (int): Subsampling factor for the spatial dimensions.
            sub_t (int): Subsampling factor for the temporal dimension.
        """
        self.sub = sub  # Store the spatial subsampling factor
        self.sub_t = sub_t  # Store the temporal subsampling factor
        s = nx  # Initialize size variable with the spatial dimension size

        # Ensure nx is even for consistent processing
        if (s % 2) == 1:
            s = s - 1  # Reduce size by 1 if odd

        # Calculate effective sizes after subsampling
        self.S = s // sub  # Effective spatial dimension after subsampling
        self.T = nt // sub_t  # Effective temporal dimension after subsampling
        self.T += 1  # Increment T to include the initial condition

        # Subsample the input data tensor
        data = data[:, 0:self.T:sub_t, 0:self.S:sub, 0:self.S:sub]
        self.data = data.permute(0, 2, 3, 1)  # Rearrange dimensions to (batch_size, S, S, T)

    def make_loader(self, n_sample, batch_size, start=0, train=True):
        """
        Create a DataLoader for training or evaluation.

        Args:
            n_sample (int): Number of samples to include in the DataLoader.
            batch_size (int): Number of samples per batch.
            start (int): Starting index for the samples in the dataset.
            train (bool): Whether to shuffle the data for training or not.

        Returns:
            DataLoader: A DataLoader instance for the specified samples.
        """
        # Extract and reshape the data for the DataLoader
        a_data = self.data[start:start + n_sample, :, :, 0].reshape(n_sample, self.S, self.S)  # Extract auxiliary data
        u_data = self.data[start:start + n_sample].reshape(n_sample, self.S, self.S, self.T)  # Extract target data

        # Generate the grid for 3D data
        gridx, gridy, gridt = get_grid3d(self.S, self.T)

        # Expand auxiliary data dimensions and repeat for the time dimension
        a_data = a_data.reshape(n_sample, self.S, self.S, 1, 1).repeat([1, 1, 1, self.T, 1])

        # Concatenate grid information with auxiliary data
        a_data = torch.cat((gridx.repeat([n_sample, 1, 1, 1, 1]),
                            gridy.repeat([n_sample, 1, 1, 1, 1]),
                            gridt.repeat([n_sample, 1, 1, 1, 1]),
                            a_data), dim=-1)

        # Create a TensorDataset from the auxiliary and target data
        dataset = torch.utils.data.TensorDataset(a_data, u_data)

        # Create a DataLoader with shuffling for training or without shuffling for evaluation
        if train:
            loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        else:
            loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

        return loader  # Return the DataLoader instance

In [None]:
def get_grid3d(S, T, time_scale=1.0, device='cpu'):
    """
    Generate 3D grids for spatial and temporal coordinates.

    This function creates 3D grids representing the spatial dimensions (gridx, gridy)
    and the temporal dimension (gridt) used in simulations or numerical methods related
    to 2D problems. The grids are constructed based on the specified sizes for the spatial
    and temporal dimensions and can be placed on a specified device (CPU or GPU).

    Args:
        S (int): Size of the spatial grid (number of spatial points).
        T (int): Size of the temporal grid (number of time points).
        time_scale (float, optional): Scale factor for the time dimension (default: 1.0).
        device (str, optional): Device to store the generated tensors (default: 'cpu').

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            - gridx (torch.Tensor): 3D tensor representing the x-coordinates.
            - gridy (torch.Tensor): 3D tensor representing the y-coordinates.
            - gridt (torch.Tensor): 3D tensor representing the time-coordinates.
    """
    # Create a 1D grid for the x-coordinates, excluding the last point
    gridx = torch.tensor(np.linspace(0, 1, S + 1)[:-1], dtype=torch.float, device=device)
    # Reshape and repeat to create a 3D grid for x-coordinates
    gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1])

    # Create a 1D grid for the y-coordinates, excluding the last point
    gridy = torch.tensor(np.linspace(0, 1, S + 1)[:-1], dtype=torch.float, device=device)
    # Reshape and repeat to create a 3D grid for y-coordinates
    gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1])

    # Create a 1D grid for the time-coordinates scaled by time_scale
    gridt = torch.tensor(np.linspace(0, 1 * time_scale, T), dtype=torch.float, device=device)
    # Reshape and repeat to create a 3D grid for time-coordinates
    gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1])

    return gridx, gridy, gridt  # Return the generated grids

In [None]:
dim = 2
N = 128
Nx = N
Ny = N
Nt = 100 + 1
Ntest = 5
l = 0.1
L = 1.0
sigma = 0.2 #2.0
nu = 0.01
Nu = None #2.0
Nsamples = 50
jitter = 1e-12
dt=1.0e-4

In [None]:
with open('drive/MyDrive/U02.pkl', 'rb') as f:
  U0 = pickle.load(f)

In [None]:
burgers_eq = BurgersEq2D(Nx=Nx, Ny=Ny, dt=dt, nu=nu, device=device)

In [None]:
U = vmap(burgers_eq.burgers_driver, in_dims=(0, None))(U0, 100)

In [None]:
a = U0.cpu().float()
u = U.cpu().float()

In [None]:
dataset = DataLoader2D(
    u,
    nx=128,
    nt=100,
    sub=1,
    sub_t=1
)

train_loader = dataset.make_loader(
    n_sample=45,
    batch_size=1,
    start=0,
    train=True
)

test_loader = dataset.make_loader(
    n_sample=5,
    batch_size=1,
    start=45,
    train=False
)

In [None]:
log = False

# Create the model using direct values from the configuration
model = FNN3d(
    modes1=[8, 8, 8, 8],
    modes2=[8, 8, 8, 8],
    modes3=[8, 8, 8, 8],
    fc_dim=500,
    layers=[128, 128, 128, 128],
    activation='prelu'
).to(device)

# Initialize the optimizer with direct values
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.5e-3,
    weight_decay=0.5e-3
)

# Initialize the learning rate scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[25, 50, 75, 100],
    gamma=0.5
)

In [24]:
start = timer()

# Call the train_burgers2d function with direct values
train_burgers2d(
    model,
    dataset,
    train_loader,
    optimizer,
    scheduler,
    nu=0.01,
    rank=0,
    log=log,
    use_tqdm=True
)

end = timer()
print(timedelta(seconds=end - start))

Epoch 0, train loss: 10.58889:   1%|          | 1/150 [00:24<1:01:47, 24.89s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_0.pt


Epoch 25, train loss: 2.14360:  17%|█▋        | 26/150 [09:55<47:15, 22.87s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_25.pt


Epoch 50, train loss: 1.79612:  34%|███▍      | 51/150 [19:26<37:44, 22.87s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_50.pt


Epoch 75, train loss: 1.35685:  51%|█████     | 76/150 [28:57<28:12, 22.87s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_75.pt


Epoch 100, train loss: 1.10235:  67%|██████▋   | 101/150 [38:28<18:41, 22.88s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_100.pt


Epoch 125, train loss: 0.99471:  84%|████████▍ | 126/150 [48:00<09:09, 22.91s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_125.pt


Epoch 149, train loss: 0.90784: 100%|██████████| 150/150 [57:08<00:00, 22.86s/it]


Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001.pt
Maximum GPU memory allocated during training: 18677.72 MB
Training complete!
0:57:09.014705


In [25]:
start = timer()

# Call the eval_burgers2D function with direct values
eval_burgers2D(
    model,
    test_loader,
    device,
    nu=0.01,
    use_tqdm=True
)

end = timer()
print(timedelta(seconds=end - start))

100%|██████████| 5/5 [00:01<00:00,  4.84it/s]

==Averaged relative L2 error mean: 0.19749, std error: 0.00891==
==Averaged equation error mean: 0.00600, std error: 0.00064==
0:00:01.037934





In [None]:
class SpectralConv3d(nn.Module):
    """
    A spectral convolution layer that applies the Real Fast Fourier Transform (FFT)
    to perform convolution in the frequency domain for 3D tensors.

    This layer takes 3D input tensors and applies a convolution operation using a
    specified number of FFT modes.

    Attributes:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        modes1 (int): Number of FFT modes to use in the first dimension.
        modes2 (int): Number of FFT modes to use in the second dimension.
        modes3 (int): Number of FFT modes to use in the third dimension.
        scale (float): Scaling factor for the weights initialization.
        weights1 (nn.Parameter): Learnable complex weights for the convolution.
    """

    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        """
        Initialize the SpectralConv3d layer.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            modes1 (int): Number of FFT modes in the first dimension.
            modes2 (int): Number of FFT modes in the second dimension.
            modes3 (int): Number of FFT modes in the third dimension.
        """
        super(SpectralConv3d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3

        self.scale = (1 / (in_channels * out_channels))
        # Initialize real and imaginary parts separately
        self.weights1_real = nn.Parameter(
            self.scale * torch.randn(in_channels, out_channels, modes1, modes2, modes3)
        )
        self.weights1_imag = nn.Parameter(
            self.scale * torch.randn(in_channels, out_channels, modes1, modes2, modes3)
        )

    def fft_forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Perform the forward real FFT on the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, channels, depth, height, width).

        Returns:
            torch.Tensor: FFT-transformed tensor with complex dtype.
        """
        return torch.fft.rfftn(x, dim=(-3, -2, -1))

    def fft_inverse(self, x: torch.Tensor, original_size: tuple) -> torch.Tensor:
        """
        Perform the inverse real FFT to return to the physical space.

        Args:
            x (torch.Tensor): FFT-transformed tensor with complex dtype.
            original_size (tuple): The original size (depth, height, width) of the input tensor.

        Returns:
            torch.Tensor: Inverse FFT-transformed tensor of shape (batch_size, channels, depth, height, width).
        """
        return torch.fft.irfftn(x, s=original_size, dim=(-3, -2, -1))

    def compl_mul3d(self, input_fft: torch.Tensor, weights_fft: torch.Tensor) -> torch.Tensor:
        """
        Perform complex multiplication between input FFT and weights FFT.

        Args:
            input_fft (torch.Tensor): FFT of input with shape (batch_size, in_channels, modes1, modes2, modes3).
            weights_fft (torch.Tensor): FFT weights with shape (in_channels, out_channels, modes1, modes2, modes3).

        Returns:
            torch.Tensor: Result of complex multiplication with shape (batch_size, out_channels, modes1, modes2, modes3).
        """
        return torch.einsum("bixyz,ioxyz->boxyz", input_fft, weights_fft)

    def forward(self, x):
        """
        Forward pass of the spectral convolution layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_channels, depth, height, width).
        """
        batchsize = x.shape[0]
        depth, height, width = x.shape[-3], x.shape[-2], x.shape[-1]

        # Compute FFT of the input
        x_fft = self.fft_forward(x)  # Shape: (batch_size, in_channels, modes1, modes2, modes3)

        # Initialize FFT of the output with complex dtype
        out_fft = torch.zeros(
            batchsize,
            self.out_channels,
            self.modes1,
            self.modes2,
            self.modes3,
            dtype=torch.cfloat,
            device=x.device,
        )

        # Combine real and imaginary parts to form complex weights
        weights_fft = torch.complex(self.weights1_real, self.weights1_imag)  # Shape: (in_channels, out_channels, modes1, modes2, modes3)

        # Perform complex multiplication on the relevant FFT modes
        out_fft[:, :, :self.modes1, :self.modes2, :self.modes3] = self.compl_mul3d(
            x_fft[:, :, :self.modes1, :self.modes2, :self.modes3],
            weights_fft
        )

        # Inverse FFT to return to physical space
        x = self.fft_inverse(out_fft, original_size=(depth, height, width))  # Shape: (batch_size, out_channels, depth, height, width)

        return x

In [None]:
log = False

# Create the model using direct values from the configuration
model = FNN3d(
    modes1=[8, 8, 8, 8],
    modes2=[8, 8, 8, 8],
    modes3=[8, 8, 8, 8],
    fc_dim=256,
    layers=[128, 128, 128, 128],
    activation='elu'
).to(device)

# Initialize the optimizer with direct values
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3
)

# Initialize the learning rate scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[25, 50, 75, 100],
    gamma=0.5
)

In [None]:
start = timer()

# Call the train_burgers2d function with direct values
train_burgers2d(
    model,
    dataset,
    train_loader,
    optimizer,
    scheduler,
    nu=0.01,
    rank=0,
    log=log,
    use_tqdm=True
)

end = timer()
print(timedelta(seconds=end - start))

Epoch 0, train loss: 13.21085:   1%|          | 1/150 [00:16<41:33, 16.74s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_0.pt


Epoch 25, train loss: 1.28888:  17%|█▋        | 26/150 [06:41<32:00, 15.49s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_25.pt


Epoch 50, train loss: 0.49538:  34%|███▍      | 51/150 [13:06<25:32, 15.48s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_50.pt


Epoch 75, train loss: 0.28754:  51%|█████     | 76/150 [19:31<19:05, 15.48s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_75.pt


Epoch 100, train loss: 0.20959:  67%|██████▋   | 101/150 [25:56<12:40, 15.51s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_100.pt


Epoch 125, train loss: 0.18869:  84%|████████▍ | 126/150 [32:21<06:11, 15.49s/it]

Checkpoint is saved at checkpoints/Burgers2D/Burgers2D-0001_125.pt


Epoch 128, train loss: 0.18308:  86%|████████▌ | 129/150 [33:07<05:24, 15.45s/it]

In [None]:
start = timer()

# Call the eval_burgers2D function with direct values
eval_burgers2D(
    model,
    test_loader,
    device,
    nu=0.01,
    use_tqdm=True
)

end = timer()
print(timedelta(seconds=end - start))