# Solve Navier-Stokes equation based on Fourier Neural Operator

## Overview

Computational fluid dynamics is one of the most important techniques in the field of fluid mechanics in the 21st century. The flow analysis, prediction and control can be realized by solving the governing equations of fluid mechanics by numerical method. Traditional finite element method (FEM) and finite difference method (FDM) are inefficient because of the complex simulation process (physical modeling, meshing, numerical discretization, iterative solution, etc.) and high computing costs. Therefore, it is necessary to improve the efficiency of fluid simulation with AI.

Machine learning methods provide a new paradigm for scientific computing by providing a fast solver similar to traditional methods. Classical neural networks learn mappings between finite dimensional spaces and can only learn solutions related to specific discretizations. Different from traditional neural networks, Fourier Neural Operator (FNO) is a new deep learning architecture that can learn mappings between infinite-dimensional function spaces. It directly learns mappings from arbitrary function parameters to solutions to solve a class of partial differential equations.  Therefore, it has a stronger generalization capability. More information can be found in the paper, [Fourier Neural Operator for Parametric Partial Differential Equations](https://arxiv.org/abs/2010.08895).

This tutorial describes how to solve the Navier-Stokes equation using Fourier neural operator.

## Navier-Stokes equation

Navier-Stokes equation is a classical equation in computational fluid dynamics. It is a set of partial differential equations describing the conservation of fluid momentum, called N-S equation for short. Its vorticity form in two-dimensional incompressible flows is as follows:

$$
\partial_t w(x, t)+u(x, t) \cdot \nabla w(x, t)=\nu \Delta w(x, t)+f(x), \quad x \in(0,1)^2, t \in(0, T]
$$

$$
\nabla \cdot u(x, t)=0, \quad x \in(0,1)^2, t \in[0, T]
$$

$$
w(x, 0)=w_0(x), \quad x \in(0,1)^2
$$

where $u$ is the velocity field, $w=\nabla \times u$ is the vorticity, $w_0(x)$ is the initial vorticity, $\nu$ is the viscosity coefficient, $f(x)$ is the forcing function.


In [11]:
import os
import numpy as np

from mindspore import nn, context, Tensor, set_seed
from mindspore import DynamicLossScaleManager, LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint


In [12]:
from mindflow import FNO2D, RelativeRMSELoss, Solver, load_yaml_config, get_warmup_cosine_annealing_lr

from src import PredictCallback, create_training_dataset

set_seed(0)
np.random.seed(0)

In [13]:
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target='GPU')

In [14]:
config = load_yaml_config('navier_stokes_2d.yaml')
data_params = config["data"]
model_params = config["model"]
optimizer_params = config["optimizer"]
callback_params = config["callback"]

## Training Dataset Construction

