# Fourier Neural Operator for 1-D Burger's Equation 
<!-- @script made by: Kai Chang -->

This script takes a deep look into the dataset for 1-dimensional Burger's equation used in [[Li et al., 2020]](https://arxiv.org/abs/2010.08895) and walks through the implementation of 1-D fourier neural operator. 

## Problem Statement
We intend to solve the equation

\begin{align}
\partial_t u(x,t) + \partial_x\left(u^2(x,t)/2\right) &= \nu\partial_{xx}u(x,t)\quad & x\in(0,1),t\in(0,1]\\
u(x,0) &= u_0(x) & x\in(0,1)
\end{align}

with periodic boundary conditions where $u_0 \in L^2\left((0,1);\mathbb{R}\right)$ is the initial conditio and $\nu\in \mathbb{R}_+$ is the viscosity coefficient. 

## Goal:
### To learn a map from the space of IC functions to the solution at time 1; i.e. $u_0(x) \mapsto u(x,1)$.


# Dataset

Let's first take a look at the dataset to get us better understand some of the technicalities in the implementation. It contains the solution to 1-d Burger's equation with different initial conditions. The dataset is generated by classical PDE solvers. The raw data is a dictionary. In neural operator training, `'a'` corresponds to the input and `'u'` is the ground truth. 

In this case, the domain of the map we want to approximate is the space of initial conditions; the codomain is the space of the solutions. 

In [1]:
import torch
import numpy as np
from utilities import *
import os

# read raw data
file_path = "../data/Burgers/burgers_data_R10.mat"
# file_path = "../data/Burgers/burgers_v100_t100_r1024_N2048.mat"
dataloader = MatReader(file_path)
print(dataloader.data.keys())

dict_keys(['__header__', '__version__', '__globals__', 'a', 'a_smooth', 'a_smooth_x', 'a_x', 'u'])
dict_keys(['__header__', '__version__', '__globals__', 'a', 'a_smooth', 'a_smooth_x', 'a_x', 'u'])


In [2]:
# separate input and ground truth
a_data = dataloader.read_field('a')
u_data = dataloader.read_field('u')

We see the size of the data is 2048 x 8192. This means that we have 2048 training samples and the discretization size is 8192. We, however, are not provided the discretized space in the dataset, which apparently is essential for our learning task. Therefore, in the implementation of the architecture, we need to build that on our own. 

There is one more preprocessing step to do, namely preparing the training and testing data. It is claimed in the paper that FNO can achieve high accuracy on fine mesh even if it is trained on coarser mesh. To test this, we subsample the dataset and divide it into training and testing data.

In [3]:
sub_train = 2**3
sub_test = 2**2
n_train = 1000
n_test = 200
full_res = a_data.size(-1)
train_res = full_res // sub_train
test_res = full_res // sub_test

# sample data
a_train = a_data[:n_train,::sub_train].reshape(n_train,train_res,1)
u_train = u_data[:n_train,::sub_train]
a_test = a_data[-n_test:,::sub_test].reshape(n_test,test_res,1)
u_test = u_data[-n_test:,::sub_test]
print(a_train.size())
print(a_test.size())

torch.Size([1000, 1024, 1])
torch.Size([200, 2048, 1])
torch.Size([1000, 1024, 1])
torch.Size([200, 2048, 1])


## FNO Implementation

Now we take a look at the implementation of the architecture.

To begin with, we need the basic building block in a FNO; that is, a fourier layer.

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
from Adam import Adam
import operator
from functools import reduce
from functools import partial
from timeit import default_timer
import matplotlib.pyplot as plt
import matplotlib as mpl
from plot_lib import plt_set_default, plot_loss
import random
from itertools import cycle

plt_set_default()

In [5]:
# fourier layer

