# PreDiff: 基于潜在扩散模型的降水短时预报

## 概述

传统的天气预报技术依赖于复杂的物理模型，这些模型不仅计算成本高昂，还要求深厚的专业知识支撑。然而，近十年来，随着地球时空观测数据的爆炸式增长，深度学习技术为构建数据驱动的预测模型开辟了新的道路。虽然这些模型在多种地球系统预测任务中展现出巨大潜力，但它们在管理不确定性和整合特定领域先验知识方面仍有不足，时常导致预测结果模糊不清或在物理上不可信。

为克服这些难题，来自香港科技大学的Gao Zhihan实现了**prediff**模型，专门用于实现概率性的时空预测。该流程融合了条件潜在扩散模型与显式的知识对齐机制，旨在生成既符合特定领域物理约束，又能精确捕捉时空变化的预测结果。通过这种方法，我们期望能够显著提升地球系统预测的准确性和可靠性。
基础上生成精细化的结果，从而得到最终的降水预报。模型框架图入下图所示(图片来源于论文 [PreDiff: Precipitation Nowcasting with Latent Diffusion Models](https://openreview.net/pdf?id=Gh67ZZ6zkS))

![prediff](images/train.jpg)

训练的过程中，数据通过变分自编码器提取关键信息到隐空间，之后随机选择时间步生成对应时间步噪声，对数据进行加噪处理。之后将数据输入Earthformer-UNet进行降噪处理，Earthformer-UNet采用了UNet构架和cuboid attention，去除了Earthformer中连接encoder和decoder的cross-attention结构。最后将结果通过变分自解码器还原得到去噪后的数据，扩散模型通过反转预先定义的破坏原始数据的加噪过程来学习数据分布。

## 概述

MindSpore Earth求解该问题的具体流程如下:

1.创建数据集

2.模型构建

3.损失函数

4.模型训练

5.模型评估与可视化

数据集可以在[PreDiff/dataset](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip)下载数据并保存

In [1]:
import time
import os
import random
import json
from typing import Sequence, Union
import numpy as np
from einops import rearrange

import mindspore as ms
from mindspore import set_seed, context, ops, nn, mint
from mindspore.experimental import optim
from mindspore.train.serialization import save_checkpoint

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


下述src可以在[PreDiff/src](./src)下载

In [2]:
from mindearth.utils import load_yaml_config

from src import (
    prepare_output_directory,
    configure_logging_system,
    prepare_dataset,
    init_model,
    PreDiffModule,
    DiffusionTrainer,
    DiffusionInferrence
)
from src.sevir_dataset import SEVIRDataset
from src.visual import vis_sevir_seq
from src.utils import warmup_lambda

In [3]:
set_seed(0)
np.random.seed(0)
random.seed(0)

可以在[配置文件](./configs/diffusion.yaml)中配置模型、数据和优化器等参数。

In [4]:
config = load_yaml_config("./configs/diffusion.yaml")
context.set_context(mode=ms.PYNATIVE_MODE)
ms.set_device(device_target="Ascend", device_id=0)

## 模型构建

模型初始化主要包括vae模块load ckpt以及earthformer部分初始化

In [5]:
main_module = PreDiffModule(oc_file="./configs/diffusion.yaml")
main_module = init_model(module=main_module, config=config, mode="train")
output_dir = prepare_output_directory(config, "0")
logger = configure_logging_system(output_dir, config)

2025-04-07 10:32:11,466 - utils.py[line:820] - INFO: Process ID: 2231351
2025-04-07 10:32:11,467 - utils.py[line:821] - INFO: {'summary_dir': './summary/prediff/single_device0', 'eval_interval': 10, 'save_ckpt_epochs': 1, 'keep_ckpt_max': 100, 'ckpt_path': '/home/lry/202542测试/PreDiff/ckpt/diffusion.ckpt', 'load_ckpt': False}


NoisyCuboidTransformerEncoder param_not_load: []
Cleared previous output directory: ./summary/prediff/single_device0


## 创建数据集

下载[sevir-lr](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip)数据集到./dataset目录。

In [6]:
dm, total_num_steps = prepare_dataset(config, PreDiffModule)

train
                                                        vil_filename  \
id                                                                     
R18020113057733 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020113057811 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020113057875 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020113057888 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020113057982 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020113058079 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020113058477 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020113058635 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020306327357 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   
R18020306327410 0  vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5   

                   vil_index  
id                            
R18020113057733 0       1310  
R18020113057811 0       1311  
R1802

## 损失函数

PreDiff训练中使用mse作为loss计算，采用了梯度裁剪，并将过程封装在了DiffusionTrainer中

In [7]:
class DiffusionTrainer(nn.Cell):
    """
    Class managing the training pipeline for diffusion models. Handles dataset processing,
    optimizer configuration, gradient clipping, checkpoint saving, and logging.
    """
    def __init__(self, main_module, dm, logger, config):
        """
        Initialize trainer with model, data module, logger, and configuration.
        Args:
            main_module: Main diffusion model to be trained
            dm: Data module providing training dataset
            logger: Logging utility for training progress
            config: Configuration dictionary containing hyperparameters
        """
        super().__init__()
        self.main_module = main_module
        self.traindataset = dm.sevir_train
        self.logger = logger
        self.datasetprocessing = SEVIRDataset(
            data_types=["vil"],
            layout="NHWT",
            rescale_method=config.get("rescale_method", "01"),
        )
        self.example_save_dir = config["summary"].get("summary_dir", "./summary")
        self.fs = config["eval"].get("fs", 20)
        self.label_offset = config["eval"].get("label_offset", [-0.5, 0.5])
        self.label_avg_int = config["eval"].get("label_avg_int", False)
        self.current_epoch = 0
        self.learn_logvar = (
            config.get("model", {}).get("diffusion", {}).get("learn_logvar", False)
        )
        self.logvar = main_module.logvar
        self.maeloss = nn.MAELoss()
        self.optim_config = config["optim"]
        self.clip_norm = config.get("clip_norm", 2.0)
        self.ckpt_dir = os.path.join(self.example_save_dir, "ckpt")
        self.keep_ckpt_max = config["summary"].get("keep_ckpt_max", 100)
        self.ckpt_history = []
        self.grad_clip_fn = ops.clip_by_global_norm
        self.optimizer = nn.Adam(params=self.main_module.main_model.trainable_params(), learning_rate=0.00001)
        os.makedirs(self.ckpt_dir, exist_ok=True)

    def train(self, total_steps: int):
        """Execute complete training pipeline."""
        self.main_module.main_model.set_train(True)
        self.logger.info("Initializing training process...")
        # optimizer, lr_scheduler = self._get_optimizer(total_steps)
        loss_processor = Trainonestepforward(self.main_module)
        grad_func = ms.ops.value_and_grad(loss_processor, None, self.main_module.main_model.trainable_params())
        for epoch in range(self.optim_config["max_epochs"]):
            epoch_loss = 0.0
            epoch_start = time.time()

            iterator = self.traindataset.create_dict_iterator()
            assert iterator, "dataset is empty"
            batch_idx = 0
            for batch_idx, batch in enumerate(iterator):
                processed_data = self.datasetprocessing.process_data(batch["vil"])
                loss_value, gradients = grad_func(processed_data)
                clipped_grads = self.grad_clip_fn(gradients, self.clip_norm)
                self.optimizer(clipped_grads)
                #lr_scheduler.step()
                epoch_loss += loss_value.asnumpy()
                self.logger.info(
                    f"epoch: {epoch} step: {batch_idx}, loss: {loss_value}"
                )
            self._save_ckpt(epoch)
            epoch_time = time.time() - epoch_start
            self.logger.info(
                f"Epoch {epoch} completed in {epoch_time:.2f}s | "
                f"Avg Loss: {epoch_loss/(batch_idx+1):.4f}"
            )

    def _get_optimizer(self, total_steps: int):
        """Configure optimization components"""
        trainable_params = list(self.main_module.main_model.trainable_params())
        if self.learn_logvar:
            self.logger.info("Including log variance parameters")
            trainable_params.append(self.logvar)
        optimizer = optim.AdamW(
            trainable_params,
            lr=self.optim_config["lr"],
            betas=tuple(self.optim_config["betas"]),
        )
        warmup_steps = int(self.optim_config["warmup_percentage"] * total_steps)
        scheduler = self._create_lr_scheduler(optimizer, total_steps, warmup_steps)

        return optimizer, scheduler

    def _create_lr_scheduler(self, optimizer, total_steps: int, warmup_steps: int):
        """Build learning rate scheduler"""
        warmup_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=warmup_lambda(
                warmup_steps=warmup_steps,
                min_lr_ratio=self.optim_config["warmup_min_lr_ratio"],
            ),
        )

        cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=total_steps - warmup_steps,
            eta_min=self.optim_config["min_lr_ratio"] * self.optim_config["lr"],
        )

        return optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler],
            milestones=[warmup_steps],
        )

    def _save_ckpt(self, epoch: int):
        """Save model ckpt with rotation policy"""
        ckpt_file = f"diffusion_epoch{epoch}.ckpt"
        ckpt_path = os.path.join(self.ckpt_dir, ckpt_file)

        save_checkpoint(self.main_module.main_model, ckpt_path)
        self.ckpt_history.append(ckpt_path)

        if len(self.ckpt_history) > self.keep_ckpt_max:
            removed_ckpt = self.ckpt_history.pop(0)
            os.remove(removed_ckpt)
            self.logger.info(f"Removed outdated ckpt: {removed_ckpt}")