Download the training and test dataset: [data_driven/navier_stokes/dataset](https://download.mindspore.cn/mindscience/mindflow/dataset/applications/data_driven/navier_stokes/dataset/) .

In this case, random sampling is performed according to the solution domain, initial condition and boundary value condition to generate training data sets. The specific settings are as follows:

In [15]:
train_dataset = create_training_dataset(data_params, input_resolution=model_params["input_resolution"], shuffle=True)
test_input = np.load(os.path.join(data_params["path"], "test/inputs.npy"))
test_label = np.load(os.path.join(data_params["path"], "test/label.npy"))

Data preparation finished


## Model Construction

This example uses a simple fully-connected network with a depth of 6 layers and the activation function is the `tanh` function.

In [16]:
model = FNO2D(in_channels=model_params["in_channels"],
              out_channels=model_params["out_channels"],
              resolution=model_params["input_resolution"],
              modes=model_params["modes"],
              channels=model_params["width"],
              depth=model_params["depth"]
              )

## Optimizer and loss function

In [17]:
steps_per_epoch = train_dataset.get_dataset_size()
lr = get_warmup_cosine_annealing_lr(lr_init=optimizer_params["initial_lr"],
                                    last_epoch=optimizer_params["train_epochs"],
                                    steps_per_epoch=steps_per_epoch,
                                    warmup_epochs=optimizer_params["warmup_epochs"])

optimizer = nn.Adam(model.trainable_params(), learning_rate=Tensor(lr))
loss_scale = DynamicLossScaleManager()

# prepare loss function
loss_fn = RelativeRMSELoss()

## Define Solver

In [18]:
solver = Solver(model,
                optimizer=optimizer,
                loss_scale_manager=loss_scale,
                loss_fn=loss_fn,
                )

## Define Callback

In [19]:
summary_dir = os.path.join(callback_params["summary_dir"], 'FNO2D')
print(summary_dir)
pred_cb = PredictCallback(model=model,
                          inputs=test_input,
                          label=test_label,
                          config=callback_params,
                          summary_dir=summary_dir)

ckpt_config = CheckpointConfig(save_checkpoint_steps=callback_params["save_checkpoint_steps"] * steps_per_epoch,
                               keep_checkpoint_max=callback_params["keep_checkpoint_max"])
ckpt_dir = os.path.join(summary_dir, "ckpt")
ckpt_cb = ModelCheckpoint(prefix=model_params["name"],
                          directory=ckpt_dir,
                          config=ckpt_config)

/data5/hyzhou/MindFlowtmp/FNO2D
check test dataset shape: (200, 19, 64, 64, 1), (200, 19, 64, 64, 1)


## Model Training

Invoke the Solver interface for model training and callback interface for evaluation.

In [20]:
solver.train(epoch=optimizer_params["train_epochs"],
             train_dataset=train_dataset,
             callbacks=[LossMonitor(), TimeMonitor(), pred_cb, ckpt_cb],
             dataset_sink_mode=True)

epoch: 1 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 2.07)
Train epoch time: 36526.785 ms, per step time: 36.527 ms
epoch: 2 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 2.00379)
Train epoch time: 29215.492 ms, per step time: 29.215 ms
epoch: 3 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.40253)
Train epoch time: 29217.016 ms, per step time: 29.217 ms
epoch: 4 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.79683)
Train epoch time: 29243.756 ms, per step time: 29.244 ms
epoch: 5 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.42917)
Train epoch time: 29197.400 ms, per step time: 29.197 ms
epoch: 6 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.24265)
Train epoch time: 29199.672 ms, per step time: 29.200 ms
epoch: 7 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.48525)
Train epoch time: 29193.341 ms, per step time: 29.193 ms
epoch: 8 step: 1000, loss is Tensor(shape=[], dtype=Float32, valu

epoch: 11 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.960248)
Train epoch time: 29194.409 ms, per step time: 29.194 ms
epoch: 12 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.889677)
Train epoch time: 29215.664 ms, per step time: 29.216 ms
epoch: 13 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.10265)
Train epoch time: 29377.794 ms, per step time: 29.378 ms
epoch: 14 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.909662)
Train epoch time: 29203.506 ms, per step time: 29.204 ms
epoch: 15 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.01007)
Train epoch time: 29546.233 ms, per step time: 29.546 ms
epoch: 16 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.919893)
Train epoch time: 29189.393 ms, per step time: 29.189 ms
epoch: 17 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.687654)
Train epoch time: 29205.885 ms, per step time: 29.206 ms
epoch: 18 step: 1000, loss is Tensor(shape=[], dty

epoch: 21 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.982546)
Train epoch time: 29358.283 ms, per step time: 29.358 ms
epoch: 22 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.06452)
Train epoch time: 29546.726 ms, per step time: 29.547 ms
epoch: 23 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.04574)
Train epoch time: 29571.793 ms, per step time: 29.572 ms
epoch: 24 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.06391)
Train epoch time: 30104.813 ms, per step time: 30.105 ms
epoch: 25 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.05743)
Train epoch time: 29641.186 ms, per step time: 29.641 ms
epoch: 26 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.681255)
Train epoch time: 29537.563 ms, per step time: 29.538 ms
epoch: 27 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.626822)
Train epoch time: 30246.987 ms, per step time: 30.247 ms
epoch: 28 step: 1000, loss is Tensor(shape=[], dtype

epoch: 31 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.941781)
Train epoch time: 29424.146 ms, per step time: 29.424 ms
epoch: 32 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.742477)
Train epoch time: 29319.169 ms, per step time: 29.319 ms
epoch: 33 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.978464)
Train epoch time: 29247.172 ms, per step time: 29.247 ms
epoch: 34 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.769914)
Train epoch time: 29275.349 ms, per step time: 29.275 ms
epoch: 35 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.942254)
Train epoch time: 29213.784 ms, per step time: 29.214 ms
epoch: 36 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.877314)
Train epoch time: 29201.405 ms, per step time: 29.201 ms
epoch: 37 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.06438)
Train epoch time: 29176.718 ms, per step time: 29.177 ms
epoch: 38 step: 1000, loss is Tensor(shape=[], dt

epoch: 41 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.674138)
Train epoch time: 29206.354 ms, per step time: 29.206 ms
epoch: 42 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.713916)
Train epoch time: 29222.487 ms, per step time: 29.222 ms
epoch: 43 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.739123)
Train epoch time: 29227.415 ms, per step time: 29.227 ms
epoch: 44 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.751716)
Train epoch time: 29210.103 ms, per step time: 29.210 ms
epoch: 45 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.753634)
Train epoch time: 29220.145 ms, per step time: 29.220 ms
epoch: 46 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.37903)
Train epoch time: 29242.208 ms, per step time: 29.242 ms
epoch: 47 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.657267)
Train epoch time: 29196.915 ms, per step time: 29.197 ms
epoch: 48 step: 1000, loss is Tensor(shape=[], dt

epoch: 51 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.827644)
Train epoch time: 29212.402 ms, per step time: 29.212 ms
epoch: 52 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.827747)
Train epoch time: 29630.607 ms, per step time: 29.631 ms
epoch: 53 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.756908)
Train epoch time: 29256.933 ms, per step time: 29.257 ms
epoch: 54 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.01792)
Train epoch time: 29282.588 ms, per step time: 29.283 ms
epoch: 55 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.04388)
Train epoch time: 29207.128 ms, per step time: 29.207 ms
epoch: 56 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.603058)
Train epoch time: 29221.402 ms, per step time: 29.221 ms
epoch: 57 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.890517)
Train epoch time: 30039.186 ms, per step time: 30.039 ms
epoch: 58 step: 1000, loss is Tensor(shape=[], dty

epoch: 61 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.87228)
Train epoch time: 30786.011 ms, per step time: 30.786 ms
epoch: 62 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.687325)
Train epoch time: 30764.955 ms, per step time: 30.765 ms
epoch: 63 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.60158)
Train epoch time: 29716.448 ms, per step time: 29.716 ms
epoch: 64 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.675408)
Train epoch time: 31927.397 ms, per step time: 31.927 ms
epoch: 65 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.692979)
Train epoch time: 29857.370 ms, per step time: 29.857 ms
epoch: 66 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.618027)
Train epoch time: 30595.117 ms, per step time: 30.595 ms
epoch: 67 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.09081)
Train epoch time: 29381.262 ms, per step time: 29.381 ms
epoch: 68 step: 1000, loss is Tensor(shape=[], dtyp

epoch: 71 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.780936)
Train epoch time: 29589.539 ms, per step time: 29.590 ms
epoch: 72 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.707816)
Train epoch time: 29268.130 ms, per step time: 29.268 ms
epoch: 73 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.833401)
Train epoch time: 29277.664 ms, per step time: 29.278 ms
epoch: 74 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.55429)
Train epoch time: 29261.874 ms, per step time: 29.262 ms
epoch: 75 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.728071)
Train epoch time: 29297.286 ms, per step time: 29.297 ms
epoch: 76 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.747971)
Train epoch time: 29268.882 ms, per step time: 29.269 ms
epoch: 77 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.880434)
Train epoch time: 29246.762 ms, per step time: 29.247 ms
epoch: 78 step: 1000, loss is Tensor(shape=[], dt

epoch: 81 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.11769)
Train epoch time: 29326.312 ms, per step time: 29.326 ms
epoch: 82 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.564208)
Train epoch time: 29442.480 ms, per step time: 29.442 ms
epoch: 83 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.799814)
Train epoch time: 29185.925 ms, per step time: 29.186 ms
epoch: 84 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.766812)
Train epoch time: 29329.175 ms, per step time: 29.329 ms
epoch: 85 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.728292)
Train epoch time: 29295.092 ms, per step time: 29.295 ms
epoch: 86 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.577499)
Train epoch time: 29306.544 ms, per step time: 29.307 ms
epoch: 87 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.612509)
Train epoch time: 29236.402 ms, per step time: 29.236 ms
epoch: 88 step: 1000, loss is Tensor(shape=[], dt

epoch: 91 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.690332)
Train epoch time: 29265.776 ms, per step time: 29.266 ms
epoch: 92 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.616666)
Train epoch time: 29263.663 ms, per step time: 29.264 ms
epoch: 93 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.814858)
Train epoch time: 29193.680 ms, per step time: 29.194 ms
epoch: 94 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.623177)
Train epoch time: 29343.511 ms, per step time: 29.344 ms
epoch: 95 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.11235)
Train epoch time: 29275.233 ms, per step time: 29.275 ms
epoch: 96 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.639179)
Train epoch time: 29253.858 ms, per step time: 29.254 ms
epoch: 97 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.518209)
Train epoch time: 29313.646 ms, per step time: 29.314 ms
epoch: 98 step: 1000, loss is Tensor(shape=[], dt

epoch: 101 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.584815)
Train epoch time: 29313.208 ms, per step time: 29.313 ms
epoch: 102 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.79604)
Train epoch time: 29314.559 ms, per step time: 29.315 ms
epoch: 103 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 1.02088)
Train epoch time: 29295.960 ms, per step time: 29.296 ms
epoch: 104 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.779254)
Train epoch time: 29297.620 ms, per step time: 29.298 ms
epoch: 105 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.686904)
Train epoch time: 29283.531 ms, per step time: 29.284 ms
epoch: 106 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.659496)
Train epoch time: 29310.753 ms, per step time: 29.311 ms
epoch: 107 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.628445)
Train epoch time: 29311.042 ms, per step time: 29.311 ms
epoch: 108 step: 1000, loss is Tensor(shape

epoch: 111 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.702142)
Train epoch time: 29402.377 ms, per step time: 29.402 ms
epoch: 112 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.684048)
Train epoch time: 29327.054 ms, per step time: 29.327 ms
epoch: 113 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.585024)
Train epoch time: 29266.951 ms, per step time: 29.267 ms
epoch: 114 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.662479)
Train epoch time: 29266.174 ms, per step time: 29.266 ms
epoch: 115 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.506824)
Train epoch time: 29254.582 ms, per step time: 29.255 ms
epoch: 116 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.649985)
Train epoch time: 29276.359 ms, per step time: 29.276 ms
epoch: 117 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.777153)
Train epoch time: 29271.347 ms, per step time: 29.271 ms
epoch: 118 step: 1000, loss is Tensor(sha

epoch: 121 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.657592)
Train epoch time: 29222.838 ms, per step time: 29.223 ms
epoch: 122 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.674226)
Train epoch time: 29266.756 ms, per step time: 29.267 ms
epoch: 123 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.758586)
Train epoch time: 29237.832 ms, per step time: 29.238 ms
epoch: 124 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.66191)
Train epoch time: 29264.856 ms, per step time: 29.265 ms
epoch: 125 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.572997)
Train epoch time: 29244.422 ms, per step time: 29.244 ms
epoch: 126 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.776778)
Train epoch time: 29309.235 ms, per step time: 29.309 ms
epoch: 127 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.857609)
Train epoch time: 29271.844 ms, per step time: 29.272 ms
epoch: 128 step: 1000, loss is Tensor(shap

epoch: 131 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.903385)
Train epoch time: 29186.732 ms, per step time: 29.187 ms
epoch: 132 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.668677)
Train epoch time: 29200.472 ms, per step time: 29.200 ms
epoch: 133 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.511895)
Train epoch time: 29201.250 ms, per step time: 29.201 ms
epoch: 134 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.727842)
Train epoch time: 29196.617 ms, per step time: 29.197 ms
epoch: 135 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.664106)
Train epoch time: 29188.297 ms, per step time: 29.188 ms
epoch: 136 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.782425)
Train epoch time: 29191.706 ms, per step time: 29.192 ms
epoch: 137 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.643643)
Train epoch time: 29192.865 ms, per step time: 29.193 ms
epoch: 138 step: 1000, loss is Tensor(sha

epoch: 141 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.667819)
Train epoch time: 29181.800 ms, per step time: 29.182 ms
epoch: 142 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.610858)
Train epoch time: 29203.687 ms, per step time: 29.204 ms
epoch: 143 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.616083)
Train epoch time: 29199.107 ms, per step time: 29.199 ms
epoch: 144 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.609115)
Train epoch time: 29302.156 ms, per step time: 29.302 ms
epoch: 145 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.518936)
Train epoch time: 29234.649 ms, per step time: 29.235 ms
epoch: 146 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.822775)
Train epoch time: 29228.318 ms, per step time: 29.228 ms
epoch: 147 step: 1000, loss is Tensor(shape=[], dtype=Float32, value= 0.802282)
Train epoch time: 29231.589 ms, per step time: 29.232 ms
epoch: 148 step: 1000, loss is Tensor(sha