class FourierLayer1d(nn.Module):
    """
    1D Fourier layer
        
    ->input => FFT => multiplied by a (kernel) matrix => inverse FFT => add Wv_t => output
        
    """
    
    def __init__(self, in_channels, out_channels, modes1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  # Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = Parameter(self.scale * torch.rand(in_channels,
                                                          out_channels,
                                                          self.modes1,
                                                          dtype=torch.cfloat))
        
        self.w = nn.Conv1d(in_channels, out_channels, 1)

    # Complex multiplication
    def compl_mul1d(self, input_, weights):
        """
        (batch, in_channel, resolution), (in_channel, out_channel, resolution) -> (batch, out_channel, resolution)
        
        implemented with the Einstein summation
        values along i-axis are multiplied and summed up
        
        """
        return torch.einsum("bix,iox->box", input_, weights)

    def forward(self, x):
        """
        input size: (batch, in_channel, resolution)
        """
        batch_size, channel_size, resolution = x.size()
        
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        # the first operation in the fourier layer
        # map the input to the fourier domain
        x_ft = torch.fft.rfft(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batch_size,
                             self.out_channels,
                             resolution//2 + 1,
                             device=x.device, dtype=torch.cfloat)
        # the reason of out_ft.size(2) = resolution//2 + 1 is because here we are using rfft
        # which takes into account the fact that the DFT of real modes is Hermitian
        # Therefore, only half of the entries needs to be stored.
        # This is also the reason that we are only doing multiplication below once instead of twice
        out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)

        #Return to physical space
        x1 = torch.fft.irfft(out_ft, n=resolution)
        x2 = self.w(x)
        return x1 + x2

Now we are ready to implement the whole architecture. Details including how the size changes are in the comments.

In [6]:
class FNO1d(nn.Module):
    """
    FNO for 1-d Burger's equation (or parametric differential equations)
    
    It contains 
        1. A fully connected layer that lifts the input to the desire channel dimension
        2. 4 layers of the integral operators u' = (W + K)(u).
           W defined by self.w; K defined by self.conv
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
    input: the value of the initial condition function and spatial parameter (a(x), x)
    input shape: (batchsize, resolution, c=2) 
                  c is the channel size
    output: the solution of a later timestep
    output shape: (batchsize, x=s, c=1) x is the discretization size
    """
    def __init__(self, modes, width):
        super().__init__()
        self.modes1 = modes
        self.width = width
        self.padding = 2 # pad the domain if input is non-periodic
        self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x)

        self.conv0 = FourierLayer1d(self.width, self.width, self.modes1)
        self.conv1 = FourierLayer1d(self.width, self.width, self.modes1)
        self.conv2 = FourierLayer1d(self.width, self.width, self.modes1)
        self.conv3 = FourierLayer1d(self.width, self.width, self.modes1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        """
        input shape: (batch, resolution, 1)
        """
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        # now the shape becomes (batch, resolution, 2)
        
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        # now the shape becomes (batch, 2, resolution)
        
        # x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic

        h1 = self.conv0(x)
        x = F.gelu(h1)

        h2 = self.conv1(x)
        x = F.gelu(h2)

        h3 = self.conv2(x)
        x = F.gelu(h3)

        h4 = self.conv3(x)
        x = h4

        # x = x[..., :-self.padding] # pad the domain if input is non-periodic
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        """
        retrieve the spatial domain
        
        """
        batchsize, resolution = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, 1, resolution), dtype=torch.float)
        gridx = gridx.reshape(1, resolution, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)

<!-- One may find it tedious to trace the size change in the code. We provide the code snippet below to demonstrate some of the details. 


    batch_size_, discretization_size_ = a_train.size()
    grid_x_ = torch.tensor(np.linspace(0,1,discretization_size_),
                           dtype = torch.float)
    grid_x_ = grid_x_.reshape(1,discretization_size_,1)
    print(grid_x_.size())
    grid_x_ = grid_x_.repeat([batch_size_,1,1])
    print(grid_x_.size())
    xx_ = torch.cat((a_train.reshape(batch_size_,discretization_size_,1),grid_x_),dim=-1)
    print(xx_.size())

 -->

# Training and Testing

### Input: initial condition $u_0(x)$ evaluated at some discretized grids of the spatial domain and the grid
### Output: solution at time step 1, i.e. $u(x,1)$ evaluated at some discretized grids (might be different from that of the input) of the spatial domain

In [7]:
file_path = "../data/Burgers/burgers_data_R10.mat"
# file_path = "./data/Burgers/burgers_v100_t100_r1024_N2048.mat"
dataloader = MatReader(file_path)
a_data = dataloader.read_field('a')
u_data = dataloader.read_field('u')

sub_train = 2**4
sub_test = 2**3
full_res = a_data.size(-1)
train_res = full_res // sub_train
test_res = full_res // sub_test
n_train = 1000
n_test = 200

# sample data
a_train = a_data[:n_train,::sub_train].reshape(n_train,train_res,1)
u_train = u_data[:n_train,::sub_train]
a_test = a_data[-n_test:,::sub_test].reshape(n_test,test_res,1)
u_test = u_data[-n_test:,::sub_test]
print(a_train.size())
print(a_test.size())