class Trainonestepforward(nn.Cell):
    """A neural network cell that performs one training step forward pass for a diffusion model.
    This class encapsulates the forward pass computation for training a diffusion model,
    handling the input processing, latent space encoding, conditioning, and loss calculation.
    Args:
        model (nn.Cell): The main diffusion model containing the necessary submodules
                         for encoding, conditioning, and loss computation.
    """

    def __init__(self, model):
        super().__init__()
        self.main_module = model

    def construct(self, inputs):
        """Perform one forward training step and compute the loss."""
        x, condition = self.main_module.get_input(inputs)
        x = x.transpose(0, 1, 4, 2, 3)
        n, t_, c_, h, w = x.shape
        x = x.reshape(n * t_, c_, h, w)
        z = self.main_module.encode_first_stage(x)
        _, c_z, h_z, w_z = z.shape
        z = z.reshape(n, -1, c_z, h_z, w_z)
        z = z.transpose(0, 1, 3, 4, 2)
        t = ops.randint(0, self.main_module.num_timesteps, (n,)).long()
        zc = self.main_module.cond_stage_forward(condition)
        loss = self.main_module.p_losses(z, zc, t, noise=None)
        return loss

## 模型训练

在本教程中，我们使用DiffusionTrainer对模型进行训练

In [8]:
trainer = DiffusionTrainer(
    main_module=main_module, dm=dm, logger=logger, config=config
)
trainer.train(total_steps=total_num_steps)

2025-04-07 10:32:36,351 - 4106154625.py[line:46] - INFO: Initializing training process...


.........

2025-04-07 10:34:09,378 - 4106154625.py[line:64] - INFO: epoch: 0 step: 0, loss: 1.0008465


.

2025-04-07 10:34:16,871 - 4106154625.py[line:64] - INFO: epoch: 0 step: 1, loss: 1.0023363
2025-04-07 10:34:18,724 - 4106154625.py[line:64] - INFO: epoch: 0 step: 2, loss: 1.0009086


.

2025-04-07 10:34:20,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 3, loss: 0.99787366
2025-04-07 10:34:22,280 - 4106154625.py[line:64] - INFO: epoch: 0 step: 4, loss: 0.9979043
2025-04-07 10:34:24,072 - 4106154625.py[line:64] - INFO: epoch: 0 step: 5, loss: 0.99897844
2025-04-07 10:34:25,864 - 4106154625.py[line:64] - INFO: epoch: 0 step: 6, loss: 1.0021904
2025-04-07 10:34:27,709 - 4106154625.py[line:64] - INFO: epoch: 0 step: 7, loss: 0.9984627
2025-04-07 10:34:29,578 - 4106154625.py[line:64] - INFO: epoch: 0 step: 8, loss: 0.9952746


.

2025-04-07 10:34:31,432 - 4106154625.py[line:64] - INFO: epoch: 0 step: 9, loss: 1.0003254
2025-04-07 10:34:33,402 - 4106154625.py[line:64] - INFO: epoch: 0 step: 10, loss: 1.0020428
2025-04-07 10:34:35,218 - 4106154625.py[line:64] - INFO: epoch: 0 step: 11, loss: 0.99563503
2025-04-07 10:34:37,149 - 4106154625.py[line:64] - INFO: epoch: 0 step: 12, loss: 0.99336195
2025-04-07 10:34:38,949 - 4106154625.py[line:64] - INFO: epoch: 0 step: 13, loss: 1.0023757


.

2025-04-07 10:34:40,962 - 4106154625.py[line:64] - INFO: epoch: 0 step: 14, loss: 1.0007098
2025-04-07 10:34:43,332 - 4106154625.py[line:64] - INFO: epoch: 0 step: 15, loss: 0.99492
2025-04-07 10:34:45,177 - 4106154625.py[line:64] - INFO: epoch: 0 step: 16, loss: 0.99957407
2025-04-07 10:34:47,040 - 4106154625.py[line:64] - INFO: epoch: 0 step: 17, loss: 0.99685913
2025-04-07 10:34:48,823 - 4106154625.py[line:64] - INFO: epoch: 0 step: 18, loss: 0.9956614


.

2025-04-07 10:34:50,720 - 4106154625.py[line:64] - INFO: epoch: 0 step: 19, loss: 0.9934994
2025-04-07 10:34:52,552 - 4106154625.py[line:64] - INFO: epoch: 0 step: 20, loss: 0.99108785
2025-04-07 10:34:54,389 - 4106154625.py[line:64] - INFO: epoch: 0 step: 21, loss: 0.99182785
2025-04-07 10:34:56,159 - 4106154625.py[line:64] - INFO: epoch: 0 step: 22, loss: 0.99136275
2025-04-07 10:34:58,118 - 4106154625.py[line:64] - INFO: epoch: 0 step: 23, loss: 0.9886243
2025-04-07 10:35:00,045 - 4106154625.py[line:64] - INFO: epoch: 0 step: 24, loss: 0.9947286


