# G-TEAM地震预警模型

## 概述

地震预警系统旨在在破坏性震动到达前尽早发出警报，以减少人员伤亡和经济损失。G-TEAM 模型是一种数据驱动的全国地震预警系统，结合了图神经网络（GNN）和 Transformer 架构，能够在地震发生后 3 秒内迅速提供震中位置、震级及地震强度分布。该模型通过直接处理原始地震波形数据，避免了手动特征选择的限制，并充分利用多台站数据，提高了预测的准确性和实时性。

本模型是一款高效的地震预警系统，结合了图神经网络（Graph Neural Network, GNN）与 Transformer 架构，以任意数量的地震台站记录的地震波形数据作为输入。该模型能够实时接收地震信号，并对震源位置、震级以及地震烈度分布范围进行快速且精准的估计，其中烈度分布范围以地面峰值加速度（Peak Ground Acceleration, PGA）表征。通过深度学习方法，本模型可以充分利用地震台网的空间关联性与时序特征，提高预警精度和响应速度，为地震应急响应和减灾决策提供可靠支持。

![](./images/image.png)

该模型采用多源地震台站数据进行PGA预测，具体架构如下：首先，系统接收多个地震台站的位置信息及其记录的地震波形数据，同时获取待估计PGA的目标位置坐标。对于每个地震台站的波形数据，首先进行标准化处理，随后通过卷积神经网络（CNN）进行特征提取。提取的特征经全连接层进行特征融合，并与对应台站的位置信息共同构成特征向量。
目标PGA位置坐标经过位置编码模块处理后，形成特征向量。所有特征向量按序列形式输入到Transformer编码器中，编码器通过自注意力机制捕捉全局依赖关系。编码器输出依次通过三个独立的全连接层，分别完成地震事件震级、震中位置以及PGA的回归预测任务。

