# Meta_PINNs algorithm

## Environment Setup

This notebook requires **MindSpore version >= 2.0.0** to support new APIs including: *mindspore.jit, mindspore.jit_class, mindspore.data_sink*. Please check [MindSpore Installation](https://www.mindspore.cn/install/en) for details.

In addition, **MindFlow version >=0.1.0** is also required. If it has not been installed in your environment, please select the right version and hardware, then install it as follows.

In [None]:
mindflow_version = "0.1.0"  # update if needed
# GPU Comment out the following code if you are using NPU.
!pip uninstall -y mindflow-gpu
!pip install mindflow-gpu==$mindflow_version

# NPU Uncomment if needed.
# !pip uninstall -y mindflow-ascend
# !pip install mindflow-ascend==$mindflow_version

## Overview

In different physical application scenarios, the selection of appropriate PINNs loss functions still relies mainly on experience and manual design. In order to address the above issues, Apostolos F. Psarosa and others proposed the Meta-PINNs algorithm, which updates hyperparameters acting on the loss function during training through gradient descent, thus training a set of hyperparameters applicable to a class of related partial differential equations.


## Technical Path

The specific process for solving this problem is as follows:

1. Create the dataset.
2. Define the model and optimizer.
3. Define the forward propagation function and gradient function for the inner loop and outer loop.
4. Define the training steps for the inner loop and outer loop.
5. Set training parameters, such as learning rate and number of iterations.
6. Conduct training and evaluate the model at intervals between each outer loop.

## Importing Code Packages

In [None]:
import argparse
import os
import time

import numpy as np

from mindspore import context, nn, get_seed, set_seed, data_sink

from mindflow.cell import MultiScaleFCSequential
from mindflow.utils import load_yaml_config

The following `src` package can be downloaded at [research/meta_pinns/src](https://gitee.com/mindspore/mindscience/tree/master/MindFlow/applications/research/meta_pinns/src).

In [None]:
from src import create_train_dataset, create_problem, create_trainer, create_normal_params
from src import re_initialize_model, evaluate, plot_l2_error, plot_l2_comparison_error
from src import WorkspaceConfig, TrainerInfo

set_seed(0)
np.random.seed(0)

Configure the parameters, where --case is optional and can be set to "burgers", "l_burgers", "periodic_hill" or "cylinder_flow".

In [None]:
parser = argparse.ArgumentParser(description="meta-pinns")
parser.add_argument("--case", type=str, default="burgers", choices=["burgers", "l_burgers", "cylinder_flow", "periodic_hill"],
                    help="choose burgers")
parser.add_argument("--mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"],
                    help="Running in GRAPH_MODE OR PYNATIVE_MODE")
parser.add_argument("--device_target", type=str, default="Ascend", choices=["GPU", "Ascend"],
                    help="The target device to run, support 'Ascend', 'GPU'")
parser.add_argument("--device_id", type=int, default=0,
                    help="ID of the target device")
parser.add_argument("--config_file_path", type=str,
                    default="./configs/burgers.yaml")
input_args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE if input_args.mode.upper().startswith("GRAPH")
                    else context.PYNATIVE_MODE,
                    save_graphs=input_args.save_graphs,
                    save_graphs_path=input_args.save_graphs_path,
                    device_target=input_args.device_target,
                    device_id=input_args.device_id)
print(
    f"Running in {input_args.mode.upper()} mode, using device id: {input_args.device_id}.")
use_ascend = context.get_context(attr_key='device_target') == "Ascend"
print(use_ascend)
print("pid:", os.getpid())

Determine the equation to be trained and load the YAML file.

In [None]:
# load configurations
case_name = input_args.case
config = load_yaml_config(input_args.config_file_path)
model_config = config["model"]
test_config = config["meta_test"]
summary_config = config["summary"]
lamda_config = config["lamda"]
meta_train_config = config["meta_train"]
initial_lr = config["optimizer"]["initial_lr"]

## Create Dataset

In the case training, actual equation values are not used. Sampling of internal points, boundary points, and initial points within the equation is sufficient.

inner_train_dataset is used for training in the inner loop, and outer_train_dataset is used for training in the outer loop.

In [None]:
# create dataset
inner_train_dataset = create_train_dataset(
    case_name, config, get_seed() + 1)
outer_train_dataset = create_train_dataset(
    case_name, config, get_seed() + 2)

## Build Model

In this case, a simple fully connected network is used, and the network shape is determined by the YAML file.

In [None]:
model = MultiScaleFCSequential(in_channels=model_config["in_channels"],
                               out_channels=model_config["out_channels"],
                               layers=model_config["layers"],
                               neurons=model_config["neurons"],
                               residual=model_config["residual"],
                               act=model_config["activation"],
                               num_scales=1)

## Model Training

For **MindSpore >= 2.0.0**, the functional programming paradigm can be used to train neural networks.

In this case, meta-learning algorithms are employed to learn the weights of terms in the loss function during PINN training.

Additionally, model evaluation on unseen equations is performed at intervals between each outer loop.

In [None]:
lamda = lamda_config["initial_lamda"]
problem = create_problem(lamda, case_name, model, config)
inner_optimizer = nn.SGD(model.trainable_params(),
                         initial_lr)