.

2025-04-07 10:35:01,964 - 4106154625.py[line:64] - INFO: epoch: 0 step: 25, loss: 0.99265075
2025-04-07 10:35:03,818 - 4106154625.py[line:64] - INFO: epoch: 0 step: 26, loss: 0.98734057
2025-04-07 10:35:05,604 - 4106154625.py[line:64] - INFO: epoch: 0 step: 27, loss: 0.9867786
2025-04-07 10:35:07,383 - 4106154625.py[line:64] - INFO: epoch: 0 step: 28, loss: 0.98637533
2025-04-07 10:35:09,311 - 4106154625.py[line:64] - INFO: epoch: 0 step: 29, loss: 0.98799324
2025-04-07 10:35:11,054 - 4106154625.py[line:64] - INFO: epoch: 0 step: 30, loss: 0.9851307


.

2025-04-07 10:35:12,883 - 4106154625.py[line:64] - INFO: epoch: 0 step: 31, loss: 0.98547524
2025-04-07 10:35:14,629 - 4106154625.py[line:64] - INFO: epoch: 0 step: 32, loss: 0.9783558
2025-04-07 10:35:16,444 - 4106154625.py[line:64] - INFO: epoch: 0 step: 33, loss: 0.9851396
2025-04-07 10:35:18,122 - 4106154625.py[line:64] - INFO: epoch: 0 step: 34, loss: 0.98461366
2025-04-07 10:35:20,102 - 4106154625.py[line:64] - INFO: epoch: 0 step: 35, loss: 0.9879103
2025-04-07 10:35:22,232 - 4106154625.py[line:64] - INFO: epoch: 0 step: 36, loss: 0.9743713


.

2025-04-07 10:35:24,417 - 4106154625.py[line:64] - INFO: epoch: 0 step: 37, loss: 0.98045284
2025-04-07 10:35:26,435 - 4106154625.py[line:64] - INFO: epoch: 0 step: 38, loss: 0.97129095
2025-04-07 10:35:28,351 - 4106154625.py[line:64] - INFO: epoch: 0 step: 39, loss: 0.98204684
2025-04-07 10:35:30,122 - 4106154625.py[line:64] - INFO: epoch: 0 step: 40, loss: 0.97880834
2025-04-07 10:35:31,760 - 4106154625.py[line:64] - INFO: epoch: 0 step: 41, loss: 0.96932787


.

2025-04-07 10:35:33,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 42, loss: 0.9717276
2025-04-07 10:35:35,276 - 4106154625.py[line:64] - INFO: epoch: 0 step: 43, loss: 0.9716038
2025-04-07 10:35:37,238 - 4106154625.py[line:64] - INFO: epoch: 0 step: 44, loss: 0.9686392
2025-04-07 10:35:39,268 - 4106154625.py[line:64] - INFO: epoch: 0 step: 45, loss: 0.99201906
2025-04-07 10:35:41,141 - 4106154625.py[line:64] - INFO: epoch: 0 step: 46, loss: 0.977281
2025-04-07 10:35:43,166 - 4106154625.py[line:64] - INFO: epoch: 0 step: 47, loss: 0.96613944


.

2025-04-07 10:35:45,249 - 4106154625.py[line:64] - INFO: epoch: 0 step: 48, loss: 0.9612762
2025-04-07 10:35:47,142 - 4106154625.py[line:64] - INFO: epoch: 0 step: 49, loss: 0.9577536
2025-04-07 10:35:49,114 - 4106154625.py[line:64] - INFO: epoch: 0 step: 50, loss: 0.95175207
2025-04-07 10:35:51,080 - 4106154625.py[line:64] - INFO: epoch: 0 step: 51, loss: 0.95729643
2025-04-07 10:35:53,116 - 4106154625.py[line:64] - INFO: epoch: 0 step: 52, loss: 0.960687


.

2025-04-07 10:35:55,202 - 4106154625.py[line:64] - INFO: epoch: 0 step: 53, loss: 0.9575224
2025-04-07 10:35:57,168 - 4106154625.py[line:64] - INFO: epoch: 0 step: 54, loss: 0.9500365
2025-04-07 10:35:59,015 - 4106154625.py[line:64] - INFO: epoch: 0 step: 55, loss: 0.94735086
2025-04-07 10:36:01,016 - 4106154625.py[line:64] - INFO: epoch: 0 step: 56, loss: 0.97874105
2025-04-07 10:36:02,904 - 4106154625.py[line:64] - INFO: epoch: 0 step: 57, loss: 0.9451903


.

2025-04-07 10:36:04,717 - 4106154625.py[line:64] - INFO: epoch: 0 step: 58, loss: 0.94447565
2025-04-07 10:36:06,499 - 4106154625.py[line:64] - INFO: epoch: 0 step: 59, loss: 0.94874763
2025-04-07 10:36:08,260 - 4106154625.py[line:64] - INFO: epoch: 0 step: 60, loss: 0.9672854
2025-04-07 10:36:10,146 - 4106154625.py[line:64] - INFO: epoch: 0 step: 61, loss: 0.9565505
2025-04-07 10:36:12,112 - 4106154625.py[line:64] - INFO: epoch: 0 step: 62, loss: 0.9480209
2025-04-07 10:36:13,989 - 4106154625.py[line:64] - INFO: epoch: 0 step: 63, loss: 0.94844496


.

2025-04-07 10:36:15,759 - 4106154625.py[line:64] - INFO: epoch: 0 step: 64, loss: 0.94463414
2025-04-07 10:36:17,409 - 4106154625.py[line:64] - INFO: epoch: 0 step: 65, loss: 0.9484377
2025-04-07 10:36:19,103 - 4106154625.py[line:64] - INFO: epoch: 0 step: 66, loss: 0.93955624
2025-04-07 10:36:21,005 - 4106154625.py[line:64] - INFO: epoch: 0 step: 67, loss: 0.9357619
2025-04-07 10:36:22,738 - 4106154625.py[line:64] - INFO: epoch: 0 step: 68, loss: 0.9534744
2025-04-07 10:36:24,626 - 4106154625.py[line:64] - INFO: epoch: 0 step: 69, loss: 0.970679


.

2025-04-07 10:36:26,527 - 4106154625.py[line:64] - INFO: epoch: 0 step: 70, loss: 0.9313204
2025-04-07 10:36:28,335 - 4106154625.py[line:64] - INFO: epoch: 0 step: 71, loss: 0.927449
2025-04-07 10:36:30,082 - 4106154625.py[line:64] - INFO: epoch: 0 step: 72, loss: 0.9536683
2025-04-07 10:36:31,761 - 4106154625.py[line:64] - INFO: epoch: 0 step: 73, loss: 0.92975646
2025-04-07 10:36:33,780 - 4106154625.py[line:64] - INFO: epoch: 0 step: 74, loss: 0.9387269


.

