# Navier-Stokes Equation
## A 2+1 Dimensional Numerical Experiment of FNOs

This notebook walks through the Fourier Neural Operator for a 2D problem such as the Navier-Stokes discussed in Section 5.3 in the paper [Fourier Neural Operator for
Parametric Partial Differential Equations](https://arxiv.org/pdf/2010.08895.pdf) which uses a recurrent structure to propagates in time.

In [1]:
from typing import Any, Generic, NamedTuple, Optional, TypeVar

import yaml
import numpy as np
import matplotlib.pyplot as plt
from timeit import default_timer

from torch.nn.functional import mse_loss
import torch

from neuralop import count_params
from neuralop.datasets import load_navier_stokes_temporal_pt
from neuralop.layers import SpectralConv2d
from neuralop.models import FNO2d
from neuralop.training import LpLoss

torch.manual_seed(0)
np.random.seed(0)

In [2]:
#################################################
# Utilities
#################################################
device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


In [3]:
################################################################
# fourier layer
################################################################
print(
"""
The overall network. It contains 4 layers of the Fourier layer.
1. Lift the input to the desire channel dimension by self.fc0 .
2. 4 layers of the integral operators u' = (W + K)(u).
    W defined by self.w; K defined by self.conv .
3. Project from the channel space to the output space by self.fc1 and self.fc2 .

input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)
input shape: (batchsize, x=64, y=64, c=12)
output: the solution of the next timestep
output shape: (batchsize, x=64, y=64, c=1)
"""
)


The overall network. It contains 4 layers of the Fourier layer.
1. Lift the input to the desire channel dimension by self.fc0 .
2. 4 layers of the integral operators u' = (W + K)(u).
    W defined by self.w; K defined by self.conv .
3. Project from the channel space to the output space by self.fc1 and self.fc2 .

input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)
input shape: (batchsize, x=64, y=64, c=12)
output: the solution of the next timestep
output shape: (batchsize, x=64, y=64, c=1)



In [26]:
class Config(NamedTuple):
    train_path: str
    test_path: str
    n_train: int
    n_test: int
    train_batch_size: int
    test_batch_size: int
    learning_rate: float
    epochs: int
    iterations: int
    modes: int
    width: int
    subsampling_rate: int
    s: int
    history_length: int
    future_duration: int
    step: int
        
    @staticmethod
    def from_yaml(config_path: str):
        with open(config_path, 'r') as f:
            cfg = yaml.load(f)

        config = Config(
            train_path=cfg['train_path'],
            test_path=cfg['test_path'],
            n_train=cfg['n_train'],
            n_test=cfg['n_test'],
            train_batch_size=cfg['train_batch_size'],
            test_batch_size=cfg['test_batch_size'],
            learning_rate=cfg['learning_rate'],
            epochs=cfg['epochs'],
            iterations=cfg['iterations'],
            modes=cfg['modes'],
            width=cfg['width'],
            subsampling_rate=cfg['subsampling_rate'],
            s=cfg['s'],
            history_length=cfg['history_length'],
            future_duration=cfg['future_duration'],
            step=cfg['step'],
        )        
        return config


config = Config.from_yaml('fourier_2d_time_V1e-3.yaml')

config.epochs: 500


In [6]:
################################################################
# load data and data normalization
################################################################

# Assumes there exist training and testing data .pt files at
# `neuralop/data/ns_data_V100_N1000_T50_1.pt`, and
# `neuralop/data/ns_data_V100_N1000_T50_1.pt`
train_loader, test_loader, output_encoder = load_navier_stokes_temporal_pt(
    config.train_path, # Currently, the same path is used for both training and testing data.
    config.n_train,
    config.n_test,
    config.history_length,
    config.future_duration,
    config.train_batch_size,
    config.test_batch_size,
)

UnitGaussianNormalizer init on 1000, reducing over [0, 1, 2, 3], samples of shape [10, 64, 64].
   Mean and std of shape torch.Size([1, 1, 1]), eps=1e-05
UnitGaussianNormalizer init on 1000, reducing over [0, 1, 2, 3], samples of shape [40, 64, 64].
   Mean and std of shape torch.Size([1, 1, 1]), eps=1e-05


