<a href="https://colab.research.google.com/github/seisbench/seisbench/blob/additional_example_workflows/examples/03a_training_phasenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

![image](https://raw.githubusercontent.com/seisbench/seisbench/main/docs/_static/seisbench_logo_subtitle_outlined.svg)

*This code is necessary on colab to install SeisBench. If SeisBench is already installed on your machine, you can skip this.*

In [98]:
# !pip install seisbench

*This cell is required to circumvent an issue with colab and obspy. For details, check this issue in the obspy documentation: https://github.com/obspy/obspy/issues/2547*

In [99]:
# try:
#     import obspy
#     obspy.read()
# except TypeError:
#     # Needs to restart the runtime once, because obspy only works properly after restart.
#     print('Stopping RUNTIME. If you run this code for the first time, this is expected. Colaboratory will restart automatically. Please run again.')
#     exit()

As training data we use the ETHZ dataset. Note that we set the sampling rate to 100 Hz to ensure that all examples are consistent in terms of sampling rate. We split the data into training, development and test sets according to the splits provided.

In [100]:
import seisbench.models as sbm
import seisbench.generate as sbg
import numpy as np
import seisbench.data as sbd
import torch
import matplotlib.pyplot as plt
from obspy.clients.fdsn import Client
from obspy import UTCDateTime

In [None]:
from config import load_config
cfg = load_config('Kaki-cfg.yml')
print(cfg)

In [None]:
from pathlib import Path
data = sbd.WaveformDataset(cfg.path.dataset, sampling_rate=100)
train, dev, test = data.train_dev_test()
print(train, dev, test, sep='\n')

In [103]:
# [key for key in data.metadata.keys() if 'arrival_sample' in key]

## Generation pipeline

The ETHZ dataset contains detailed labels for the phases. However, for this example we only want to differentiate between P and S picks. Therefore, we define a dictionary mapping the detailed picks to their phases.

In [104]:
phase_dict = {
    "trace_p_arrival_sample": "P",
    "trace_pP_arrival_sample": "P",
    "trace_P_arrival_sample": "P",
    "trace_P1_arrival_sample": "P",
    "trace_Pg_arrival_sample": "P",
    "trace_Pn_arrival_sample": "P",
    "trace_PmP_arrival_sample": "P",
    "trace_pwP_arrival_sample": "P",
    "trace_pwPm_arrival_sample": "P",
    "trace_s_arrival_sample": "S",
    "trace_S_arrival_sample": "S",
    "trace_S1_arrival_sample": "S",
    "trace_Sg_arrival_sample": "S",
    "trace_SmS_arrival_sample": "S",
    "trace_Sn_arrival_sample": "S",
}

Now we define two generators with identical augmentations, one for training, one for validation. The augmentations are:
1. Selection of a (long) window around a pick. This way, we ensure that out data always contains a pick.
1. Selection of a random window with 3001 samples, the input length of PhaseNet.
1. A normalization, consisting of demeaning and amplitude normalization.
1. A change of datatype to float32, as this is expected by the pytorch model.
1. A probabilistic label

In [105]:
augmentations = [
    sbg.WindowAroundSample(
        list(phase_dict.keys()),
        samples_before=3000, windowlen=6000,
        selection="random", strategy="variable"),
    sbg.RandomWindow(
        windowlen=3001, strategy="pad"),
    sbg.Normalize(
        demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
    sbg.ChangeDtype(
        np.float32),
    sbg.ProbabilisticLabeller(
        label_columns=phase_dict, sigma=30, dim=0)
]

dev_generator = sbg.GenericGenerator(dev)
dev_generator.add_augmentations(augmentations)

SeisBench generators are pytorch datasets. Therefore, we can pass them to pytorch data loaders. These will automatically take care of parallel loading and batching. Here we create one loader for training and one for validation. We choose a batch size of 256 samples. This batch size should fit on most hardware.

Now we got all components for training the model. What we still need to do is define the optimizer and the loss, and write the training and validation loops.

In [None]:
model = sbm.PhaseNet.load(cfg.path.dl_model, version_str=cfg.training.version_str)
model_org = sbm.PhaseNet.from_pretrained('original')


## Evaluating the model

Not that we trained the model, we can evaluate it. First, we'll check how the model does on an example from the development set. Note that the model will most likely not be fully trained after only five epochs.

In [111]:
def visual_eval(data_X, data_Y, model1, model2):
    with torch.no_grad():
        pred1 = model1(torch.tensor(
            data_X,
            device=model1.device).unsqueeze(0)
        )
        pred1 = pred1[0].cpu().numpy()
        #
        pred2 = model2(torch.tensor(
            data_X,
            device=model2.device).unsqueeze(0)
        )
        pred2 = pred2[0].cpu().numpy()
    #
    fig, axs = plt.subplots(
        4, 1,
        figsize=(15, 7),
        sharex=True,
        gridspec_kw={"hspace": 0, "height_ratios": [3, 1, 1, 1]}
    )
    ax1, ax2, ax3, ax4 = axs
    #
    ax1.plot(data_X.T)
    ax1.set_ylabel('Waveform')
    #
    ax2.plot(data_Y.T)
    ax2.set_ylabel('Manual')
    #
    ax3.plot(pred1.T)
    ax3.set_ylabel('Mine')
    #
    pred2 = pred2[[1, 2, 0], :]
    ax4.plot(pred2.T)
    ax4.set_ylabel('Origin')
    #
    ax1.set_yticks([0, 0.5, 1])
    for ax in [ax2, ax3, ax4]:
        ax.set_yticks([0.5, 1])

In [None]:
rand_num = np.random.randint(len(dev_generator))
sample = dev_generator[rand_num]
visual_eval(data_X=sample["X"],
            data_Y=sample["y"],
            model1=model,
            model2=model_org)

In [92]:
y = sample["y"]
p, s, n = y.argmax(axis=1)

In [94]:
def numeric_eval(data_X, data_Y, model1, model2):
    with torch.no_grad():
        pred1 = model1(torch.tensor(
            data_X,
            device=model1.device).unsqueeze(0)
        )
        pred1 = pred1[0].cpu().numpy()
        #
        pred2 = model2(torch.tensor(
            data_X,
            device=model2.device).unsqueeze(0)
        )
        pred2 = pred2[0].cpu().numpy()
        pred2 = pred2[[1, 2, 0], :]
    p0, s0, n0 = data_Y.argmax(axis=1)
    p1, s1, n1 = pred1.argmax(axis=1)
    p2, s2, n2 = pred2.argmax(axis=1)
    return {'p': [p0, p1, p2], 's': [s0, s1, s2]}

In [None]:
lst_stats = []
for ii in range(len(dev_generator)):
    sample = dev_generator[ii]
    stat = numeric_eval(
        data_X=sample["X"],
        data_Y=sample["y"],
        model1=model,
        model2=model_org)
    lst_stats.append(stat)

In [None]:
for phase in ['p', 's']:
    arr = []
    for stat in lst_stats:
        arr.append(stat[phase])
    arr = np.array(arr)
    m1 = arr[:, 0] - arr[:, 1]
    m2 = arr[:, 0] - arr[:, 2]
    #
    sps = 100
    m1 = m1 / sps
    m2 = m2 / sps

    lim = 2 # seconds
    step = lim/10
    m1 = m1[np.abs(m1)<=lim]
    m2 = m2[np.abs(m2)<=lim]
    bins = np.arange(-lim, lim, step) + step/2
    plt.hist(m1, bins=bins, edgecolor='k')
    plt.title('m1'+phase)
    plt.show()
    plt.hist(m2, bins=bins, edgecolor='k')
    plt.title('m2'+phase)
    plt.show()

As a second option, we'll directly apply our model to an obspy waveform stream using the `annotate` function. For this, we are downloading waveforms through FDSN and annotating them. Note that you could use the `classify` function in a similar fashion.

As we trained the model on Swiss data, we use an example event from Switzerland. Note that we deliberately chose a rather easy example, as the model is not fully trained after the low number of epochs. The exact performance of the model will vary depending, because the model training and initialization involves random aspects.

In [None]:
client = Client("ETH")

t = UTCDateTime("2019-11-04T00:59:46.419800Z")
stream = client.get_waveforms(network="CH", station="EMING", location="*", channel="HH?", starttime=t-30, endtime=t+50)

annotations = model.annotate(stream)
annotations_org = model_org.annotate(stream)


fig = plt.figure(figsize=(15, 7))
axs = fig.subplots(3, 1, sharex=True, gridspec_kw={'hspace': 0})

offset = annotations[0].stats.starttime - stream[0].stats.starttime
offset_org = annotations_org[0].stats.starttime - stream[0].stats.starttime

for i in range(3):
    axs[0].plot(stream[i].times(), stream[i].data, label=stream[i].stats.channel)
    if annotations[i].stats.channel[-1] != "N":  # Do not plot noise curve
        axs[1].plot(annotations[i].times() + offset, annotations[i].data, label=annotations[i].stats.channel)
    if annotations_org[i].stats.channel[-1] != "N":  # Do not plot noise curve
        axs[2].plot(annotations_org[i].times() + offset_org,
                    annotations_org[i].data,
                    label=annotations_org[i].stats.channel)

axs[0].legend()
axs[1].legend()
axs[2].legend()

## Remarks

As discussed in the data basics tutorial, loading a SeisBench dataset only means loading the metadata into memory. The waveforms are only loaded once they are requested to save memory. By default, waveforms are **not** cached in memory. For training, this means that the data needs to be read from the file in every epoch again. Depending on your hardware, this will take a lot of time. To solve this issue, you can set the `cache` option, when creating the dataset. Then, all you have to do is call `preload_waveforms` and the data will be loaded into memory and automatically cached. For most practical applications, this option is recommended.