In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import h5py

import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from datetime import datetime, timedelta
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
from keras import layers
from scipy.signal import spectrogram, stft, istft, resample

from sklearn.metrics import confusion_matrix

In [None]:
class SteadDataObject:
    def __init__(self):
        pass

    def get_component(self, component: str) -> list:
        # There are 3 channels: first row: E channel, second row: N channel, third row: Z channel
        # After rotating:
        # First column (E) becomes last row
        # Second column (N) becomes middle row
        # Third column (Z) becomes first row
        rotated = np.rot90(self.data, k=1, axes=(0, 1))
        if component.lower() == "e":
            return rotated[2]
        elif component.lower() == "n":
            return rotated[1]
        else:
            return rotated[0]

    def get_raw(self):
        return self.data.reshape(6000, 3, 1)

    def get_components(self):
        rotated = np.rot90(self.data, k=1, axes=(0, 1))
        return rotated

    def get_timespan(self, samples=6000, fs=100):
        t = self.get_ts_datetime()
        time_points = []
        for _ in range(samples):
            time_points.append(t)
            t += timedelta(seconds=1/fs)
        return time_points

    def get_ts_datetime(self):
        dt = datetime.strptime(self.trace_start_time, "%Y-%m-%d %H:%M:%S.%f")
        return dt

    def get_ts_iso8601(self):
        dt = self.get_ts_datetime()
        return dt.isoformat()

    def get_ts_short(self):
        dt = self.get_ts_datetime()
        return dt.strftime("%Y-%m-%d %H:%M:%S")

In [None]:
def plot_all(do, label, samples, fs, stft_size, nperseg, file_path=None):
    d0 = pd.DataFrame(data=do[0][:samples])
    d1 = pd.DataFrame(data=do[1][:samples])
    d2 = pd.DataFrame(data=do[2][:samples])

    plt.rc('font', size=11)
    plt.rc('axes', titlesize=16)

    fig = plt.figure(figsize=(16, 10), dpi=227)
    ax1 = plt.subplot2grid((5, 6), (0, 0), colspan=3)
    ax2 = plt.subplot2grid((5, 6), (1, 0), colspan=3)
    ax3 = plt.subplot2grid((5, 6), (2, 0), colspan=3)
    ax4 = plt.subplot2grid((5, 6), (0, 3), colspan=3)
    ax5 = plt.subplot2grid((5, 6), (1, 3), colspan=3)
    ax6 = plt.subplot2grid((5, 6), (2, 3), colspan=3)
    ax7 = plt.subplot2grid((5, 6), (3, 0), colspan=2, rowspan=2)
    ax8 = plt.subplot2grid((5, 6), (3, 2), colspan=2, rowspan=2)
    ax9 = plt.subplot2grid((5, 6), (3, 4), colspan=2, rowspan=2)

    plt.subplots_adjust(hspace=1, wspace=1)

    sns.lineplot(data=d0, ax=ax1, linewidth=1, legend=None)
    sns.lineplot(data=d1, ax=ax2, linewidth=1, legend=None)
    sns.lineplot(data=d2, ax=ax3, linewidth=1, legend=None)

    ax1.set_title("Vertical component waveform")
    ax1.set(xlabel="Samples", ylabel="Amp. counts")
    ax1.locator_params(nbins=6, axis="y")

    ax2.set_title("North component waveform")
    ax2.set(xlabel="Samples", ylabel="Amp. counts")
    ax2.locator_params(nbins=6, axis="y")

    ax3.set_title("East component waveform")
    ax3.set(xlabel="Samples", ylabel="Amp. counts")
    ax3.locator_params(nbins=6, axis="y")

    f_0, t_0, Sxx_0 = spectrogram(x=do[0], fs=FS)
    f_1, t_1, Sxx_1 = spectrogram(x=do[1], fs=FS)
    f_2, t_2, Sxx_2 = spectrogram(x=do[2], fs=FS)

    ax4.clear()
    ax4.set_title("Vertical component spectrogram")
    _ax4 = ax4.pcolormesh(t_0, f_0, Sxx_0, shading="gouraud")
    ax4.set(xlabel="Time [sec]", ylabel="Freq. [Hz]")
    fig.colorbar(_ax4, ax=ax4)

    ax5.clear()
    ax5.set_title("North component spectrogram")
    _ax5 = ax5.pcolormesh(t_1, f_1, Sxx_1, shading="gouraud")
    ax5.set(xlabel="Time [sec]", ylabel="Freq. [Hz]")
    fig.colorbar(_ax5, ax=ax5)

    ax6.clear()
    ax6.set_title("East component spectrogram")
    _ax6 = ax6.pcolormesh(t_2, f_2, Sxx_2, shading="gouraud")
    ax6.set(xlabel="Time [sec]", ylabel="Freq. [Hz]")
    fig.colorbar(_ax6, ax=ax6)

    f_sftt_0, t_sftt_0, Zxx_0 = stft(do[0], window="hanning", fs=fs, nperseg=nperseg)
    f_sftt_1, t_sftt_1, Zxx_1 = stft(do[1], window="hanning", fs=fs, nperseg=nperseg)
    f_sftt_2, t_sftt_2, Zxx_2 = stft(do[2], window="hanning", fs=fs, nperseg=nperseg)

    ticks = np.arange(stft_size)

    ax7.clear()
    ax7.set_title("Vertical component STFT")
    _ax7 = ax7.pcolormesh(ticks, ticks, np.abs(Zxx_0), shading="auto")
    fig.colorbar(_ax7, ax=ax7)

    ax8.clear()
    ax8.set_title("North component STFT")
    _ax8 = ax8.pcolormesh(ticks, ticks, np.abs(Zxx_1), shading="auto")
    fig.colorbar(_ax8, ax=ax8)

    ax9.clear()
    ax9.set_title("East component STFT")
    _ax9 = ax9.pcolormesh(ticks, ticks, np.abs(Zxx_2), shading="auto")
    fig.colorbar(_ax9, ax=ax9)

    plt.suptitle(label, fontsize=14)

    if file_path != None:
        plt.savefig(file_path, bbox_inches='tight')
        plt.close(fig)