2025-04-07 10:36:35,917 - 4106154625.py[line:64] - INFO: epoch: 0 step: 75, loss: 0.9491191
2025-04-07 10:36:37,922 - 4106154625.py[line:64] - INFO: epoch: 0 step: 76, loss: 0.9263407
2025-04-07 10:36:39,572 - 4106154625.py[line:64] - INFO: epoch: 0 step: 77, loss: 0.95135903
2025-04-07 10:36:41,209 - 4106154625.py[line:64] - INFO: epoch: 0 step: 78, loss: 0.92555064
2025-04-07 10:36:42,827 - 4106154625.py[line:64] - INFO: epoch: 0 step: 79, loss: 0.93047976
2025-04-07 10:36:44,649 - 4106154625.py[line:64] - INFO: epoch: 0 step: 80, loss: 0.9445814


.

2025-04-07 10:36:46,501 - 4106154625.py[line:64] - INFO: epoch: 0 step: 81, loss: 0.92167306
2025-04-07 10:36:48,449 - 4106154625.py[line:64] - INFO: epoch: 0 step: 82, loss: 0.9199027
2025-04-07 10:36:50,603 - 4106154625.py[line:64] - INFO: epoch: 0 step: 83, loss: 0.95979875
2025-04-07 10:36:52,662 - 4106154625.py[line:64] - INFO: epoch: 0 step: 84, loss: 0.94403404
2025-04-07 10:36:54,314 - 4106154625.py[line:64] - INFO: epoch: 0 step: 85, loss: 0.91954345


.

2025-04-07 10:36:55,996 - 4106154625.py[line:64] - INFO: epoch: 0 step: 86, loss: 0.92873365
2025-04-07 10:36:57,701 - 4106154625.py[line:64] - INFO: epoch: 0 step: 87, loss: 0.91166925
2025-04-07 10:36:59,362 - 4106154625.py[line:64] - INFO: epoch: 0 step: 88, loss: 0.92743254
2025-04-07 10:37:01,139 - 4106154625.py[line:64] - INFO: epoch: 0 step: 89, loss: 0.9097767
2025-04-07 10:37:03,120 - 4106154625.py[line:64] - INFO: epoch: 0 step: 90, loss: 0.918455
2025-04-07 10:37:05,260 - 4106154625.py[line:64] - INFO: epoch: 0 step: 91, loss: 0.9123219


.

2025-04-07 10:37:06,972 - 4106154625.py[line:64] - INFO: epoch: 0 step: 92, loss: 0.9185343
2025-04-07 10:37:08,881 - 4106154625.py[line:64] - INFO: epoch: 0 step: 93, loss: 0.9153005
2025-04-07 10:37:10,973 - 4106154625.py[line:64] - INFO: epoch: 0 step: 94, loss: 0.90332276
2025-04-07 10:37:13,070 - 4106154625.py[line:64] - INFO: epoch: 0 step: 95, loss: 0.90544885
2025-04-07 10:37:14,777 - 4106154625.py[line:64] - INFO: epoch: 0 step: 96, loss: 0.92892224


.

2025-04-07 10:37:16,919 - 4106154625.py[line:64] - INFO: epoch: 0 step: 97, loss: 0.92682004
2025-04-07 10:37:18,923 - 4106154625.py[line:64] - INFO: epoch: 0 step: 98, loss: 0.9004317
2025-04-07 10:37:20,940 - 4106154625.py[line:64] - INFO: epoch: 0 step: 99, loss: 0.908974
2025-04-07 10:37:22,739 - 4106154625.py[line:64] - INFO: epoch: 0 step: 100, loss: 0.8956867
2025-04-07 10:37:24,509 - 4106154625.py[line:64] - INFO: epoch: 0 step: 101, loss: 0.8987319


.

2025-04-07 10:37:26,159 - 4106154625.py[line:64] - INFO: epoch: 0 step: 102, loss: 0.9083508
2025-04-07 10:37:27,783 - 4106154625.py[line:64] - INFO: epoch: 0 step: 103, loss: 0.89505464
2025-04-07 10:37:29,432 - 4106154625.py[line:64] - INFO: epoch: 0 step: 104, loss: 0.9006442
2025-04-07 10:37:31,031 - 4106154625.py[line:64] - INFO: epoch: 0 step: 105, loss: 0.8925739
2025-04-07 10:37:32,688 - 4106154625.py[line:64] - INFO: epoch: 0 step: 106, loss: 0.8919925
2025-04-07 10:37:34,278 - 4106154625.py[line:64] - INFO: epoch: 0 step: 107, loss: 0.8901893
2025-04-07 10:37:35,874 - 4106154625.py[line:64] - INFO: epoch: 0 step: 108, loss: 0.8947307


.

2025-04-07 10:37:37,562 - 4106154625.py[line:64] - INFO: epoch: 0 step: 109, loss: 0.89940923
2025-04-07 10:37:39,124 - 4106154625.py[line:64] - INFO: epoch: 0 step: 110, loss: 0.88965017
2025-04-07 10:37:40,773 - 4106154625.py[line:64] - INFO: epoch: 0 step: 111, loss: 0.8835504
2025-04-07 10:37:42,345 - 4106154625.py[line:64] - INFO: epoch: 0 step: 112, loss: 0.8785033
2025-04-07 10:37:43,921 - 4106154625.py[line:64] - INFO: epoch: 0 step: 113, loss: 0.8814548
2025-04-07 10:37:45,600 - 4106154625.py[line:64] - INFO: epoch: 0 step: 114, loss: 0.8877945


.

2025-04-07 10:37:47,338 - 4106154625.py[line:64] - INFO: epoch: 0 step: 115, loss: 0.88197625
2025-04-07 10:37:48,996 - 4106154625.py[line:64] - INFO: epoch: 0 step: 116, loss: 0.8941308
2025-04-07 10:37:50,679 - 4106154625.py[line:64] - INFO: epoch: 0 step: 117, loss: 0.88495713
2025-04-07 10:37:52,603 - 4106154625.py[line:64] - INFO: epoch: 0 step: 118, loss: 0.90219486
2025-04-07 10:37:54,497 - 4106154625.py[line:64] - INFO: epoch: 0 step: 119, loss: 0.89262724
2025-04-07 10:37:56,103 - 4106154625.py[line:64] - INFO: epoch: 0 step: 120, loss: 0.8879415


.

2025-04-07 10:37:57,735 - 4106154625.py[line:64] - INFO: epoch: 0 step: 121, loss: 0.878676
2025-04-07 10:37:59,364 - 4106154625.py[line:64] - INFO: epoch: 0 step: 122, loss: 0.8715365
2025-04-07 10:38:00,946 - 4106154625.py[line:64] - INFO: epoch: 0 step: 123, loss: 0.8677654
2025-04-07 10:38:02,558 - 4106154625.py[line:64] - INFO: epoch: 0 step: 124, loss: 0.8684499
2025-04-07 10:38:04,199 - 4106154625.py[line:64] - INFO: epoch: 0 step: 125, loss: 0.8848672
2025-04-07 10:38:05,816 - 4106154625.py[line:64] - INFO: epoch: 0 step: 126, loss: 0.8611082


.

