## 问题定义

对待求解问题的数学描述如下：
![Equation](./resource/equation.png)

其中，$u$为待求解物理量，$x$为空间坐标，$t$为时间，$c=1$为波速

并且对于上述求解问题，解析解为：$u=\sin(x)(\sin(t) + \cos(t))$

### 求解目标

给定坐标$(x,t)$求解结果（$u$）


## 求解

In [1]:
import os
import warnings

# optional
# set appropriate GPU in case of multi-GPU machine
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="6"

In [15]:
# 必要的符号运算
from sympy import Symbol, Eq, Abs, Function, Number, sin
import numpy as np

import modulus.sym

# 超参数
from modulus.sym.hydra import to_yaml
from modulus.sym.hydra import to_absolute_path, instantiate_arch, ModulusConfig
from modulus.sym.hydra.utils import compose

# 求解器
from modulus.sym.solver import Solver

# domain
from modulus.sym.domain import Domain

# 几何物体
from modulus.sym.geometry.primitives_1d import Line1D

# 约束
from modulus.sym.domain.constraint import (
    PointwiseBoundaryConstraint,
    PointwiseInteriorConstraint,
)

# validator
from modulus.sym.domain.validator import PointwiseValidator

# inferencer
from modulus.sym.domain.inferencer import PointwiseInferencer
from modulus.sym.key import Key

# Equation
# 导入抽象基类
from modulus.sym.eq.pde import PDE

# post process
from modulus.sym.utils.io import (
    csv_to_dict,
    ValidatorPlotter,
    InferencerPlotter,
)
import matplotlib.pyplot as plt