torch.Size([1000, 512, 1])
torch.Size([200, 1024, 1])
torch.Size([1000, 512, 1])
torch.Size([200, 1024, 1])


In [8]:
# training and testing
%matplotlib notebook

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

###################################################
################## model configs ##################
###################################################

print('training data resolution: ', train_res)
print('testing data resolution: ', test_res)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 20
learning_rate = 0.001
epochs = 500

# for scheduler's use
# learning rate decays by a half for every 50 epochs
step_size = 100
gamma = 0.5

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(a_train, u_train), 
                                           batch_size=batch_size, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(a_test, u_test), 
                                          batch_size=20, 
                                          shuffle=False)
modes = 16
width = 64 # size of hidden layers
model = FNO1d(modes, width).to(device)
print("count_params(model): ", count_params(model))

optimizer = Adam(model.parameters(), 
                 lr=learning_rate, 
                 weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=step_size, 
                                            gamma=gamma)

myLoss = LpLoss(d=1, p=2, size_average=False) # default is L2 relative error

epoch_step = 50

train_MSE_Loss = []
train_L2_Loss = []
test_L2_Loss = []

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, '\t', param.data.shape)

training data resolution:  512
testing data resolution:  1024
training data resolution:  512
testing data resolution:  1024
count_params(model):  549569
fc0.weight 	 torch.Size([64, 2])
fc0.bias 	 torch.Size([64])
conv0.weights1 	 torch.Size([64, 64, 16])
conv0.w.weight 	 torch.Size([64, 64, 1])
conv0.w.bias 	 torch.Size([64])
conv1.weights1 	 torch.Size([64, 64, 16])
conv1.w.weight 	 torch.Size([64, 64, 1])
conv1.w.bias 	 torch.Size([64])
conv2.weights1 	 torch.Size([64, 64, 16])
conv2.w.weight 	 torch.Size([64, 64, 1])
conv2.w.bias 	 torch.Size([64])
conv3.weights1 	 torch.Size([64, 64, 16])
conv3.w.weight 	 torch.Size([64, 64, 1])
conv3.w.bias 	 torch.Size([64])
fc1.weight 	 torch.Size([128, 64])
fc1.bias 	 torch.Size([128])
fc2.weight 	 torch.Size([1, 128])
fc2.bias 	 torch.Size([1])
count_params(model):  549569
fc0.weight 	 torch.Size([64, 2])
fc0.bias 	 torch.Size([64])
conv0.weights1 	 torch.Size([64, 64, 16])
conv0.w.weight 	 torch.Size([64, 64, 1])
conv0.w.bias 	 torch.Size([6

In [9]:
####################################################
################# plotting configs #################
####################################################

# change sample by changing rows and cols
rows,cols = 2,2
num_plots = rows * cols

np.random.seed(0)
sampled_batch_idx = random.sample(range(n_test//batch_size),num_plots)
sampled_plotting_idx = np.random.randint(batch_size, size = num_plots)
fig = plt.figure(figsize=(16, 16))

# config for saving figures
saveFig = False
fig_folder = 'ep' + str(epochs) + \
             '-lr' + str(learning_rate) + \
             '-decay' + str(gamma) + \
             '-trainRes' + str(train_res) + \
             '-testRes' + str(test_res)

fig_dir = '../figs/burgers/' + fig_folder
if saveFig:
    if not os.path.isdir(fig_dir):
        os.makedirs(fig_dir)

xvals = np.linspace(0,1,a_data.size(-1))
xvals_test = xvals[::sub_test]
xvals_train = xvals[::sub_train]
plot_pos = 1

for ep in range(epochs):
    fig.clear()
    fig.suptitle('epoch: #' + str(ep),fontsize = 50)
    
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for batch, (a, u) in enumerate(train_loader):
        a, u = a.to(device), u.to(device) 
        # a.shape = (batch, train_res, 1)
        # u.shape = (batch, test_res)
        # the mismatch of dimensions does not matter here because of the reshaping below

        optimizer.zero_grad()
        out = model(a)
        # out.shape = (batch, test_res, 1)

        mse = F.mse_loss(out.view(batch_size, -1), u.view(batch_size, -1), reduction='mean')
        l2 = myLoss(out.view(batch_size, -1), u.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()
        
    if (ep % epoch_step)==0 or (ep==(epochs-1)):
        t2 = default_timer()
    
    scheduler.step()
    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for batch, (a, u) in enumerate(test_loader):
            a, u = a.to(device), u.to(device)
            # a.shape = (batch, resolution, 1)
            # u.shape = (batch, resolution)
            
            out = model(a)
            test_l2 += myLoss(out.view(batch_size, -1), u.view(batch_size, -1)).item()
            
            ### plot sampled testing function ###
            
            if batch in sampled_batch_idx:
                idx = sampled_plotting_idx[plot_pos-1]
                ax = fig.add_subplot(rows,cols,plot_pos)
                ax.clear()
#                 ax.get_xaxis().set_visible(False)
                ax.set_ylim((-1.2,1.2))
                u_sample = out[idx,:,:].reshape((-1,1))
                u_sample = u_sample.detach().cpu().numpy()
                u_truth = u[idx,:].reshape((-1,1))
                u_truth = u_truth.detach().cpu().numpy()
                ax.plot(xvals_test, u_sample, 'r', label='FNO1d')
                ax.plot(xvals_test, u_truth, 'b--', label='truth')
                ax.set_title(f"sample #{plot_pos}")
                ax.set_xlabel('x')
                ax.set_ylabel('u(x,1)')
                ax.legend()
                plot_pos += 1
        fig.canvas.draw()
        if saveFig:
            if fig_dir[-1] != '/':
                fig_dir = fig_dir + '/'
            fig_name = fig_dir + 'epoch-' + str(ep) + '.png'
            fig.savefig(fig_name)
        plot_pos = 1

    train_mse /= len(train_loader)
    train_MSE_Loss.append(train_mse)
    
    train_l2 /= n_train
    train_L2_Loss.append(train_l2)
    
    test_l2 /= n_test
    test_L2_Loss.append(test_l2)

    if (ep % epoch_step)==0 or (ep==(epochs-1)):
        print('epoch \t\t t2-t1 \t\t train-MSE-Error \t train-L2-Error \t test-L2-Error')
        print('{0:d}\t\t{1:.5f} \t\t {2:.5f}\t\t{3:.5f} \t {4:.5f}' \
              .format(ep, t2-t1, train_mse, train_l2, test_l2))
        print('\n')
#         print(f'epoch #{ep}')
#         print(f"time since last checked: {t2-t1}")
#         print(f'MSE Training Error: {train_mse}')
#         print(f'Relative L2 Training Error: {train_l2}')
#         print(f'Relative L2 Testing Error: {test_l2}')

# torch.save(model, 'model/ns_fourier_burgers')
pred = torch.zeros(u_test.shape)
index = 0
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(a_test, u_test), 
                                          batch_size=1, 
                                          shuffle=False)
# with torch.no_grad():
#     for x, y in test_loader:
#         test_l2 = 0
#         x, y = x.to(device), y.to(device)

#         out = model(x).view(-1)
#         pred[index] = out

#         test_l2 += myLoss(out.view(1, -1), y.view(1, -1)).item()
#         print("index, test_l2")
#         print(index, test_l2)
#         index = index + 1
print('FINISHED!')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

epoch 		 t2-t1 		 train-MSE-Error 	 train-L2-Error 	 test-L2-Error
0		1.49144 		 0.05360		0.32270 	 0.05408


epoch 		 t2-t1 		 train-MSE-Error 	 train-L2-Error 	 test-L2-Error
0		1.49144 		 0.05360		0.32270 	 0.05408




KeyboardInterrupt: 

KeyboardInterrupt: 

In [None]:
# sub_test = 2**5
# test_res = full_res // sub_test
# a_test = a_data[-n_test:,::sub_test].reshape(n_test,test_res,1)
# u_test = u_data[-n_test:,::sub_test]
# test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(a_test,u_test),
#                                           batch_size = 1, 
#                                           shuffle = False)
# with torch.no_grad():

In [None]:
savefig = False
figpath = '../figs/burgers/loss/'
figname = 'ep' + str(epochs) + \
          '-lr' + "".join(str(learning_rate).split('.')) + \
          '-decay' + "".join(str(gamma).split('.')) + \
          '-trainRes' + str(train_res) + \
          '-testRes' + str(test_res)

figtitle = '1-D Burger\'s \n Learning Rate: ' + str(learning_rate) + \
           '; train_res: ' + str(train_res) + '; test_res: ' + str(test_res) + '\n' + \
           'Decay Rate: ' + str(gamma) + ' per ' + str(step_size) + ' epochs'
plot_loss(figtitle = figtitle,
          figpath = figpath, 
          figname = figname,
          savefig = savefig, 
          train_MSE = train_MSE_Loss, 
          train_L2 = train_L2_Loss, 
          test_L2_res = test_L2_Loss)