# Fourier Neural Operator for 2-Dimensional Navier-Stokes Equations

<!-- @script made by: Kai Chang -->

This script takes a deep look into the dataset for 2-dimensional Navier-Stokes equations used in [[Li et al., 2020]](https://arxiv.org/abs/2010.08895) and walks through the implementation of 3-d fourier neural operator. 

## Problem Statement
Here we try to solve

\begin{align}
\partial_t w(x,t) + u(x,t)\cdot\nabla w(x,t) &= \nu\Delta w(x,t) + f(x) \quad &x\in(0,1)\times(0,1),t\in(0,T]\\
\nabla \cdot u(x,t) &= 0 \quad &x\in(0,1)\times(0,1),t\in(0,T]\\
w(x,0) &= w_0(x) &x\in(0,1)\times(0,1)
\end{align}

where $u(x,t)$, the velocity field, is a continuous function in $t$ and for each fixed $t$, $u(\cdot,t)\in H_{per}^r((0,1)\times(0,1);\mathbb{R}^2)$ for any $r > 0$, $w = \nabla \times u$ is the vorticity, $w_0 \in L^2_{per}((0,1)\times(0,1);\mathbb{R})$ is the initial vorticity, $\nu\in\mathbb{R}_+$ is the viscosity coefficient, and $f\in L^2_{per}((0,1)\times(0,1);\mathbb{R})$ is the forcing function.

## Goal
### To learn a operator that maps the vorticity of the first 10 time steps to the vorticity of the next 40 steps.

In [None]:
import torch
import numpy as np
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from utilities import *
import operator
from functools import reduce
from functools import partial
from timeit import default_timer
from Adam import Adam

torch.manual_seed(0)
np.random.seed(0)

## FNO Implementation

In [2]:
class FourierLayer3d(nn.Module):
    """
    3D Fourier Layer
    
    ->input => FFT => multiplied by a (kernel) matrix => inverse FFT => add Wv_t => output
    
    """
    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.modes3 = modes3

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = Parameter(self.scale * torch.rand(in_channels, 
                                                             out_channels, 
                                                             self.modes1, 
                                                             self.modes2, 
                                                             self.modes3, 
                                                             dtype=torch.cfloat))
        
        self.weights2 = Parameter(self.scale * torch.rand(in_channels, 
                                                             out_channels,
                                                             self.modes1,
                                                             self.modes2, 
                                                             self.modes3, 
                                                             dtype=torch.cfloat))
        
        self.weights3 = Parameter(self.scale * torch.rand(in_channels,
                                                             out_channels,
                                                             self.modes1,
                                                             self.modes2,
                                                             self.modes3, 
                                                             dtype=torch.cfloat))
        
        self.weights4 = Parameter(self.scale * torch.rand(in_channels,
                                                             out_channels,
                                                             self.modes1, 
                                                             self.modes2, 
                                                             self.modes3, 
                                                             dtype=torch.cfloat))
        
        self.w = nn.Conv3d(in_channels, out_channels, 1)
        
    # Complex multiplication
    def compl_mul3d(self, input, weights):
        # (batch, in_channel, x,y,t), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
        return torch.einsum("bixyz,ioxyz->boxyz", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1])

        # Multiply relevant Fourier modes
        # the reason of out_ft.size(3) = resolution//2 + 1 is because here we are using rfft
        # which takes the fact that the DFT of real modes is Hermitian
        # Therefore, only half of the entries needs to be stored.
        # This is also the reason that we are only doing multiplication half of the usual times
        out_ft = torch.zeros(batchsize, 
                             self.out_channels, 
                             x.size(-3), 
                             x.size(-2), 
                             x.size(-1)//2 + 1, 
                             dtype=torch.cfloat, 
                             device=x.device)
        
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        #Return to physical space
        x1 = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        x2 = self.w(x)
        return x1 + x2

