## Model Training: Vision Transformer

### Environment setup

In [1]:
import gc
import sys
sys.path.append('../')
sys.path.append('../../../')

import keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

from PIL import Image
from transformers import ViTFeatureExtractor, ViTForImageClassification

import earthquake_detection.architectures as architectures
import earthquake_detection.training_utils as training_utils

2025-01-27 12:08:33.804085: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


### Load datasets

In [2]:
# Load extracted raw signals
raw_signals = np.load('../../../data/STEAD/extracted_raw_signals_subsample_2000.npy')

# Load created spectrogram images
spectrogram_imgs = np.load('../../../data/STEAD/created_spectrogram_images_subsample_2000.npy')

# Load metadata
metadata = pd.read_feather('../../../data/STEAD/extracted_metadata_subsample_2000.feather')
metadata = metadata.reset_index()

#### Inspect data to ensure it looks as expected

In [3]:
raw_signals.shape

(2000, 6000, 3)

In [4]:
spectrogram_imgs.shape

(2000, 200, 300, 3)

In [5]:
metadata

Unnamed: 0,trace_name,network_code,receiver_code,receiver_type,receiver_latitude,receiver_longitude,receiver_elevation_m,p_arrival_sample,p_status,p_weight,...,source_mechanism_strike_dip_rake,source_distance_deg,source_distance_km,back_azimuth_deg,snr_db,coda_end_sample,trace_start_time,trace_category,chunk,weight_for_subsample
0,B084.PB_20111212104350_EV,PB,B084,EH,33.611570,-116.456370,1271.0,800.0,manual,0.55,...,,0.5746,63.90,178.5,[15.39999962 15.5 17.10000038],[[3199.]],43:51.3,earthquake_local,1,9.706561e-07
1,B086.PB_20080618145426_EV,PB,B086,EH,33.557500,-116.531000,1392.0,800.0,manual,0.59,...,,0.5071,56.40,330.6,[45. 46.20000076 46.29999924],[[2230.]],54:27.0,earthquake_local,1,9.706561e-07
2,B023.PB_20130513182210_EV,PB,B023,EH,46.111200,-123.078700,177.4,500.0,manual,0.63,...,,0.9683,107.61,163.1,[12.60000038 12.10000038 11.19999981],[[3199.]],22:11.5,earthquake_local,1,9.706561e-07
3,B011.PB_20120724202107_EV,PB,B011,EH,48.649543,-123.448192,22.0,800.0,manual,0.76,...,,0.8950,99.44,82.7,[20.39999962 26.39999962 25. ],[[3199.]],21:08.5,earthquake_local,1,9.706561e-07
4,B082.PB_20150914105733_EV,PB,B082,HH,33.598182,-116.596005,1374.8,799.0,autopicker,0.92,...,,0.1618,17.99,123.9,[28.79999924 30.29999924 28.70000076],[[1516.]],57:35.0,earthquake_local,1,9.706561e-07
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,XAN.IC_20180116143912_NO,IC,XAN,BH,34.031300,108.923700,630.0,,,,...,,,,,,,2018-01-16 14:39:12,noise,6,4.247619e-06
1996,OK048.GS_20180115012442_NO,GS,OK048,HH,36.416220,-96.943740,295.0,,,,...,,,,,,,2018-01-15 01:24:42,noise,6,4.247619e-06
1997,A11.CN_20180115103824_NO,CN,A11,HH,47.243100,-70.196900,55.0,,,,...,,,,,,,2018-01-15 10:38:24,noise,6,4.247619e-06
1998,CMOB.NC_201606271215_NO,NC,CMOB,HN,37.810420,-121.802720,743.0,,,,...,,,,,,,2016-06-27 12:15:00,noise,6,4.247619e-06


In [8]:
metadata.columns

Index(['trace_name', 'network_code', 'receiver_code', 'receiver_type',
       'receiver_latitude', 'receiver_longitude', 'receiver_elevation_m',
       'p_arrival_sample', 'p_status', 'p_weight', 'p_travel_sec',
       's_arrival_sample', 's_status', 's_weight', 'source_id',
       'source_origin_time', 'source_origin_uncertainty_sec',
       'source_latitude', 'source_longitude', 'source_error_sec',
       'source_gap_deg', 'source_horizontal_uncertainty_km', 'source_depth_km',
       'source_depth_uncertainty_km', 'source_magnitude',
       'source_magnitude_type', 'source_magnitude_author',
       'source_mechanism_strike_dip_rake', 'source_distance_deg',
       'source_distance_km', 'back_azimuth_deg', 'snr_db', 'coda_end_sample',
       'trace_start_time', 'trace_category', 'chunk', 'weight_for_subsample',
       'label'],
      dtype='object')

### Create labels for classification model training

In [7]:
metadata['label'] = [1 if label=='earthquake_local' else 0 for label in metadata['trace_category']]
classifier_labels = metadata['label'].values

In [9]:
classifier_labels

array([1, 1, 1, ..., 0, 0, 0])

### Parameter setup

In [None]:
spectrogram_kwargs = {'image_size' : (300,200)}
waveform_kwargs = {'image_size' : (400,200)}