In [17]:
import numpy as np
import sys
sys.path.append('..')

import torch
from torch.utils.data import DataLoader

import seisbench.generate as sbg
from seisbench.data import WaveformDataset
from seisbench.models import EQTransformer
from seisbench.util import worker_seeding

from utils.augmentations import ChangeChannels, StoreMetadata

from utils.evaluation import calculate_metrics
# from utils import predict
# from evaluation import eval

In [18]:
data_path = "/home/trahn/data/bedretto/"
data = WaveformDataset(data_path)
data.filter(data["trace_ntps"] == 20001)

In [19]:
rng = np.random.default_rng()

In [20]:
def predict(model, dataloader):
    """Convenience function for predicting values in `dataloader' using `model'.  Returns a dictionary with keys 'predicitions' and 'targets'"""
    predictions = []
    targets = []

    model.eval()  # close model for evaluation

    with torch.no_grad():
        for batch in dataloader:
            # TODO: window_borders does not exist for us.
            # window_borders = batch["window_borders"]

            det_pred, p_pred, s_pred = model(batch["X"].to(model.device))


            score_detection = torch.zeros(det_pred.shape[0])
            score_p_or_s = torch.zeros(det_pred.shape[0])
            p_sample = torch.zeros(det_pred.shape[0], dtype=int)
            s_sample = torch.zeros(det_pred.shape[0], dtype=int)
            for i in range(det_pred.shape[0]):
                # TODO In pick-benchmark every batch as a "window_borders" property that is used here, but we do not have that (?) so I am just using the full rage instead.
                # start_sample, end_sample = window_borders[i]
                local_det_pred = det_pred[i, :]
                local_p_pred = p_pred[i, :]
                local_s_pred = s_pred[i, :]

                score_detection[i] = torch.max(local_det_pred)
                score_p_or_s[i] = torch.max(local_p_pred) / torch.max(
                    local_s_pred
                )  # most likely P by most likely S

                p_sample[i] = torch.argmax(local_p_pred)
                s_sample[i] = torch.argmax(local_s_pred)

            # TODO Also see notebook for usage
            predictions.append(
                torch.stack((score_detection, p_sample, s_sample), dim=1).cpu()
            )
            targets.append([True, batch["trace_p_arrival_sample"][0].item(), batch["trace_s_arrival_sample"][0].item()])

    return {"predictions": np.vstack(predictions), "targets": np.concatenate(targets)}


In [21]:
phase_dict = {
    "trace_p_arrival_sample": "P",
    "trace_pP_arrival_sample": "P",
    "trace_P_arrival_sample": "P",
    "trace_P1_arrival_sample": "P",
    "trace_Pg_arrival_sample": "P",
    "trace_Pn_arrival_sample": "P",
    "trace_PmP_arrival_sample": "P",
    "trace_pwP_arrival_sample": "P",
    "trace_pwPm_arrival_sample": "P",
    "trace_s_arrival_sample": "S",
    "trace_S_arrival_sample": "S",
    "trace_S1_arrival_sample": "S",
    "trace_Sg_arrival_sample": "S",
    "trace_SmS_arrival_sample": "S",
    "trace_Sn_arrival_sample": "S",
}

def get_eval_augmentations():
    p_phases = [key for key, val in phase_dict.items() if val == "P"]
    s_phases = [key for key, val in phase_dict.items() if val == "S"]

    detection_labeller = sbg.DetectionLabeller(
        p_phases, s_phases=s_phases, key=("X", "detections")
    )

    return [
        StoreMetadata('trace_p_arrival_sample'),
        StoreMetadata("trace_s_arrival_sample"),
        sbg.RandomWindow(
            low=None,
            high=None,
            windowlen=20000,
            strategy="pad",
        ),
        sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=20, dim=0),
        detection_labeller,
        sbg.ChangeDtype(np.float32, "X"),
        sbg.ChangeDtype(np.float32, "y"),
        sbg.ChangeDtype(np.float32, "detections"),
        ChangeChannels(0),
        StoreMetadata('trace_snr'),
        sbg.Normalize(detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
    ]

In [6]:
train, dev, test = data.train_dev_test()