In [3]:
class FNO3d(nn.Module):
    """
    The overall network.
    
        1. Lift the input to a higher dimension.
        2. 4 Fourier layers
        3. Project from the channel space to the output space.
        
    input: the solution of the first 10 timesteps + 3 locations 
           (u(1, x, y), ..., u(10, x, y),  x, y, t). 
           It's a constant function in time, except for the last index.
           
    input shape: (batchsize, resolution_x, resolution_y, resolution_t, in_channels)
        
    output: the solution of the next 40 timesteps
    output shape: (batchsize, resolution_x, resolution_y, resolution_t, 1)
    """
    
    def __init__(self, modes1, modes2, modes3, width):
        super().__init__()

        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.padding = 6 # pad the domain if input is non-periodic
        self.fc0 = nn.Linear(13, self.width)
        # the number of input channel is 12
        # it contains the solution of the first 10 timesteps + 3 locations 
        # (u(1,x,y), ..., u(10,x,y),  x, y, t)

        self.conv0 = FourierLayer3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv1 = FourierLayer3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv2 = FourierLayer3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv3 = FourierLayer3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
#         self.bn0 = torch.nn.BatchNorm3d(self.width)
#         self.bn1 = torch.nn.BatchNorm3d(self.width)
#         self.bn2 = torch.nn.BatchNorm3d(self.width)
#         self.bn3 = torch.nn.BatchNorm3d(self.width)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        # shape (batch, resolution_x, resolution_y, resolution_t, T_0 + 3)
        
        x = self.fc0(x)
        # shape (batch, resolution_x, resolution_y, resolution_t, width)
        
        x = x.permute(0, 4, 1, 2, 3)
        # shape (batch, width, resolution_x, resolution_y, resolution_t)
        
        x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic

        h1 = self.conv0(x)
        x = F.gelu(h1)

        h2 = self.conv1(x)
        x = F.gelu(h2)

        h3 = self.conv1(x)
        x = F.gelu(h3)
        
        h4 = self.conv1(x)
        x = h4

        x = x[..., :-self.padding] # unpad the domain if input is non-periodic
        x = x.permute(0, 2, 3, 4, 1) 
        # shape (batch, resolution_x, resolution_y, resolution_t, width)
        
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        # shape (batch, resolution_x, resolution_y, resolution_t, 1)
        
        return x

    def get_grid(self, shape, device):
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        # shape (batch, resolution_x, resolution_y, resolution_t, 1); encode x-coordinate
        
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        # shape (batch, resolution_x, resolution_y, resolution_t, 1); encode y-coordinate
        
        gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        # shape (batch, resolution_x, resolution_y, resolution_t, 1); encode time-index
        
        return torch.cat((gridx, gridy, gridz), dim=-1).to(device)
               # shape (batch, resolution_x, resolution_y, resolution_t, 3)

## Dataset

Let's take a look at the dataset to get us better understand some of the technicalities in the implementation. It contains the solution to 2-d Navier-Stokes equations with different vorticities (curl of the solution). The dataset is generated by classical PDE solvers. The raw data is a dictionary with two keys `'a'` and `'u'`. In neural operator training, `'a'` is the input and `'u'` is the ground truth. 

In this case, the domain of the map we want to approximate is the space of vorticities up to some time $T_0$; the codomain is the space of vorticities upto some time $T > T_0$. 

In [5]:
file_path = "../data/NS/ns_V1e-3_N5000_T50.mat"
dataloader = MatReader(file_path)
print(dataloader.data.keys())

<KeysViewHDF5 ['a', 't', 'u']>


In [6]:
time_data_raw = dataloader.read_field('t')
print('time size: ', time_data_raw.size())
T_0 = 10
T = 50

