
# 环境安装

本案例要求 MindSpore >= 2.2.12 版本以调用如下接口: mindspore.nn, mindspore.jit_class, mindspore.data_sink,mindflow_ascend具体请查看MindSpore安装。

## 项目初始化

In [None]:
import os
import time
import numpy as np
from mindspore import nn, Tensor, context, ops, jit, set_seed, data_sink, save_checkpoint
from mindspore import dtype as mstype
from mindspore.nn import L1Loss
from mindflow.common import get_warmup_cosine_annealing_lr
from mindflow.utils import load_yaml_config, print_log
from src.utils import Trainer, init_model, check_file_path, count_params, plot_image, plot_image_first
from src.dataset import init_dataset

确保代码在Ascend设备上运行，并检查是否成功设置了设备目标

In [None]:
def train():
    """train"""
    set_seed(0)
    np.random.seed(0)

    context.set_context(mode=context.GRAPH_MODE,
                        save_graphs=False,
                        device_target="Ascend",
                        device_id=0)
    use_ascend = context.get_context("device_target") == "Ascend"
    print(use_ascend)

配置训练参数

In [None]:
    config = load_yaml_config("./configs/combined_methods.yaml")
    data_params = config["data"]
    model_params = config["model"]
    optimizer_params = config["optimizer"]
    summary_params = config["summary"]

准备数据集

In [None]:
    train_dataset, test_dataset, means, stds = init_dataset(data_params)
    print('train_dataset', train_dataset)

模型构建

In [None]:
    if use_ascend:
        from mindspore.amp import DynamicLossScaler, all_finite, auto_mixed_precision
        loss_scaler = DynamicLossScaler(1024, 2, 100)
        compute_dtype = mstype.float16
        model = init_model("unet2d", data_params, model_params, compute_dtype=compute_dtype)
        auto_mixed_precision(model, optimizer_params["amp_level"]["unet2d"])
    else:
        context.set_context(enable_graph_kernel=False)
        loss_scaler = None
        compute_dtype = mstype.float32
        model = init_model("unet2d", data_params, model_params, compute_dtype=compute_dtype)

损失函数与优化器

In [None]:
    loss_fn = L1Loss()
    summary_dir = os.path.join(summary_params["summary_dir"], "Exp_datadriven", "unet2d")
    ckpt_dir = os.path.join(summary_dir, "ckpt_dir")
    check_file_path(ckpt_dir)
    check_file_path(os.path.join(ckpt_dir, 'img'))
    print_log('model parameter count:', count_params(model.trainable_params()))
    print_log(
        f'learning rate: {optimizer_params["lr"]["unet2d"]}, '
        f'T_in: {data_params["T_in"]}, T_out: {data_params["T_out"]}')
    steps_per_epoch = train_dataset.get_dataset_size()

    lr = get_warmup_cosine_annealing_lr(optimizer_params["lr"]["unet2d"], steps_per_epoch,
                                        optimizer_params["epochs"], optimizer_params["warm_up_epochs"])
    optimizer = nn.AdamWeightDecay(model.trainable_params(),
                                   learning_rate=Tensor(lr),
                                   weight_decay=optimizer_params["weight_decay"])
    trainer = Trainer(model, data_params, loss_fn, means, stds)

执行模型的前向传播并返回损失值，计算前向函数 forward_fn 的输出（即损失）相对于模型参数的梯度。

In [None]:
    def forward_fn(inputs, labels):
        loss, _, _, _, _, _, _ = trainer.get_loss(inputs, labels)
        if use_ascend:
            loss = loss_scaler.scale(loss)
        return loss

    grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

用于通过反向传播算法优化模型以最小化损失函数。

In [None]:
    @jit
    def train_step(inputs, labels):
        loss, grads = grad_fn(inputs, labels)
        if use_ascend:
            loss = loss_scaler.unscale(loss)
            if all_finite(grads):
                grads = loss_scaler.unscale(grads)
        loss_new = ops.depend(loss, optimizer(grads))
        return loss_new, inputs, labels

get_loss 方法负责执行模型的前向传播，计算预测输出与真实标签之间的差异，并据此计算损失。

In [None]:
    def test_step(inputs, labels):
        return trainer.get_loss(inputs, labels)

获取训练数据集的大小，即训练数据集中的批次总数。这个值用于确定训练循环中的迭代次数。

In [None]:
    train_size = train_dataset.get_dataset_size()
    test_size = test_dataset.get_dataset_size()
    train_sink = data_sink(train_step, train_dataset, sink_size=1)
    test_sink = data_sink(test_step, test_dataset, sink_size=1)
    test_interval = summary_params["test_interval"]
    save_ckpt_interval = summary_params["save_ckpt_interval"]