data_generator = sbg.GenericGenerator(test)
data_generator.add_augmentations(get_eval_augmentations())
data_loader = DataLoader(
    data_generator,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    worker_init_fn=worker_seeding,
    # drop_last=True,
)

In [7]:
model = EQTransformer(in_channels=1, in_samples=20000)
checkpoint = torch.load("../../results/trained_models/eqt_bedretto_from_scratch/checkpoint-50.pt", map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [8]:
print("Run predictions.")
res = predict(model, data_loader)

Run predictions.


In [9]:
res["targets"] = res["targets"].reshape(-1, 3)

In [10]:
targets = res["targets"]
det_true = targets[:, 0]
p_true = targets[:, 1]
s_true = targets[:, 1]

predictions = res["predictions"]
det_pred = predictions[:, 0]
p_pred = predictions[:, 1]
s_pred = predictions[:, 2]

In [11]:
p_nans = np.isnan(p_true)
s_nans = np.isnan(s_true)

p_true = p_true[~p_nans]
s_true = s_true[~s_nans]

p_pred = p_pred[~p_nans]
s_pred = s_pred[~s_nans]

In [30]:
snr = []

print("Build ground truth.")
for batch in data_loader:
    local_snr = batch['trace_snr']
    if isinstance(local_snr, str):
        local_snr = float(
            local_snr.replace("[", "").replace("]", "").strip().split(" ")[0]
        )
    else:
        local_snr = 0.0

    snr.append(local_snr)

snr = np.array(snr)
snr = snr[~p_nans]

Build ground truth.


In [53]:
snr = []

print("Build ground truth.")
for idx in range(len(test)):
    _, metadata = data.get_sample(idx)
    local_snr = metadata["trace_snr"]
    if isinstance(local_snr, str):
        local_snr = float(
            local_snr.replace("[", "").replace("]", "").strip().split(" ")[0]
        )

    snr.append(local_snr)

snr = np.array(snr)
snr = snr[~p_nans]

Build ground truth.


In [55]:
from sklearn.metrics import (
    confusion_matrix,
    roc_curve,
    precision_score,
    recall_score,
    f1_score,
    mean_absolute_error,
    mean_absolute_percentage_error,
    mean_squared_error,
)

detection_threshold = 0.5

print("Evaluate predictions.")
det_roc = roc_curve(det_true, det_pred)

# NOTE: detection_threshold is a hyperparamater
det_pred = np.ceil(det_pred - detection_threshold)

results = dict()

results["det_roc"] = det_roc
for det_metric in [confusion_matrix, precision_score, recall_score, f1_score]:
    results[f"det_{det_metric.__name__}"] = det_metric(det_true, det_pred)

for pick, true, pred in [("p", p_true, p_pred), ("s", s_true, s_pred)]:
    for name, metric in [("mu", np.mean), ("std", np.std)]:
        results[f"{pick}_{name}"] = metric(true - pred)
    for name, metric in [
        ("MAE", mean_absolute_error),
        ("MAPE", mean_absolute_percentage_error),
        ("RMSE", lambda true, pred: mean_squared_error(true, pred, squared=False))
    ]:
        results[f"{pick}_{name}"] = metric(true, pred)

results["p_res"] = p_true - p_pred
results["s_res"] = s_true - s_pred
results["snr"] = snr

Evaluate predictions.




In [56]:
for k in ['det_precision_score', 'det_recall_score', 'det_f1_score', 'p_mu', 'p_std', 'p_MAE', 'p_MAPE', 'p_RMSE', 's_mu', 's_std', 's_MAE', 's_MAPE', 's_RMSE']:
    print(f"{k}\t{results[k]:>.4f}")

det_precision_score	1.0000
det_recall_score	0.4629
det_f1_score	0.6328
p_mu	2.9877
p_std	1030.7170
p_MAE	160.4598
p_MAPE	0.0160
p_RMSE	1030.7214
s_mu	-4455.8651
s_std	4628.2067
s_MAE	4698.9750
s_MAPE	0.4699
s_RMSE	6424.5646


In [57]:
import pickle

In [58]:
with open("bedretto.pickle", "wb") as f:
    pickle.dump(results, f, pickle.HIGHEST_PROTOCOL)