# we note that here the a_data and u_data come from the same tensor. 
# the difference between them is 'a' is the vorticities at the first 10 time steps
# while 'u' is the vorticities at the last 40 time steps
# therefore, we are essentially learning a map between the first several time steps and 
# the last several time steps
# i.e. we want to predict the later dynamics given the previous dynamics
a_data = dataloader.read_field('u')[:,:,:,:T_0]
print('a size: ', a_data.size())
u_data = dataloader.read_field('u')[:,:,:,T_0:T]
print('u size: ', u_data.size())

time size:  torch.Size([1, 50])
a size:  torch.Size([5000, 64, 64, 10])
u size:  torch.Size([5000, 64, 64, 40])


In [7]:
n_train = 1000
n_test = 200

sub_train = 2**0
sub_test = 2**0
full_res = a_data.size(-2)
train_res = full_res // sub_train
test_res = full_res // sub_test

a_train = a_data[:n_train, ::sub_train, ::sub_train, :]
u_train = u_data[:n_train, ::sub_test, ::sub_test, :]
a_test = a_data[-n_test:, ::sub_test, ::sub_test, :]
u_test = u_data[-n_test:, ::sub_test, ::sub_test, :]

runtime = np.zeros(2, )
t1 = default_timer()

assert (train_res == u_train.shape[-2])
assert (T-T_0 == u_train.shape[-1])

a_normalizer = UnitGaussianNormalizer(a_train)
a_train = a_normalizer.encode(a_train)
a_test = a_normalizer.encode(a_test)
u_normalizer = UnitGaussianNormalizer(u_train)
u_train = u_normalizer.encode(u_train)

# stretch the tensor to make size compatible
a_train = a_train.reshape(n_train,train_res,train_res,1,T_0).repeat([1,1,1,T-T_0,1])
a_test = a_test.reshape(n_test,test_res,test_res,1,T_0).repeat([1,1,1,T-T_0,1])

print("a_train shape: ", a_train.size())
print("a_test shape: ", a_test.size())
print("u_train shape: ", u_train.size())
print("u_test shape: ", u_test.size())

####### load data
batch_size = 10
batch_size2 = batch_size

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(a_train, u_train),
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(a_test, u_test), 
                                          batch_size=batch_size, 
                                          shuffle=False)
t2 = default_timer()

print('preprocessing finished, time used:', t2-t1)

a_train shape:  torch.Size([1000, 64, 64, 40, 10])
a_test shape:  torch.Size([200, 64, 64, 40, 10])
u_train shape:  torch.Size([1000, 64, 64, 40])
u_test shape:  torch.Size([200, 64, 64, 40])
preprocessing finished, time used: 4.997746458999998


# Training and Testing

### Input: vorticity of the first $T_0$ (in this case 10) time steps evaluated on some discretized grids of the spatial domain and the grid information

### Output: vorticity of the next $T-T_0$ (in this case 40) time steps evaluated on some discretized grids of the spatial domain

In [None]:
%matplotlib notebook

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

###################################################
################## model configs ##################
###################################################

print('training data resolution: ', train_res)
print('testing data resolution: ', test_res)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 300
epoch_step = 25
learning_rate = 0.001
step_size = 100
gamma = 0.5
modes = 8
width = 20

model = FNO3d(modes, modes, modes, width).to(device)
print(count_params(model))

optimizer = Adam(model.parameters(), 
                 lr=learning_rate, 
                 weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=scheduler_step, 
                                            gamma=gamma)
myloss = LpLoss(size_average=False)
train_MSE_Loss = []
train_L2_Loss = []
test_L2_Loss = []

####################################################
################# plotting configs #################
####################################################

# change sample by changing rows and cols
rows,cols = 2,2
num_plots = rows * cols