In [7]:
model = FNO2d(
    config.modes,  # modes_width
    config.modes,  # modes_height
    config.width,  # width of all hidden layers
    # input channels are 12: the solution of the previous 10 timesteps + 2 location encodings
    # i.e: (u(t-10, x, y), ..., u(t-1, x, y),  x, y)
    in_channels=2 + config.history_length,
    out_channels=1,  # output channel is 1: u(t, x, y)
    n_layers=4,
    # domain_padding=8,
    domain_padding=None,
    # domain_padding_mode='one-sided',
    use_mlp=True,
).cuda()
print(f"Model parameter count: {count_params(model):,d}")

Model parameter count: 239773


In [None]:
################################################################
# training and evaluation
################################################################
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config.iterations
)
    
lp_loss = LpLoss()  # By default, does not do size averaging
output_encoder.cuda()
for ep in range(config.epochs):
    model.train()
    t1 = default_timer()
    train_l2_step = 0
    train_l2_full = 0
    for train_data in train_loader:
        # print({k: v.shape for k, v in train_data.items()})
        xx = train_data['x'].to(device)
        yy = train_data['y'].to(device)
        loss = 0

        # For each time step
        for t in range(0, config.future_duration, config.step):
            # We want the operator to learn the next S time steps:
            y = yy[:, t:t + config.step, ...]
            # use recurrent structure to propagate in time:
            im = model(xx)
            # print('y.shape=', y.shape)
            # print('im.shape=', im.shape)
            loss += lp_loss(
                im.reshape(config.train_batch_size, -1),
                y.reshape(config.train_batch_size, -1))

            if t == 0:
                pred = im
            else:
                pred = torch.cat((pred, im), -1)

            # Advance the recurrent input by one time step; i.e:
            # [t_n, ..., t_{n+m}] --> [t_{n+s}, ..., t_{n+s+m}]
            # for starting time N, future duraiton M, and time step S

            # xx0 = xx[..., config.step:-2, :, :]  # [t_{n+s}, ..., t_{n+m}]
            # xx1 = im                             # [t_{n+m+1}, ..., t_{n+m+s}]
            # xx2 = xx[..., -2:, :, :]              # positional encoding
            # print('xx0.shape=', xx0.shape, '[t_{n+s}, ..., t_{n+m}]')
            # print('xx1.shape=', xx1.shape, '[t_{n+m+1}, ..., t_{n+m+s}]')
            # print('xx2.shape=', xx2.shape, '(positional encoding)')
            xx = torch.cat((
                xx[..., config.step:-2, :, :],  # [t_{n+s}, ..., t_{n+m}]
                im,                             # [t_{n+m+1}, ..., t_{n+m+s}]
                xx[..., -2:, :, :]              # positional encoding
            ), dim=1)

        # Compare the full, operator-predicted future with the ground truth.
        # This is not used in training, but surfaces how well 
        # the operator has done overall in this epoch.
        train_l2_step += loss.item()
        l2_full = lp_loss(
            pred.reshape(config.train_batch_size, -1), 
            yy.reshape(config.train_batch_size, -1))
        train_l2_full += l2_full.item()

        # Backprop based only on the loss for each step-by-step in time.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    model.eval()
    test_l2_step = 0
    test_l2_full = 0
    with torch.no_grad():
        for test_data in test_loader:
            xx = test_data['x'].to(device)
            yy = test_data['y'].to(device)
            loss = 0

            for t in range(0, config.future_duration, config.step):
                y = yy[:, t:t + config.step, ...]
                im = model(xx)
                loss += lp_loss(
                    im.reshape(config.test_batch_size, -1), 
                    y.reshape(config.test_batch_size, -1))

                if t == 0:
                    pred = im
                else:
                    pred = torch.cat((pred, im), -1)

                xx = torch.cat((
                    xx[..., config.step:-2, :, :],  # [t_{n+s}, ..., t_{n+m}]
                    im,                             # [t_{n+m+1}, ..., t_{n+m+s}]
                    xx[..., -2:, :, :]              # positional encoding
                ), dim=1)

            test_l2_step += loss.item()
            test_l2_full += lp_loss(
                pred.reshape(config.test_batch_size, -1), 
                yy.reshape(config.test_batch_size, -1)).item()

    t2 = default_timer()
    print(
        f'Epoch # {ep:03d} / {config.epochs}',
        f'Duration:           {t2 - t1:9.5f}',
        f'Training L2 (step): {train_l2_step / config.n_train / (config.future_duration / config.step):9.5f}',
        f'Training L2 (full): {train_l2_full / config.n_train:9.5f}',
        f'Testing L2 (step):  {test_l2_step / config.n_test / (config.future_duration / config.step):9.5f}',
        f'Testing L2 (full):  {test_l2_full / config.n_test:9.5f}',
        '=' * 32,
        sep='\n',
    )