In [None]:
def plot_seismograms(do, label, samples, file_path=None):
    d0 = pd.DataFrame(data=do[0][:samples])
    d1 = pd.DataFrame(data=do[1][:samples])
    d2 = pd.DataFrame(data=do[2][:samples])
    
    plt.rc('font', size=11)
    plt.rc('axes', titlesize=16)

    fig = plt.figure(figsize=(8, 5), dpi=227)
    ax1 = plt.subplot2grid((3, 1), (0, 0), colspan=1)
    ax2 = plt.subplot2grid((3, 1), (1, 0), colspan=1)
    ax3 = plt.subplot2grid((3, 1), (2, 0), colspan=1)

    plt.subplots_adjust(hspace=1, wspace=1)

    sns.lineplot(data=d0, ax=ax1, linewidth=1, legend=None)
    sns.lineplot(data=d1, ax=ax2, linewidth=1, legend=None)
    sns.lineplot(data=d2, ax=ax3, linewidth=1, legend=None)

    ax1.set_title("Vertical component waveform")
    ax1.set(xlabel="Samples", ylabel="Amp. counts")
    ax1.locator_params(nbins=6, axis="y")

    ax2.set_title("North component waveform")
    ax2.set(xlabel="Samples", ylabel="Amp. counts")
    ax2.locator_params(nbins=6, axis="y")

    ax3.set_title("East component waveform")
    ax3.set(xlabel="Samples", ylabel="Amp. counts")
    ax3.locator_params(nbins=6, axis="y")

    plt.suptitle(label, fontsize=14)

    if file_path != None:
        plt.savefig(file_path, bbox_inches='tight')
        plt.close(fig)

In [None]:
def plot_seismograms_downsampled(do, label, file_path=None):
    d0 = pd.DataFrame(data=resample(do[0][:4000], 4000))
    d1 = pd.DataFrame(data=resample(do[1][:4000], 4000))
    d2 = pd.DataFrame(data=resample(do[2][:4000], 4000))
    
    plt.rc('font', size=11)
    plt.rc('axes', titlesize=16)

    fig = plt.figure(figsize=(8, 5), dpi=227)
    ax1 = plt.subplot2grid((3, 1), (0, 0), colspan=1)
    ax2 = plt.subplot2grid((3, 1), (1, 0), colspan=1)
    ax3 = plt.subplot2grid((3, 1), (2, 0), colspan=1)

    plt.subplots_adjust(hspace=1, wspace=1)

    sns.lineplot(data=d0, ax=ax1, linewidth=1, legend=None)
    sns.lineplot(data=d1, ax=ax2, linewidth=1, legend=None)
    sns.lineplot(data=d2, ax=ax3, linewidth=1, legend=None)

    ax1.set_title("Vertical component waveform")
    ax1.set(xlabel="Samples", ylabel="Amp. counts")
    ax1.locator_params(nbins=6, axis="y")

    ax2.set_title("North component waveform")
    ax2.set(xlabel="Samples", ylabel="Amp. counts")
    ax2.locator_params(nbins=6, axis="y")

    ax3.set_title("East component waveform")
    ax3.set(xlabel="Samples", ylabel="Amp. counts")
    ax3.locator_params(nbins=6, axis="y")

    plt.suptitle(label, fontsize=14)

    if file_path != None:
        plt.savefig(file_path, bbox_inches='tight')
        plt.close(fig)