sampled_batch_idx = random.sample(range(n_test//batch_size),num_plots)
sampled_plotting_idx = np.random.randint(batch_size, size = num_plots)
fig = plt.figure(figsize=(16, 16))

# config for saving figures
saveFig = False
fig_folder = 'ep' + str(epochs) + \
             '-lr' + str(learning_rate) + \
             '-decay' + str(gamma) + \
             '-trainRes' + str(train_res) + \
             '-testRes' + str(test_res)

fig_dir = '../figs/ns/' + fig_folder
if not os.path.isdir(fig_dir):
    os.makedirs(fig_dir)
plot_pos = 1


for ep in range(epochs):
    fig.clear()
    fig.suptitle('epoch: #' + str(ep),fontsize = 50)
    
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x).view(batch_size, train_res, train_res, T-T_0)

        mse = F.mse_loss(out.detach(), y.detach(), reduction='mean')
        # mse.backward()

        y = u_normalizer.decode(y)
        out = u_normalizer.decode(out)
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward()

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)

            out = model(x).view(batch_size, test_res, test_res, T-T_0)
            out = u_normalizer.decode(out)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            
            ##### Plotting #####
            if batch in sampled_batch_idx:
                idx = sampled_plotting_idx[plot_pos-1]
                ax = fig.add_subplot(rows,cols,plot_pos)
                ax.clear()
#                 ax.get_xaxis().set_visible(False)
                ax.set_ylim((-1.2,1.2))
                u_sample = out[idx,:,:].reshape((-1,1))
                u_sample = u_sample.detach().cpu().numpy()
                u_truth = u[idx,:].reshape((-1,1))
                u_truth = u_truth.detach().cpu().numpy()
                ax.plot(xvals_test, u_sample, 'r', label='FNO1d')
                ax.plot(xvals_test, u_truth, 'b--', label='truth')
                ax.set_title(f"sample #{plot_pos}")
                ax.set_xlabel('x')
                ax.set_ylabel('u(x,1)')
                ax.legend()
                plot_pos += 1
        fig.canvas.draw()
        if saveFig:
            if fig_dir[-1] != '/':
                fig_dir = fig_dir + '/'
            fig_name = fig_dir + 'epoch-' + str(ep) + '.png'
            fig.savefig(fig_name)
        plot_pos = 1

            
    # update loss
    train_mse /= len(train_loader)
    train_MSE_Loss.append(train_mse)
    train_l2 /= n_train
    train_L2_Loss.append(train_l2)
    test_l2 /= n_test
    test_L2_Loss.append(test_l2)

    t2 = default_timer()
    if (ep % epoch_step)==0 or (ep==(epochs-1)):
        print('epoch \t\t t2-t1 \t\t train-MSE-Error \t train-L2-Error \t test-L2-Error')
        print('{0:d}\t\t{1:.5f} \t\t {2:.5f}\t\t{3:.5f} \t {4:.5f}' \
              .format(ep, t2-t1, train_mse, train_l2, test_l2))
        print('\n')

# pred = torch.zeros(u_test.shape)
# index = 0
# test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(a_test, u_test),
#                                           batch_size=1, 
#                                           shuffle=False)
# with torch.no_grad():
#     for x, y in test_loader:
#         test_l2 = 0
#         x, y = x.to(device), y.to(device)

#         out = model(x)
#         out = u_normalizer.decode(out)
#         pred[index] = out

#         test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
#         print(index, test_l2)
#         index = index + 1

training data resolution:  64
testing data resolution:  64
6558377


In [None]:
savefig = False
figpath = '../figs/burgers/loss/'
figname = 'ep' + str(epochs) + \
          '-lr' + str(learning_rate) + \
          '-decay' + str(gamma) + \
          '-trainRes' + str(train_res) + \
          '-testRes' + str(test_res)

figtitle = '1-D Burger\'s \n Learning Rate: ' + str(learning_rate) + '\n' + \
           'Decay Rate: ' + str(gamma) + ' per ' + str(step_size) + ' epochs'
plot_loss(figtitle = figtitle,
          figpath = figpath, 
          figname = figname,
          savefig = savefig, 
          train_MSE = train_MSE_Loss, 
          train_L2 = train_L2_Loss, 
          test_L2 = test_L2_Loss)