outer_optimizer = nn.Adam(problem.get_params(),
                          initial_lr)

if use_ascend:
    from mindspore.amp import DynamicLossScaler, auto_mixed_precision
    loss_scaler = DynamicLossScaler(1024, 2, 100)
    auto_mixed_precision(model, 'O1')
else:
    loss_scaler = None

inner_trainer = create_trainer(TrainerInfo(case_name, model, inner_optimizer, problem,
                                           use_ascend, loss_scaler, config, False, True))
outer_trainer = create_trainer(TrainerInfo(case_name, model, outer_optimizer,
                                           problem, use_ascend, loss_scaler, config, True, True))

inner_train_step = inner_trainer.train_step
outer_train_step = outer_trainer.train_step

iteration_str = "iterations"
inner_iters = meta_train_config["inner_loop"][iteration_str]
outer_iters = meta_train_config["outer_loop"][iteration_str]

steps_per_epochs = inner_train_dataset.get_dataset_size()
inner_sink_process = data_sink(
    inner_train_step, inner_train_dataset, sink_size=1)
outer_sink_process = data_sink(
    outer_train_step, outer_train_dataset, sink_size=1)

lamda_min = lamda_config["lamda_min"]
lamda_max = lamda_config["lamda_max"]

used_lamda = [lamda_config["eva_lamda"]]
best_params = problem.get_params()
best_l2 = 1e10

# starting meta training
eva_l2_errors = []
for epoch in range(1, 1 + outer_iters):
    # train
    lamda = np.random.uniform(lamda_min, lamda_max)
    if lamda not in used_lamda:
        used_lamda.append(lamda)
    time_beg = time.time()
    model.set_train(True)

    if epoch % meta_train_config["reinit_lamda"] == 0:
        problem.lamda = lamda
    if epoch % meta_train_config["reinit_epoch"] == 0:
        re_initialize_model(model, epoch)

    for _ in range(1, 1 + inner_iters):
        for _ in range(steps_per_epochs):
            inner_sink_process()
    for _ in range(steps_per_epochs):
        cur_loss = outer_sink_process()

    print("epoch: %s loss: %s epoch time: %.3fms",
          epoch, cur_loss, (time.time() - time_beg) * 1000)

    if epoch % meta_train_config["eva_interval_outer"] == 0:
        # evaluate current model on unseen lamda
        print("learned params are: %s",
              problem.get_params(if_value=True))
        eva_iter = meta_train_config["eva_loop"][iteration_str]
        eva_l2_error = evaluate(WorkspaceConfig(epoch, case_name, config, problem.get_params(), eva_iter,
                                                eva_iter, use_ascend, loss_scaler, False, True, False, None))
        eva_l2_errors.append(eva_l2_error[0])
        if eva_l2_error[0] < best_l2:
            best_l2 = eva_l2_error[0]
            best_params = problem.get_params()

print(best_l2)
for param in best_params:
    print(param.asnumpy())

plot_l2_error(case_name, summary_config["visual_dir"],
              meta_train_config["eva_interval_outer"], eva_l2_errors)

# start comparing
test_iter = test_config[iteration_str]
test_interval = test_config["cal_l2_interval"]

# start meta training
meta_l2_errors = evaluate(WorkspaceConfig(None, case_name, config, best_params,
                                          test_iter, test_interval,
                                          use_ascend, loss_scaler, False, True,
                                          True, f"{case_name}_meta_testing"))
# end meta training

# start normal training
normal_params = create_normal_params(case_name)

normal_l2_errors = evaluate(WorkspaceConfig(None, case_name, config, normal_params,
                                            test_iter, test_interval,
                                            use_ascend, loss_scaler, False, False,
                                            True, f"{case_name}_normal_training"))
# end normal training

plot_l2_comparison_error(
    case_name, summary_config["visual_dir"], test_interval, meta_l2_errors, normal_l2_errors)

epoch: 1 loss: [0.86779684] epoch time: 23640.816ms
epoch: 2 loss: [0.9413994] epoch time: 1189.008ms
epoch: 3 loss: [0.9189246] epoch time: 1087.861ms
epoch: 4 loss: [0.9275309] epoch time: 1123.993ms
...
epoch: 50 l2_error: 1.0040625699565209 epoch time: 8788.293ms
...
epoch: 1000 loss: [0.9453384] epoch time: 1885.815ms
learned params are: pde weight=[1.865994] ic weight=[1.8658787] bc_weight=[1.8648634]
...
-------------------------------------start meta testing-------------------------------------
epoch: 0 l2_error: 1.0029983755797225 epoch time: 12976.881ms
epoch: 100 l2_error: 0.8387432039920448 epoch time: 111.595ms
epoch: 200 l2_error: 0.5762822379368656 epoch time: 111.311ms
...
epoch: 19800 l2_error: 0.033477298861820166 epoch time: 76.523ms
epoch: 19900 l2_error: 0.027350145606692883 epoch time: 79.059ms
-------------------------------------end meta testing-------------------------------------
-------------------------------------start normal training-----------------------

![burgers_l2](./images/burgers_l2.png)