In [None]:
def plot_stft(do, label, samples, fs, stft_size, nperseg, file_path=None):
    d0 = pd.DataFrame(data=do[0][:samples])
    d1 = pd.DataFrame(data=do[1][:samples])
    d2 = pd.DataFrame(data=do[2][:samples])

    plt.rc('font', size=11)
    plt.rc('axes', titlesize=16)

    fig = plt.figure(figsize=(16, 13), dpi=227)
    ax7 = plt.subplot2grid((6, 6), (0, 0), colspan=2, rowspan=2)
    ax8 = plt.subplot2grid((6, 6), (0, 2), colspan=2, rowspan=2)
    ax9 = plt.subplot2grid((6, 6), (0, 4), colspan=2, rowspan=2)

    plt.subplots_adjust(hspace=1, wspace=1)

    f_sftt_0, t_sftt_0, Zxx_0 = stft(do[0], window="hanning", fs=fs, nperseg=nperseg)
    f_sftt_1, t_sftt_1, Zxx_1 = stft(do[1], window="hanning", fs=fs, nperseg=nperseg)
    f_sftt_2, t_sftt_2, Zxx_2 = stft(do[2], window="hanning", fs=fs, nperseg=nperseg)

    ticks = np.arange(stft_size)

    ax7.clear()
    ax7.set_title("Vertical component STFT")
    _ax7 = ax7.pcolormesh(ticks, ticks, np.abs(Zxx_0), shading="auto")
    fig.colorbar(_ax7, ax=ax7)

    ax8.clear()
    ax8.set_title("North component STFT")
    _ax8 = ax8.pcolormesh(ticks, ticks, np.abs(Zxx_1), shading="auto")
    fig.colorbar(_ax8, ax=ax8)

    ax9.clear()
    ax9.set_title("East component STFT")
    _ax9 = ax9.pcolormesh(ticks, ticks, np.abs(Zxx_2), shading="auto")
    fig.colorbar(_ax9, ax=ax9)

    plt.suptitle(label, fontsize=14)

    if file_path != None:
        plt.savefig(file_path, bbox_inches='tight')
        plt.close(fig)

In [None]:
def to_dataobject(obj) -> SteadDataObject:
    do = SteadDataObject()
    setattr(do, "data", np.array(obj))

    for a in obj.attrs:
        setattr(do, a, obj.attrs[a])

    return do

In [None]:
def get_event_data(idx_start, idx_end):
    streams = []
    df = pd.read_csv("data/STEAD.csv")
    df = df.sample(frac=1)
    df = df[
        (df.trace_category == "earthquake_local")
        & (df.source_distance_km <= 50)
        & (df.source_magnitude < 3.5)
    ]
    ev_list = df["trace_name"].to_list()[idx_start:idx_end]

    dtfl = h5py.File("data/STEAD.hdf5", "r")

    for _, evi in enumerate(ev_list):
        dataset = dtfl.get(f"data/{evi}")
        do = to_dataobject(dataset)
        streams.append(do)

    return streams

In [None]:
events = get_event_data(0, 100)

In [None]:
idx = 88
plot_all(events[idx].get_components(), events[idx].trace_name, 6000, 100, 78, 155)

In [None]:
plot_seismograms(events[idx].get_components(), events[idx].trace_name, 6000, "msc-experiment2-prep-0.png")

In [None]:
downsampled = [resample(e, 4000) for e in events[idx].get_components()]
plot_seismograms_downsampled(downsampled, events[idx].trace_name, "msc-experiment2-prep-1.png")

In [None]:
downsampled = [resample(e, 4000) for e in events[idx].get_components()]
plot_stft(downsampled, events[idx].trace_name, 4000, 66, 64, 127, "msc-experiment2-prep-2.png")

In [None]:
events[idx].trace_name