Epoch # 000 / 500
Duration:      25.06834
Training L2 (step):   0.55194
Training L2 (full):   1.29020
Testing L2 (step):    0.53796
Testing L2 (full):    1.33653
Epoch # 001 / 500
Duration:      24.81722
Training L2 (step):   0.45042
Training L2 (full):   1.33538
Testing L2 (step):    0.35723
Testing L2 (full):    1.41073
Epoch # 002 / 500
Duration:      24.78823
Training L2 (step):   0.28598
Training L2 (full):   1.39070
Testing L2 (step):    0.26935
Testing L2 (full):    1.45236
Epoch # 003 / 500
Duration:      25.47932
Training L2 (step):   0.23150
Training L2 (full):   1.39994
Testing L2 (step):    0.22702
Testing L2 (full):    1.44580
Epoch # 004 / 500
Duration:      25.01301
Training L2 (step):   0.19770
Training L2 (full):   1.40148
Testing L2 (step):    0.20557
Testing L2 (full):    1.45377
Epoch # 005 / 500
Duration:      25.07817
Training L2 (step):   0.17342
Training L2 (full):   1.40091
Testing L2 (step):    0.17774
Testing L2 (full):    1.43436
Epoch # 006 / 500
Duration: 

## Visualization

Visualize the error in the trained model against a subsample of testing data points (i.e. field values from time t=0 to be mapped to t=1). Also visualize the error (squared to be non-negative).

In [None]:
pass
# test_samples = test_loader.dataset
# n_rows = 3
# n_cols = 5

# fig = plt.figure(figsize=(13,  # width (inches)
#                           9))  # height (inches)
# for index in range(n_rows):
#     data = test_samples[index]
#     # Input x
#     x = data['x'].cuda()
#     # Ground-truth
#     y = data['y'].cuda()
#     # Model prediction
#     out = model(x.unsqueeze(0))
#     error = (out - y).square()
#     error2 = (out / 2 - y).square()
#     vmin = min(y.min(), out.min(), error.min(), error2.min())
#     vmax = max(y.max(), out.max(), error.max(), error2.max())

#     ax1 = fig.add_subplot(n_rows, n_cols, index * n_cols + 1)
#     im1 = ax1.imshow(x[0].cpu(), cmap='gray')
#     if index == 0:
#         ax1.set_title('Input x')
#     plt.xticks([], [])
#     plt.yticks([], [])
#     # fig.colorbar(im1, ax=ax1)

#     ax2 = fig.add_subplot(n_rows, n_cols, index * n_cols + 2)
#     im2 = ax2.imshow(y.squeeze().cpu(), cmap='magma', vmin=vmin, vmax=vmax)
#     if index == 0:
#         ax2.set_title('Ground-truth y')
#     plt.xticks([], [])
#     plt.yticks([], [])
#     fig.colorbar(im2, ax=ax2)

#     ax3 = fig.add_subplot(n_rows, n_cols, index * n_cols + 3)
#     im3 = ax3.imshow(
#         out.squeeze().detach().cpu(),
#         cmap='magma',
#         vmin=vmin,
#         vmax=vmax
#     )
#     if index == 0:
#         ax3.set_title('Model prediction')
#     plt.xticks([], [])
#     plt.yticks([], [])
#     fig.colorbar(im3, ax=ax3)

#     ax4 = fig.add_subplot(n_rows, n_cols, index * n_cols + 4)
#     im4 = ax4.imshow(
#         error.squeeze().detach().cpu(),
#         cmap='magma',
#         vmin=vmin,
#         vmax=vmax
#     )
#     if index == 0:
#         ax4.set_title('Error')
#     plt.xticks([], [])
#     plt.yticks([], [])
#     fig.colorbar(im4, ax=ax4)
    
#     ax5 = fig.add_subplot(n_rows, n_cols, index * n_cols + 5)
#     im5 = ax5.imshow(
#         error2.squeeze().detach().cpu(),
#         cmap='magma',
#         vmin=vmin,
#         vmax=vmax
#     )
#     if index == 0:
#         ax5.set_title('Error (2)')
#     plt.xticks([], [])
#     plt.yticks([], [])
#     fig.colorbar(im5, ax=ax5)
    
    
# fig.suptitle('Inputs, ground-truth output and prediction.', y=0.98)
# plt.tight_layout()
# fig.show()