In [None]:
import torch
import h5py
import numpy as np
from models import create_model, load_checkpoint
from utils import vis_phase_picking


def normalize(data: np.ndarray, mode: str):
    data -= np.mean(data, axis=1, keepdims=True)
    if mode == "max":
        max_data = np.max(data, axis=1, keepdims=True)
        max_data[max_data == 0] = 1
        data /= max_data

    elif mode == "std":
        std_data = np.std(data, axis=1, keepdims=True)
        std_data[std_data == 0] = 1
        data /= std_data
    elif mode == "":
        return data
    else:
        raise ValueError(f"Supported mode: 'max','std', got '{mode}'")
    return data


def load_data(
    data_path: str = "./datasets/STEAD/chunk2.hdf5",
    trace_name: str = "B087.PB_20111102050415_EV",
):
    # Read HDF5
    with h5py.File(data_path, "r") as f:
        data = f.get(f"data/{trace_name}")
        data = np.array(data).astype(np.float32).T
    # 检查是否存在 nan 值
    print("是否存在 nan 值:", np.isnan(data).any())

    return data

In [None]:
def load_model(
    model_name: str,
    ckpt_path: str,
    device: torch.device,
    in_channels: int = 3,
    in_samples: int = 6000,
):
    # Model init
    model = create_model(
        model_name=model_name, in_channels=in_channels, in_samples=in_samples
    )

    # Load parameters
    ckpt = load_checkpoint(ckpt_path, device=device)

    model_state_dict = ckpt["model_dict"] if "model_dict" in ckpt else ckpt
    
    filtered_state_dict = {k: v for k, v in model_state_dict.items() if k in model.state_dict()}
    model.load_state_dict(filtered_state_dict, strict=False)

    model.to(device)

    return model

In [None]:
def get_phase_times(metadata_path: str, trace_name: str):
    # Read metadata CSV
    metadata = pd.read_csv("./datasets/STEAD/chunk2.csv")
    
    # Find the row corresponding to the trace_name
    trace_metadata = metadata[metadata['trace_name'] == trace_name]
    
    if trace_metadata.empty:
        raise ValueError(f"Trace name '{trace_name}' not found in metadata.")
    
    # Extract p_arrival_time and s_arrival_time
    p_arrival_sample = trace_metadata['p_arrival_sample'].values[0]
    s_arrival_sample = trace_metadata['s_arrival_sample'].values[0]
    
    return p_arrival_sample, s_arrival_sample

In [None]:
import pandas as pd

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Step.1 - Load Model 
    model = load_model(
        model_name="GPT4EQ",
        ckpt_path="./logs/example/checkpoints/GPT4EQ.pth",
        device=device,
        in_channels=3,
    )

    # Step.2 - Load waveforms
    waveform_ndarray = load_data(
        data_path="./datasets/STEAD/chunk2.hdf5",
        trace_name="B087.PB_20111102050415_EV",
    )
    waveform_ndarray = waveform_ndarray[:, :6000]
    waveform_ndarray = normalize(waveform_ndarray, mode="std")
    waveform_tensor = torch.from_numpy(waveform_ndarray).reshape(1, 3, -1).to(device)

    print(f"waveform shape: {waveform_ndarray.shape}")
    
    p_arrival_sample, s_arrival_sample = get_phase_times(
        metadata_path="./datasets/STEAD/chunk2.csv",
        trace_name="B087.PB_20111102050415_EV",
    )
    true_phase_idxs = [p_arrival_sample, s_arrival_sample]


    # Step.3 - Inference
    preds_tensor = model(waveform_tensor)
    
    print(f"Preds tensor shape: {preds_tensor.shape}")
    
    preds_ndarray = preds_tensor.detach().cpu().numpy().reshape(3, -1)
    
    print(f"Preds shape: {preds_ndarray.shape}")


    # Step.4 - Visualization 
    vis_phase_picking(
        waveforms=waveform_ndarray,
        waveforms_labels=["E", "N", "Z"],
        preds=preds_ndarray,
        true_phase_idxs=true_phase_idxs,
        true_phase_labels=["P","S"],
        pred_phase_labels=["$\hat{D}$", "$\hat{P}$", "$\hat{S}$"],
        sampling_rate=None,
        save_name="demo_prediction",
        save_dir="./",
        formats=["png"],
    )