2025-04-07 10:38:07,435 - 4106154625.py[line:64] - INFO: epoch: 0 step: 127, loss: 0.87677616
2025-04-07 10:38:09,051 - 4106154625.py[line:64] - INFO: epoch: 0 step: 128, loss: 0.8892087
2025-04-07 10:38:10,675 - 4106154625.py[line:64] - INFO: epoch: 0 step: 129, loss: 0.87242335
2025-04-07 10:38:12,362 - 4106154625.py[line:64] - INFO: epoch: 0 step: 130, loss: 0.86540776
2025-04-07 10:38:13,976 - 4106154625.py[line:64] - INFO: epoch: 0 step: 131, loss: 0.9510796
2025-04-07 10:38:15,605 - 4106154625.py[line:64] - INFO: epoch: 0 step: 132, loss: 0.8619976


.

2025-04-07 10:38:17,224 - 4106154625.py[line:64] - INFO: epoch: 0 step: 133, loss: 0.8630925
2025-04-07 10:38:18,780 - 4106154625.py[line:64] - INFO: epoch: 0 step: 134, loss: 0.85540855
2025-04-07 10:38:20,350 - 4106154625.py[line:64] - INFO: epoch: 0 step: 135, loss: 0.85183513
2025-04-07 10:38:21,884 - 4106154625.py[line:64] - INFO: epoch: 0 step: 136, loss: 0.8917813
2025-04-07 10:38:23,435 - 4106154625.py[line:64] - INFO: epoch: 0 step: 137, loss: 0.8526528
2025-04-07 10:38:24,950 - 4106154625.py[line:64] - INFO: epoch: 0 step: 138, loss: 0.8536273
2025-04-07 10:38:26,598 - 4106154625.py[line:64] - INFO: epoch: 0 step: 139, loss: 0.8565655


.

2025-04-07 10:38:28,139 - 4106154625.py[line:64] - INFO: epoch: 0 step: 140, loss: 0.8921677
2025-04-07 10:38:29,688 - 4106154625.py[line:64] - INFO: epoch: 0 step: 141, loss: 0.86149573
2025-04-07 10:38:31,311 - 4106154625.py[line:64] - INFO: epoch: 0 step: 142, loss: 0.8502701
2025-04-07 10:38:32,945 - 4106154625.py[line:64] - INFO: epoch: 0 step: 143, loss: 0.84761256
2025-04-07 10:38:34,574 - 4106154625.py[line:64] - INFO: epoch: 0 step: 144, loss: 0.8530063
2025-04-07 10:38:36,196 - 4106154625.py[line:64] - INFO: epoch: 0 step: 145, loss: 0.89813197


.

2025-04-07 10:38:37,836 - 4106154625.py[line:64] - INFO: epoch: 0 step: 146, loss: 0.86497414
2025-04-07 10:38:39,461 - 4106154625.py[line:64] - INFO: epoch: 0 step: 147, loss: 0.86043245
2025-04-07 10:38:41,038 - 4106154625.py[line:64] - INFO: epoch: 0 step: 148, loss: 0.8537921
2025-04-07 10:38:42,593 - 4106154625.py[line:64] - INFO: epoch: 0 step: 149, loss: 0.84643245
2025-04-07 10:38:44,350 - 4106154625.py[line:64] - INFO: epoch: 0 step: 150, loss: 0.84086126
2025-04-07 10:38:45,982 - 4106154625.py[line:64] - INFO: epoch: 0 step: 151, loss: 0.8376725


.

2025-04-07 10:38:47,621 - 4106154625.py[line:64] - INFO: epoch: 0 step: 152, loss: 0.8443006
2025-04-07 10:38:49,414 - 4106154625.py[line:64] - INFO: epoch: 0 step: 153, loss: 0.87024367
2025-04-07 10:38:51,379 - 4106154625.py[line:64] - INFO: epoch: 0 step: 154, loss: 0.8439486
2025-04-07 10:38:53,492 - 4106154625.py[line:64] - INFO: epoch: 0 step: 155, loss: 0.8428738
2025-04-07 10:38:55,505 - 4106154625.py[line:64] - INFO: epoch: 0 step: 156, loss: 0.8446244


.

2025-04-07 10:38:57,390 - 4106154625.py[line:64] - INFO: epoch: 0 step: 157, loss: 0.82819533
2025-04-07 10:38:59,154 - 4106154625.py[line:64] - INFO: epoch: 0 step: 158, loss: 0.8346045
2025-04-07 10:39:00,861 - 4106154625.py[line:64] - INFO: epoch: 0 step: 159, loss: 0.91556245
2025-04-07 10:39:02,460 - 4106154625.py[line:64] - INFO: epoch: 0 step: 160, loss: 0.8365531
2025-04-07 10:39:03,994 - 4106154625.py[line:64] - INFO: epoch: 0 step: 161, loss: 0.82283574
2025-04-07 10:39:05,550 - 4106154625.py[line:64] - INFO: epoch: 0 step: 162, loss: 0.83937204


.

2025-04-07 10:39:07,130 - 4106154625.py[line:64] - INFO: epoch: 0 step: 163, loss: 0.82220745
2025-04-07 10:39:08,702 - 4106154625.py[line:64] - INFO: epoch: 0 step: 164, loss: 0.8206043
2025-04-07 10:39:10,286 - 4106154625.py[line:64] - INFO: epoch: 0 step: 165, loss: 0.82163304
2025-04-07 10:39:11,858 - 4106154625.py[line:64] - INFO: epoch: 0 step: 166, loss: 0.9156118
2025-04-07 10:39:13,664 - 4106154625.py[line:64] - INFO: epoch: 0 step: 167, loss: 0.8271665
2025-04-07 10:39:15,392 - 4106154625.py[line:64] - INFO: epoch: 0 step: 168, loss: 0.8538544
2025-04-07 10:39:17,448 - 4106154625.py[line:64] - INFO: epoch: 0 step: 169, loss: 0.81377554


.

2025-04-07 10:39:19,382 - 4106154625.py[line:64] - INFO: epoch: 0 step: 170, loss: 0.82164574
2025-04-07 10:39:21,451 - 4106154625.py[line:64] - INFO: epoch: 0 step: 171, loss: 0.8611313
2025-04-07 10:39:23,232 - 4106154625.py[line:64] - INFO: epoch: 0 step: 172, loss: 0.910937
2025-04-07 10:39:25,157 - 4106154625.py[line:64] - INFO: epoch: 0 step: 173, loss: 0.81960344
2025-04-07 10:39:27,027 - 4106154625.py[line:64] - INFO: epoch: 0 step: 174, loss: 0.8318243


.

2025-04-07 10:39:28,879 - 4106154625.py[line:64] - INFO: epoch: 0 step: 175, loss: 0.8163141
2025-04-07 10:39:30,569 - 4106154625.py[line:64] - INFO: epoch: 0 step: 176, loss: 0.81251186
2025-04-07 10:39:32,357 - 4106154625.py[line:64] - INFO: epoch: 0 step: 177, loss: 0.8562678
2025-04-07 10:39:34,015 - 4106154625.py[line:64] - INFO: epoch: 0 step: 178, loss: 0.815516
2025-04-07 10:39:35,701 - 4106154625.py[line:64] - INFO: epoch: 0 step: 179, loss: 0.8176594
2025-04-07 10:39:37,351 - 4106154625.py[line:64] - INFO: epoch: 0 step: 180, loss: 0.81118274