模型训练

In [None]:
    for epoch in range(1, optimizer_params["epochs"] + 1):
        time_beg = time.time()
        train_l1 = 0.0
        model.set_train()
        for _ in range(1, train_size + 1):
            loss_train, inputs, labels = train_sink()
            train_l1 += loss_train.asnumpy()
        train_loss = train_l1 / train_size
        if epoch >= trainer.hatch_extent:
            _, loss1, loss2, _, _, _, _ = trainer.get_loss(inputs, labels)
            trainer.renew_loss_lists(loss1, loss2)
            trainer.adjust_hatchs()
        print_log(
            f"epoch: {epoch}, "
            f"step time: {(time.time() - time_beg) / steps_per_epoch:>7f}, "
            f"loss: {train_loss:>7f}")

        if epoch % test_interval == 0:
            model.set_train(False)
            test_l1 = 0.0
            for _ in range(test_size):
                loss_test, loss1, loss2, inputs, pred, labels, _ = test_sink()
                test_l1 += loss_test.asnumpy()
            test_loss = test_l1 / test_size
            print_log(
                f"epoch: {epoch}, "
                f"step time: {(time.time() - time_beg) / steps_per_epoch:>7f}, "
                f"loss: {test_loss:>7f}")

            plot_image(inputs, 0)
            plot_image_first(inputs, 0)
            plot_image(pred, 0)
            plot_image(labels, 0)

        if epoch % save_ckpt_interval == 0:
            save_checkpoint(model, ckpt_file_name=os.path.join(ckpt_dir, 'model_data.ckpt'))

    print("Training Finished!!")

运行

In [None]:
if __name__ == "__main__":
    train()

epoch: 1, step time: 7.204487, loss: 30.159577
epoch: 2, step time: 0.050003, loss: 16.109943
epoch: 3, step time: 23.267395, loss: 14.340383
epoch: 4, step time: 0.059521, loss: 11.345826
epoch: 5, step time: 0.059669, loss: 8.936566
epoch: 6, step time: 0.060332, loss: 7.976031
epoch: 7, step time: 0.059778, loss: 7.400826
epoch: 8, step time: 0.057766, loss: 7.384783
epoch: 9, step time: 0.058148, loss: 6.274703
epoch: 10, step time: 0.058232, loss: 7.144862
epoch: 10, step time: 0.930520, loss: 0.105834
epoch: 11, step time: 0.056829, loss: 6.353740
epoch: 12, step time: 0.055501, loss: 7.061435
epoch: 13, step time: 0.056191, loss: 5.679731
epoch: 14, step time: 0.056355, loss: 6.572665
epoch: 15, step time: 0.056130, loss: 5.429127
epoch: 16, step time: 0.056419, loss: 5.232585
epoch: 17, step time: 0.056626, loss: 4.742438
epoch: 18, step time: 0.056156, loss: 5.042458
epoch: 19, step time: 0.056799, loss: 4.441929
epoch: 20, step time: 0.055955, loss: 4.645259
epoch: 20, step time: 0.061364, loss: 0.061257
epoch: 21, step time: 0.061526, loss: 4.051717
epoch: 22, step time: 0.057296, loss: 4.084571
epoch: 23, step time: 0.056526, loss: 3.812290
epoch: 24, step time: 0.056620, loss: 4.036302
epoch: 25, step time: 0.056961, loss: 4.224667
epoch: 26, step time: 0.056612, loss: 3.680945
epoch: 27, step time: 0.056421, loss: 3.704518
epoch: 28, step time: 0.056409, loss: 3.288220
epoch: 29, step time: 0.056580, loss: 3.349201
epoch: 30, step time: 0.056552, loss: 4.562499
epoch: 30, step time: 0.062123, loss: 0.056036
epoch: 31, step time: 0.058193, loss: 5.577391
epoch: 32, step time: 0.056697, loss: 4.279838
epoch: 33, step time: 0.055909, loss: 4.410978
epoch: 34, step time: 0.056718, loss: 3.254215
epoch: 35, step time: 0.056718, loss: 3.463492
epoch: 36, step time: 0.056144, loss: 3.151621
epoch: 37, step time: 0.056268, loss: 2.815826
epoch: 38, step time: 0.055588, loss: 2.828014
epoch: 39, step time: 0.060484, loss: 2.759430
epoch: 40, step time: 0.056569, loss: 2.598535
epoch: 40, step time: 0.062142, loss: 0.030793

## 结果展示

![plot1](images/plot1.png)![plot2](images/plot2.png)
第一张为输入温度场图片，第二张为训练预测图片