本模型的训练数据来源于[谛听数据集2.0 -中国地震台网多功能大型人工智能训练数据集](http://www.esdc.ac.cn/article/137)，该数据集汇集了中国大陆及其邻近地区（15°-50°N，65°-140°E）1177 个中国地震台网固定台站的波形记录，覆盖时间范围为 2020 年 3 月至 2023 年 2 月。数据集包含研究区域内所有震级大于 0 的地方震事件，共计 264,298 个。我们在训练过程中仅选取了初至 P 波和 S 波震相，并且只保留至少被三个台站记录到的地震事件，以确保数据的可靠性和稳定性。

目前本模型已开源推理部分，可使用提供的ckpt进行推理。


In [19]:
import numpy as np

import mindspore as ms
from mindspore import context

下述src可在[GTEAM/src](./src)下载

In [13]:
from mindearth import load_yaml_config, make_dir

from src.utils import (
    predict_at_time,
    calc_mag_stats,
    calc_loc_stats,
    calc_pga_stats,
    init_model,
    get_logger
)
from src.forcast import GTeamInference
from src.data import load_data
from src.visual import generate_true_pred_plot

可以在[配置文件](./config/GTEAM.yaml)中配置模型、数据和优化器等参数。

In [14]:
config = load_yaml_config("./config/GTEAM.yaml")
context.set_context(mode=ms.PYNATIVE_MODE)
ms.set_device(device_target="Ascend", device_id=0)

In [15]:
save_dir = config["summary"].get("summary_dir", "./summary")
make_dir(save_dir)
logger_obj = get_logger(config)

2025-04-05 08:57:36,391 - utils.py[line:179] - INFO: {'hidden_dim': 1000, 'hidden_dropout': 0.0, 'n_heads': 10, 'n_pga_targets': 15, 'output_location_dims': [150, 100, 50, 30, 3], 'output_mlp_dims': [150, 100, 50, 30, 1], 'transformer_layers': 6, 'waveform_model_dims': [500, 500, 500], 'wavelength': [[0.01, 15], [0.01, 15], [0.01, 10]], 'times': [5], 'run_with_less_data': False, 'pga': True, 'mode': 'test', 'no_event_token': False}
2025-04-05 08:57:36,391 - utils.py[line:179] - INFO: {'hidden_dim': 1000, 'hidden_dropout': 0.0, 'n_heads': 10, 'n_pga_targets': 15, 'output_location_dims': [150, 100, 50, 30, 3], 'output_mlp_dims': [150, 100, 50, 30, 1], 'transformer_layers': 6, 'waveform_model_dims': [500, 500, 500], 'wavelength': [[0.01, 15], [0.01, 15], [0.01, 10]], 'times': [5], 'run_with_less_data': False, 'pga': True, 'mode': 'test', 'no_event_token': False}
2025-04-05 08:57:36,392 - utils.py[line:179] - INFO: {'root_dir': './dataset', 'batch_size': 64, 'max_stations': 5, 'disable_sta

## 初始化模型

In [16]:
model = init_model(config)

## 数据集准备

根据地震后发生时间选择不同台站检测的数据

In [None]:
class GTeamInference:
    """
    Initialize the GTeamInference class.
    """

    def __init__(self, model_ins, cfg, output_dir, logger):
        """
        Args:
            model_ins: The model instance used for inference.
            cfg: Configuration dictionary containing model and data parameters.
            output_dir: Directory to save the output results.
        Attributes:
            model: The model instance for inference.
            cfg: Configuration dictionary.
            output_dir: Directory to save outputs.
            pga: Flag indicating if PGA (Peak Ground Acceleration) is enabled.
            generator_params: Parameters for data generation.
            model_params: Parameters specific to the model.
            mag_key: Key for magnitude-related data.
            pos_offset: Position offset for location predictions.
            mag_stats: List to store magnitude prediction statistics.
            loc_stats: List to store location prediction statistics.
            pga_stats: List to store PGA prediction statistics.
        """
        self.model = model_ins
        self.cfg = cfg
        self.output_dir = output_dir
        self.logger = logger
        self.pga = cfg["model"].get("pga", "true")
        self.generator_params = cfg["data"]
        self.model_params = cfg["model"]
        self.output_dir = output_dir
        self.mag_key = self.generator_params["key"]
        self.pos_offset = self.generator_params["pos_offset"]
        self.mag_stats = []
        self.loc_stats = []
        self.pga_stats = []

    def _parse_predictions(self, pred):
        """
        Parse the raw predictions into magnitude, location, and PGA components.
        """
        mag_pred = pred[0]
        loc_pred = pred[1]
        pga_pred = pred[2] if self.pga else []
        return mag_pred, loc_pred, pga_pred

    def _process_predictions(
            self, mag_pred, loc_pred, pga_pred, time, evt_metadata, pga_true
    ):
        """
        Process the parsed predictions to compute statistics and generate plots.
        """
        mag_pred_np = [t[0].asnumpy() for t in mag_pred]
        mag_pred_reshaped = np.concatenate(mag_pred_np, axis=0)

        loc_pred_np = [t[0].asnumpy() for t in loc_pred]
        loc_pred_reshaped = np.array(loc_pred_np)

        pga_pred_np = [t.asnumpy() for t in pga_pred]
        pga_pred_reshaped = np.concatenate(pga_pred_np, axis=0)
        pga_true_reshaped = np.log(
            np.abs(np.concatenate(pga_true, axis=0).reshape(-1, 1))
        )

        if not self.model_params["no_event_token"]:
            self.mag_stats += calc_mag_stats(
                mag_pred_reshaped, evt_metadata, self.mag_key
            )

            self.loc_stats += calc_loc_stats(
                loc_pred_reshaped, evt_metadata, self.pos_offset
            )

            generate_true_pred_plot(
                mag_pred_reshaped,
                evt_metadata[self.mag_key].values,
                time,
                self.output_dir,
            )
        self.pga_stats = calc_pga_stats(pga_pred_reshaped, pga_true_reshaped)

    def _save_results(self):
        """
        Save the final results (magnitude, location, and PGA statistics) to a JSON file.
        """
        times = self.cfg["model"].get("times")
        self.logger.info("times: {}".format(times))
        self.logger.info("mag_stats: {}".format(self.mag_stats))
        self.logger.info("loc_stats: {}".format(self.loc_stats))
        self.logger.info("pga_stats: {}".format(self.pga_stats))

    def test(self):
        """
        Perform inference for all specified times, process predictions, and save results.
        This method iterates over the specified times, performs predictions, processes
        the results, and saves the final statistics.
        """
        data_data, evt_key, evt_metadata, meta_data, data_path = load_data(self.cfg)
        pga_true = data_data["pga"]
        for time in self.cfg["model"].get("times"):
            pred = predict_at_time(
                self.model,
                time,
                data_data,
                data_path,
                evt_key,
                evt_metadata,
                config=self.cfg,
                pga=self.pga,
                sampling_rate=meta_data["sampling_rate"],
            )
            mag_pred, loc_pred, pga_pred = self._parse_predictions(pred)
            self._process_predictions(
                mag_pred, loc_pred, pga_pred, time, evt_metadata, pga_true
            )
        self._save_results()
        print("Inference completed and results saved")


## 开始推理

In [17]:
processor = GTeamInference(model, config, save_dir, logger_obj)
processor.test()

Data loaded from ./dataset/diting2_2020-2022_sc_abridged_test_filter_pga.pkl


2025-04-05 08:57:42,398 - forcast.py[line:115] - INFO: times: [5]
2025-04-05 08:57:42,398 - forcast.py[line:115] - INFO: times: [5]
2025-04-05 08:57:42,399 - forcast.py[line:116] - INFO: mag_stats: [-5.849881172180176, 0.26172267853934106, 0.2561628818511963]
2025-04-05 08:57:42,399 - forcast.py[line:116] - INFO: mag_stats: [-5.849881172180176, 0.26172267853934106, 0.2561628818511963]
2025-04-05 08:57:42,400 - forcast.py[line:117] - INFO: loc_stats: [5.55861115185705, 5.1707730693636345, 4.317579930843666, 4.128873124004999]
2025-04-05 08:57:42,400 - forcast.py[line:117] - INFO: loc_stats: [5.55861115185705, 5.1707730693636345, 4.317579930843666, 4.128873124004999]
2025-04-05 08:57:42,402 - forcast.py[line:118] - INFO: pga_stats: [0.8641006385570611, 0.4655571071890895, 0.28675066434439034]
2025-04-05 08:57:42,402 - forcast.py[line:118] - INFO: pga_stats: [0.8641006385570611, 0.4655571071890895, 0.28675066434439034]


Inference completed and results saved