.

2025-04-07 10:39:38,946 - 4106154625.py[line:64] - INFO: epoch: 0 step: 181, loss: 0.80203724
2025-04-07 10:39:40,642 - 4106154625.py[line:64] - INFO: epoch: 0 step: 182, loss: 0.87345916
2025-04-07 10:39:42,321 - 4106154625.py[line:64] - INFO: epoch: 0 step: 183, loss: 0.81266487
2025-04-07 10:39:43,999 - 4106154625.py[line:64] - INFO: epoch: 0 step: 184, loss: 0.80216926
2025-04-07 10:39:45,764 - 4106154625.py[line:64] - INFO: epoch: 0 step: 185, loss: 0.80834883
2025-04-07 10:39:47,643 - 4106154625.py[line:64] - INFO: epoch: 0 step: 186, loss: 0.8091302


.

2025-04-07 10:39:49,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 187, loss: 0.85867965
2025-04-07 10:39:51,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 188, loss: 0.83379465
2025-04-07 10:39:53,201 - 4106154625.py[line:64] - INFO: epoch: 0 step: 189, loss: 0.8088391
2025-04-07 10:39:54,998 - 4106154625.py[line:64] - INFO: epoch: 0 step: 190, loss: 0.80790806
2025-04-07 10:39:57,158 - 4106154625.py[line:64] - INFO: epoch: 0 step: 191, loss: 0.8407364


.

2025-04-07 10:39:59,182 - 4106154625.py[line:64] - INFO: epoch: 0 step: 192, loss: 0.8151839
2025-04-07 10:40:00,872 - 4106154625.py[line:64] - INFO: epoch: 0 step: 193, loss: 0.78970444
2025-04-07 10:40:02,929 - 4106154625.py[line:64] - INFO: epoch: 0 step: 194, loss: 0.79682875
2025-04-07 10:40:04,755 - 4106154625.py[line:64] - INFO: epoch: 0 step: 195, loss: 0.82242036
2025-04-07 10:40:06,438 - 4106154625.py[line:64] - INFO: epoch: 0 step: 196, loss: 0.7956406


.

2025-04-07 10:40:08,369 - 4106154625.py[line:64] - INFO: epoch: 0 step: 197, loss: 0.8161787
2025-04-07 10:40:10,191 - 4106154625.py[line:64] - INFO: epoch: 0 step: 198, loss: 0.8084446
2025-04-07 10:40:11,973 - 4106154625.py[line:64] - INFO: epoch: 0 step: 199, loss: 0.8210702
2025-04-07 10:40:13,663 - 4106154625.py[line:64] - INFO: epoch: 0 step: 200, loss: 0.80087566
2025-04-07 10:40:15,493 - 4106154625.py[line:64] - INFO: epoch: 0 step: 201, loss: 0.87920845
2025-04-07 10:40:17,323 - 4106154625.py[line:64] - INFO: epoch: 0 step: 202, loss: 0.8160571


.

2025-04-07 10:40:19,189 - 4106154625.py[line:64] - INFO: epoch: 0 step: 203, loss: 0.7799623
2025-04-07 10:40:21,020 - 4106154625.py[line:64] - INFO: epoch: 0 step: 204, loss: 0.81907594
2025-04-07 10:40:22,823 - 4106154625.py[line:64] - INFO: epoch: 0 step: 205, loss: 0.78082323
2025-04-07 10:40:24,593 - 4106154625.py[line:64] - INFO: epoch: 0 step: 206, loss: 0.7767377
2025-04-07 10:40:26,411 - 4106154625.py[line:64] - INFO: epoch: 0 step: 207, loss: 0.78217006


.

2025-04-07 10:40:28,204 - 4106154625.py[line:64] - INFO: epoch: 0 step: 208, loss: 0.78541696
2025-04-07 10:40:30,055 - 4106154625.py[line:64] - INFO: epoch: 0 step: 209, loss: 0.788193
2025-04-07 10:40:31,905 - 4106154625.py[line:64] - INFO: epoch: 0 step: 210, loss: 0.77395964
2025-04-07 10:40:33,954 - 4106154625.py[line:64] - INFO: epoch: 0 step: 211, loss: 0.7963271
2025-04-07 10:40:35,947 - 4106154625.py[line:64] - INFO: epoch: 0 step: 212, loss: 0.77294105
2025-04-07 10:40:37,721 - 4106154625.py[line:64] - INFO: epoch: 0 step: 213, loss: 0.7669926


.

2025-04-07 10:40:39,729 - 4106154625.py[line:64] - INFO: epoch: 0 step: 214, loss: 0.79589576
2025-04-07 10:40:41,758 - 4106154625.py[line:64] - INFO: epoch: 0 step: 215, loss: 0.7651855
2025-04-07 10:40:43,662 - 4106154625.py[line:64] - INFO: epoch: 0 step: 216, loss: 0.820046
2025-04-07 10:40:45,532 - 4106154625.py[line:64] - INFO: epoch: 0 step: 217, loss: 0.7689292
2025-04-07 10:40:47,505 - 4106154625.py[line:64] - INFO: epoch: 0 step: 218, loss: 0.81641614


.

2025-04-07 10:40:49,338 - 4106154625.py[line:64] - INFO: epoch: 0 step: 219, loss: 0.76227266
2025-04-07 10:40:51,284 - 4106154625.py[line:64] - INFO: epoch: 0 step: 220, loss: 0.85349905
2025-04-07 10:40:53,122 - 4106154625.py[line:64] - INFO: epoch: 0 step: 221, loss: 0.8078137
2025-04-07 10:40:54,912 - 4106154625.py[line:64] - INFO: epoch: 0 step: 222, loss: 0.7646342
2025-04-07 10:40:56,772 - 4106154625.py[line:64] - INFO: epoch: 0 step: 223, loss: 0.7557045
2025-04-07 10:40:58,621 - 4106154625.py[line:64] - INFO: epoch: 0 step: 224, loss: 0.76513314


.

2025-04-07 10:41:00,458 - 4106154625.py[line:64] - INFO: epoch: 0 step: 225, loss: 0.7822351
2025-04-07 10:41:02,231 - 4106154625.py[line:64] - INFO: epoch: 0 step: 226, loss: 0.7729878
2025-04-07 10:41:04,074 - 4106154625.py[line:64] - INFO: epoch: 0 step: 227, loss: 0.75777054
2025-04-07 10:41:05,926 - 4106154625.py[line:64] - INFO: epoch: 0 step: 228, loss: 0.7532151
2025-04-07 10:41:07,785 - 4106154625.py[line:64] - INFO: epoch: 0 step: 229, loss: 0.795061


.

2025-04-07 10:41:09,631 - 4106154625.py[line:64] - INFO: epoch: 0 step: 230, loss: 0.7710381
2025-04-07 10:41:11,459 - 4106154625.py[line:64] - INFO: epoch: 0 step: 231, loss: 0.7682188
2025-04-07 10:41:13,288 - 4106154625.py[line:64] - INFO: epoch: 0 step: 232, loss: 0.7783369
2025-04-07 10:41:15,137 - 4106154625.py[line:64] - INFO: epoch: 0 step: 233, loss: 0.7680697
2025-04-07 10:41:17,048 - 4106154625.py[line:64] - INFO: epoch: 0 step: 234, loss: 0.75664115
2025-04-07 10:41:18,831 - 4106154625.py[line:64] - INFO: epoch: 0 step: 235, loss: 0.7511877