In [3]:
cfg = compose(config_path="conf", config_name="config")
cfg.network_dir = 'outputs'    # Set the network directory for checkpoints
print(to_yaml(cfg))

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(


training:
  max_steps: 10000
  grad_agg_freq: 1
  rec_results_freq: 1000
  rec_validation_freq: ${training.rec_results_freq}
  rec_inference_freq: ${training.rec_results_freq}
  rec_monitor_freq: ${training.rec_results_freq}
  rec_constraint_freq: ${training.rec_results_freq}
  save_network_freq: 1000
  print_stats_freq: 100
  summary_freq: 1000
  amp: false
  amp_dtype: float16
  ntk:
    use_ntk: false
    save_name: null
    run_freq: 1000
graph:
  func_arch: false
  func_arch_allow_partial_hessian: true
stop_criterion:
  metric: null
  min_delta: null
  patience: 50000
  mode: min
  freq: 1000
  strict: false
profiler:
  profile: false
  start_step: 0
  end_step: 100
  name: nvtx
network_dir: outputs
initialization_network_dir: ''
save_filetypes: vtk,npz
summary_histograms: false
jit: true
jit_use_nvfuser: true
jit_arch_mode: only_activation
jit_autograd_nodes: false
cuda_graphs: true
cuda_graph_warmup: 20
find_unused_parameters: false
broadcast_buffers: false
device: ''
debug: fal

### 定义必要组件

#### PDE

本案例特殊之处在于需要自行定义PDE和边界条件

自定义PDE通过继承抽象基类(modulus.sym.eq.pde.PDE)实现

In [5]:
class WaveEquation1D(PDE):
    """
    Wave equation 1D
    The equation is given as an example for implementing
    your own PDE. A more universal implementation of the
    wave equation can be found by
    `from modulus.sym.eq.pdes.wave_equation import WaveEquation`.

    Parameters
    ==========
    c : float, string
        Wave speed coefficient. If a string then the
        wave speed is input into the equation.
    """

    name = "WaveEquation1D"

    def __init__(self, c=1.0):
        # 空间坐标
        x = Symbol("x")

        # 时间坐标
        t = Symbol("t")

        # 输入变量
        input_variables = {"x": x, "t": t}

        # 输出变量
        u = Function("u")(*input_variables)

        # 若传入的c为字符串，则表明c是随输入变量变化的量（inverse）
        # 否则定义为Number
        if type(c) is str:
            c = Function(c)(*input_variables)
        elif type(c) in [float, int]:
            c = Number(c)

        # 构建方程
        self.equations = {}
        self.equations["wave_equation"] = u.diff(t, 2) - (c**2 * u.diff(x)).diff(x)

we = WaveEquation1D(c=1.0)

#### Model

In [8]:
# 输入为空间坐标和时间坐标
# 输出为u
wave_net = instantiate_arch(
        input_keys=[Key("x"), Key("t")],
        output_keys=[Key("u")],
        cfg=cfg.arch.fully_connected,
    )
print(wave_net)

nodes = we.make_nodes() + [wave_net.make_node(name="wave_network")]

FullyConnectedArch(
  (_impl): FullyConnectedArchCore(
    (layers): ModuleList(
      (0): FCLayer(
        (linear): WeightNormLinear(in_features=2, out_features=512, bias=True)
      )
      (1-5): 5 x FCLayer(
        (linear): WeightNormLinear(in_features=512, out_features=512, bias=True)
      )
    )
    (final_layer): FCLayer(
      (activation_fn): Identity()
      (linear): Linear(in_features=512, out_features=1, bias=True)
    )
  )
)


#### Geo

In [11]:
# 定义求解区域
# 本问题中，x范围为[0, pi]，t范围为[0, 2*pi]
x, t_symbol = Symbol("x"), Symbol("t")
L = float(np.pi)
geo = Line1D(0, L)
time_range = {t_symbol: (0, 2 * L)}

#### Domain

在Domain中定义约束以及训练所需的各种组件

In [12]:
# make domain
domain = Domain()

边界条件

In [13]:
# boundary condition
BC = PointwiseBoundaryConstraint(
    nodes=nodes,
    geometry=geo,
    outvar={"u": 0},
    batch_size=cfg.batch_size.BC,
    parameterization=time_range,  # 求解的时间范围为整个求解域
)
domain.add_constraint(BC, "BC")

初始条件

初始条件需要计算梯度，在Modulus中，梯度的符号表示方法是`__`。例如对$x$的一阶导：`__x`；二阶导：`__x__x`

In [16]:
# initial condition
IC = PointwiseInteriorConstraint(
    nodes=nodes,
    geometry=geo,
    outvar={"u": sin(x), "u__t": sin(x)},
    batch_size=cfg.batch_size.IC,
    lambda_weighting={"u": 1.0, "u__t": 1.0},
    parameterization={t_symbol: 0.0},  # 求解的时间范围为0.0
)
domain.add_constraint(IC, "IC")

内部满足PDE约束

In [17]:
# interior
interior = PointwiseInteriorConstraint(
    nodes=nodes,
    geometry=geo,
    outvar={"wave_equation": 0},
    batch_size=cfg.batch_size.interior,
    parameterization=time_range,  # 求解的时间范围为整个求解域
)
domain.add_constraint(interior, "interior")

验证器以及其他必要组件

In [18]:
# x和t的step size
deltaT = 0.01
deltaX = 0.01

# 创建x和t序列
x = np.arange(0, L, deltaX)
t = np.arange(0, 2 * L, deltaT)

# 构建网格
X, T = np.meshgrid(x, t)
X = np.expand_dims(X.flatten(), axis=-1)
T = np.expand_dims(T.flatten(), axis=-1)

# 解析解
u = np.sin(X) * (np.cos(T) + np.sin(T))

# 创建输入输出
invar_numpy = {"x": X, "t": T}
outvar_numpy = {"u": u}

# 创建Validator
validator = PointwiseValidator(
    nodes=nodes, invar=invar_numpy, true_outvar=outvar_numpy, batch_size=128, plotter=ValidatorPlotter()
)
domain.add_validator(validator)

### 求解器以及求解

In [20]:
# 定义求解器
slv = Solver(cfg, domain)

手动加载日志系统

In [21]:
import logging
logging.getLogger().addHandler(logging.StreamHandler())

启动求解

In [22]:
slv.solve()

Installed PyTorch version 2.2.0a0+81ea7a4 is not TorchScript supported in Modulus. Version 2.1.0a0+4136153 is officially supported.
attempting to restore from: /workspace/01_1D_Wave/outputs
optimizer checkpoint not found
model wave_network.0.pth not found
[step:          0] saved constraint results to outputs
[step:          0] record constraint batch time:  6.485e-02s
Loaded backend module://matplotlib_inline.backend_inline version unknown.
Loaded backend module://matplotlib_inline.backend_inline version unknown.
findfont: Matching sans\-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0.
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/STIXSizThreeSymReg.ttf', name='STIXSizeThreeSym', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 10.05
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/cmb10.ttf', name='cmb10', style='no

findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/STIXSizTwoSymBol.ttf', name='STIXSizeTwoSym', style='normal', variant='normal', weight=700, stretch='normal', size='scalable')) = 10.335
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/STIXNonUni.ttf', name='STIXNonUnicode', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 10.05
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf', name='DejaVu Sans', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 0.05
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/cmss10.ttf', name='cmss10', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 10.05
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matp

findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono.ttf', name='DejaVu Sans Mono', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 10.05
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/STIXSizFiveSymReg.ttf', name='STIXSizeFiveSym', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 10.05
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Oblique.ttf', name='DejaVu Sans Mono', style='oblique', variant='normal', weight=400, stretch='normal', size='scalable')) = 11.05
findfont: score(FontEntry(fname='/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/STIXSizOneSymReg.ttf', name='STIXSizeOneSym', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 10.05
findfont: score(FontEntry(fname='

findfont: score(FontEntry(fname='/usr/share/fonts/truetype/liberation/LiberationSans-BoldItalic.ttf', name='Liberation Sans', style='italic', variant='normal', weight=700, stretch='normal', size='scalable')) = 11.335
findfont: score(FontEntry(fname='/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Italic.ttf', name='Liberation Sans Narrow', style='italic', variant='normal', weight=400, stretch='condensed', size='scalable')) = 11.25
findfont: score(FontEntry(fname='/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf', name='Liberation Mono', style='normal', variant='normal', weight=700, stretch='normal', size='scalable')) = 10.335
findfont: score(FontEntry(fname='/usr/share/fonts/truetype/liberation/LiberationSans-Italic.ttf', name='Liberation Sans', style='italic', variant='normal', weight=400, stretch='normal', size='scalable')) = 11.05
findfont: score(FontEntry(fname='/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf', name='DejaVu Sans Mono', style='normal',

[step:       3700] loss:  1.142e-05, time/iteration:  1.792e+01 ms
[step:       3800] loss:  8.934e-06, time/iteration:  1.795e+01 ms
[step:       3900] loss:  5.325e-06, time/iteration:  1.792e+01 ms
[step:       4000] saved constraint results to outputs
[step:       4000] record constraint batch time:  7.976e-02s
locator: <matplotlib.ticker.AutoLocator object at 0x7f0ad4048220>
locator: <matplotlib.ticker.AutoLocator object at 0x7f0ad409a800>
locator: <matplotlib.ticker.AutoLocator object at 0x7f0acff10cd0>
[step:       4000] saved validator results to outputs
[step:       4000] record validators time:  8.259e+00s
[step:       4000] saved checkpoint to /workspace/01_1D_Wave/outputs
[step:       4000] loss:  4.518e-06, time/iteration:  1.070e+02 ms
[step:       4100] loss:  5.141e-06, time/iteration:  1.829e+01 ms
[step:       4200] loss:  3.471e-06, time/iteration:  1.897e+01 ms
[step:       4300] loss:  4.883e-06, time/iteration:  1.912e+01 ms
[step:       4400] loss:  7.377e-06, ti

### 后处理以及可视化

对于jupyter，比较方便的方法是使用matplotlib

此外，还可以使用tensorboard以及Paraview

如果使用了PointwiseValidator则可以直接查看验证的结果：

![u](./outputs/validators/validator_u.png)