.

2025-04-07 10:41:20,764 - 4106154625.py[line:64] - INFO: epoch: 0 step: 236, loss: 0.7427261
2025-04-07 10:41:22,569 - 4106154625.py[line:64] - INFO: epoch: 0 step: 237, loss: 0.8036304
2025-04-07 10:41:24,487 - 4106154625.py[line:64] - INFO: epoch: 0 step: 238, loss: 0.76217574
2025-04-07 10:41:26,373 - 4106154625.py[line:64] - INFO: epoch: 0 step: 239, loss: 0.7397079
2025-04-07 10:41:28,139 - 4106154625.py[line:64] - INFO: epoch: 0 step: 240, loss: 0.8942822
2025-04-07 10:41:30,037 - 4106154625.py[line:64] - INFO: epoch: 0 step: 241, loss: 0.74506545


.

2025-04-07 10:41:32,130 - 4106154625.py[line:64] - INFO: epoch: 0 step: 242, loss: 0.7901791
2025-04-07 10:41:34,077 - 4106154625.py[line:64] - INFO: epoch: 0 step: 243, loss: 0.74124205
2025-04-07 10:41:35,979 - 4106154625.py[line:64] - INFO: epoch: 0 step: 244, loss: 0.7894727
2025-04-07 10:41:37,959 - 4106154625.py[line:64] - INFO: epoch: 0 step: 245, loss: 0.83756655
2025-04-07 10:41:39,831 - 4106154625.py[line:64] - INFO: epoch: 0 step: 246, loss: 0.7398231


.

2025-04-07 10:41:41,763 - 4106154625.py[line:64] - INFO: epoch: 0 step: 247, loss: 0.76385504
2025-04-07 10:41:43,700 - 4106154625.py[line:64] - INFO: epoch: 0 step: 248, loss: 0.7347469
2025-04-07 10:41:45,518 - 4106154625.py[line:64] - INFO: epoch: 0 step: 249, loss: 0.8313259
2025-04-07 10:41:47,373 - 4106154625.py[line:64] - INFO: epoch: 0 step: 250, loss: 0.8136975
2025-04-07 10:41:49,420 - 4106154625.py[line:64] - INFO: epoch: 0 step: 251, loss: 0.7310439


.

......

2025-04-07 13:39:55,859 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1247, loss: 0.021378823
2025-04-07 13:39:57,754 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1248, loss: 0.01565772
2025-04-07 13:39:59,606 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1249, loss: 0.012067624
2025-04-07 13:40:01,396 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1250, loss: 0.017700804
2025-04-07 13:40:03,181 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1251, loss: 0.06254268
2025-04-07 13:40:04,945 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1252, loss: 0.013293369


.

2025-04-07 13:40:06,770 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1253, loss: 0.026906993
2025-04-07 13:40:08,644 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1254, loss: 0.18210539
2025-04-07 13:40:10,593 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1255, loss: 0.024170894
2025-04-07 13:40:12,430 - 4106154625.py[line:69] - INFO: Epoch 4 completed in 2274.61s | Avg Loss: 0.0517


## 模型评估与可视化

完成训练后，我们使用第5个ckpt进行推理。下述展示了预测值与实际值之间的误差和各项指标以及结果可视化。

In [14]:
def get_alignment_kwargs_avg_x(target_seq):
    """Generate alignment parameters for guided sampling"""
    batch_size = target_seq.shape[0]
    avg_intensity = mint.mean(target_seq.view(batch_size, -1), dim=1, keepdim=True)
    return {"avg_x_gt": avg_intensity * 2.0}


class DiffusionInferrence(nn.Cell):
    """
    Class managing model inference and evaluation processes. Handles loading checkpoints,
    generating predictions, calculating evaluation metrics, and saving visualization results.
    """
    def __init__(self, main_module, dm, logger, config):
        """
        Initialize inference manager with model, data module, logger, and configuration.
        Args:
            main_module: Main diffusion model for inference
            dm: Data module providing test dataset
            logger: Logging utility for evaluation progress
            config: Configuration dictionary containing evaluation parameters
        """
        super().__init__()
        self.num_samples = config["eval"].get("num_samples_per_context", 1)
        self.eval_example_only = config["eval"].get("eval_example_only", True)
        self.alignment_type = (
            config.get("model", {}).get("align", {}).get("alignment_type", "avg_x")
        )
        self.use_alignment = self.alignment_type is not None
        self.eval_aligned = config["eval"].get("eval_aligned", True)
        self.eval_unaligned = config["eval"].get("eval_unaligned", True)
        self.num_samples_per_context = config["eval"].get("num_samples_per_context", 1)
        self.logging_prefix = config["logging"].get("logging_prefix", "PreDiff")
        self.test_example_data_idx_list = [48]
        self.main_module = main_module
        self.testdataset = dm.sevir_test
        self.logger = logger
        self.datasetprocessing = SEVIRDataset(
            data_types=["vil"],
            layout="NHWT",
            rescale_method=config.get("rescale_method", "01"),
        )
        self.example_save_dir = config["summary"].get("summary_dir", "./summary")

        self.fs = config["eval"].get("fs", 20)
        self.label_offset = config["eval"].get("label_offset", [-0.5, 0.5])
        self.label_avg_int = config["eval"].get("label_avg_int", False)

        self.current_epoch = 0

        self.learn_logvar = (
            config.get("model", {}).get("diffusion", {}).get("learn_logvar", False)
        )
        self.logvar = main_module.logvar
        self.maeloss = nn.MAELoss()
        self.test_metrics = {
            "step": 0,
            "mse": 0.0,
            "mae": 0.0,
            "ssim": 0.0,
            "mse_kc": 0.0,
            "mae_kc": 0.0,
        }

    def test(self):
        """Execute complete evaluation pipeline."""
        self.logger.info("============== Start Test ==============")
        self.start_time = time.time()
        for batch_idx, item in enumerate(self.testdataset.create_dict_iterator()):
            self.test_metrics = self._test_onestep(item, batch_idx, self.test_metrics)

        self._finalize_test(self.test_metrics)

    def _test_onestep(self, item, batch_idx, metrics):
        """Process one test batch and update evaluation metrics."""
        data_idx = int(batch_idx * 2)
        if not self._should_test_onestep(data_idx):
            return metrics
        data = item.get("vil")
        data = self.datasetprocessing.process_data(data)
        target_seq, cond, context_seq = self._get_model_inputs(data)
        aligned_preds, unaligned_preds = self._generate_predictions(
            cond, target_seq
        )
        metrics = self._update_metrics(
            aligned_preds, unaligned_preds, target_seq, metrics
        )
        self._plt_pred(
            data_idx,
            context_seq,
            target_seq,
            aligned_preds,
            unaligned_preds,
            metrics["step"],
        )

        metrics["step"] += 1
        return metrics

    def _should_test_onestep(self, data_idx):
        """Determine if evaluation should be performed on current data index."""
        return (not self.eval_example_only) or (
            data_idx in self.test_example_data_idx_list
        )

    def _get_model_inputs(self, data):
        """Extract and prepare model inputs from raw data."""
        target_seq, cond, context_seq = self.main_module.get_input(
            data, return_verbose=True
        )
        return target_seq, cond, context_seq

    def _generate_predictions(self, cond, target_seq):
        """Generate both aligned and unaligned predictions from the model."""
        aligned_preds = []
        unaligned_preds = []

        for _ in range(self.num_samples_per_context):
            if self.use_alignment and self.eval_aligned:
                aligned_pred = self._sample_with_alignment(
                    cond, target_seq
                )
                aligned_preds.append(aligned_pred)

            if self.eval_unaligned:
                unaligned_pred = self._sample_without_alignment(cond)
                unaligned_preds.append(unaligned_pred)

        return aligned_preds, unaligned_preds

    def _sample_with_alignment(self, cond, target_seq):
        """Generate predictions using alignment mechanism."""
        alignment_kwargs = get_alignment_kwargs_avg_x(target_seq)
        pred_seq = self.main_module.sample(
            cond=cond,
            batch_size=cond["y"].shape[0],
            return_intermediates=False,
            use_alignment=True,
            alignment_kwargs=alignment_kwargs,
            verbose=False,
        )
        if pred_seq.dtype != ms.float32:
            pred_seq = pred_seq.float()
        return pred_seq

    def _sample_without_alignment(self, cond):
        """Generate predictions without alignment."""
        pred_seq = self.main_module.sample(
            cond=cond,
            batch_size=cond["y"].shape[0],
            return_intermediates=False,
            verbose=False,
        )
        if pred_seq.dtype != ms.float32:
            pred_seq = pred_seq.float()
        return pred_seq

    def _update_metrics(self, aligned_preds, unaligned_preds, target_seq, metrics):
        """Update evaluation metrics with new predictions."""
        for pred in aligned_preds:
            metrics["mse_kc"] += ops.mse_loss(pred, target_seq)
            metrics["mae_kc"] += self.maeloss(pred, target_seq)
            self.main_module.test_aligned_score.update(pred, target_seq)

        for pred in unaligned_preds:
            metrics["mse"] += ops.mse_loss(pred, target_seq)
            metrics["mae"] += self.maeloss(pred, target_seq)
            self.main_module.test_score.update(pred, target_seq)

            pred_bchw = self._convert_to_bchw(pred)
            target_bchw = self._convert_to_bchw(target_seq)
            metrics["ssim"] += self.main_module.test_ssim(pred_bchw, target_bchw)[0]

        return metrics

    def _convert_to_bchw(self, tensor):
        """Convert tensor to batch-channel-height-width format for metrics."""
        return rearrange(tensor.asnumpy(), "b t h w c -> (b t) c h w")

    def _plt_pred(
            self, data_idx, context_seq, target_seq, aligned_preds, unaligned_preds, step
    ):
        """Generate and save visualization of predictions."""
        pred_sequences = [pred[0].asnumpy() for pred in aligned_preds + unaligned_preds]
        pred_labels = [
            f"{self.logging_prefix}_aligned_pred_{i}" for i in range(len(aligned_preds))
        ] + [f"{self.logging_prefix}_pred_{i}" for i in range(len(unaligned_preds))]

        self.save_vis_step_end(
            data_idx=data_idx,
            context_seq=context_seq[0].asnumpy(),
            target_seq=target_seq[0].asnumpy(),
            pred_seq=pred_sequences,
            pred_label=pred_labels,
            mode="test",
            suffix=f"_step_{step}",
        )

    def _finalize_test(self, metrics):
        """Complete test process and log final metrics."""
        total_time = (time.time() - self.start_time) * 1000
        self.logger.info(f"test cost: {total_time:.2f} ms")
        self._compute_total_metrics(metrics)
        self.logger.info("============== Test Completed ==============")

    def _compute_total_metrics(self, metrics):
        """log_metrics"""
        step_count = max(metrics["step"], 1)
        if self.eval_unaligned:
            self.logger.info(f"MSE: {metrics['mse'] / step_count}")
            self.logger.info(f"MAE: {metrics['mae'] / step_count}")
            self.logger.info(f"SSIM: {metrics['ssim'] / step_count}")
            test_score = self.main_module.test_score.eval()
            self.logger.info("SCORE:\n%s", json.dumps(test_score, indent=4))
        if self.use_alignment:
            self.logger.info(f"KC_MSE: {metrics['mse_kc'] / step_count}")
            self.logger.info(f"KC_MAE: {metrics['mae_kc'] / step_count}")
            aligned_score = self.main_module.test_aligned_score.eval()
            self.logger.info("KC_SCORE:\n%s", json.dumps(aligned_score, indent=4))

    def save_vis_step_end(
            self,
            data_idx: int,
            context_seq: np.ndarray,
            target_seq: np.ndarray,
            pred_seq: Union[np.ndarray, Sequence[np.ndarray]],
            pred_label: Union[str, Sequence[str]] = None,
            mode: str = "train",
            prefix: str = "",
            suffix: str = "",
    ):
        """Save visualization of predictions with context and target."""
        example_data_idx_list = self.test_example_data_idx_list
        if isinstance(pred_seq, Sequence):
            seq_list = [context_seq, target_seq] + list(pred_seq)
            label_list = ["context", "target"] + pred_label
        else:
            seq_list = [context_seq, target_seq, pred_seq]
            label_list = ["context", "target", pred_label]
        if data_idx in example_data_idx_list:
            png_save_name = f"{prefix}{mode}_data_{data_idx}{suffix}.png"
            vis_sevir_seq(
                save_path=os.path.join(self.example_save_dir, png_save_name),
                seq=seq_list,
                label=label_list,
                interval_real_time=10,
                plot_stride=1,
                fs=self.fs,
                label_offset=self.label_offset,
                label_avg_int=self.label_avg_int,
            )

In [15]:
main_module.main_model.set_train(False)
params = ms.load_checkpoint("/home/lry/202542测试/PreDiff/summary/prediff/single_device0/ckpt/diffusion_epoch4.ckpt")
a, b = ms.load_param_into_net(main_module.main_model, params)
print(b)
tester = DiffusionInferrence(
        main_module=main_module, dm=dm, logger=logger, config=config
    )
tester.test()




[]
..

2025-04-07 14:10:31,931 - 2610859736.py[line:201] - INFO: test cost: 375371.60 ms
2025-04-07 14:10:31,937 - 2610859736.py[line:215] - INFO: KC_MSE: 0.0036273836
2025-04-07 14:10:31,939 - 2610859736.py[line:216] - INFO: KC_MAE: 0.017427118
2025-04-07 14:10:31,955 - 2610859736.py[line:218] - INFO: KC_SCORE:
{
    "16": {
        "csi": 0.2715393900871277,
        "pod": 0.5063194632530212,
        "sucr": 0.369321346282959,
        "bias": 3.9119162559509277
    },
    "74": {
        "csi": 0.15696434676647186,
        "pod": 0.17386901378631592,
        "sucr": 0.6175059080123901,
        "bias": 0.16501028835773468
    }
}
