## Model Training: Vision Transformer

### Environment setup

In [45]:
import gc
import sys
sys.path.append('../')
sys.path.append('../../../')

import keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

from PIL import Image
from transformers import ViTImageProcessor, TFViTForImageClassification

import earthquake_detection.architectures as architectures
import earthquake_detection.training_utils as training_utils

RuntimeError: Failed to import transformers.models.vit.modeling_tf_vit because of the following error (look up to see its traceback):
Your currently installed version of Keras is Keras 3, but this is not yet supported in Transformers. Please install the backwards-compatible tf-keras package with `pip install tf-keras`.

### Load datasets

In [11]:
# Load extracted raw signals
raw_signals = np.load('../../../data/STEAD/extracted_raw_signals_subsample_2000.npy')

# Load created spectrogram images
spectrogram_imgs = np.load('../../../data/STEAD/created_spectrogram_images_subsample_2000.npy')

# Load metadata
metadata = pd.read_feather('../../../data/STEAD/extracted_metadata_subsample_2000.feather')
metadata = metadata.reset_index()

#### Inspect data to ensure it looks as expected

In [12]:
raw_signals.shape

(2000, 6000, 3)

In [13]:
spectrogram_imgs.shape

(2000, 200, 300, 3)

In [14]:
metadata

Unnamed: 0,trace_name,network_code,receiver_code,receiver_type,receiver_latitude,receiver_longitude,receiver_elevation_m,p_arrival_sample,p_status,p_weight,...,source_mechanism_strike_dip_rake,source_distance_deg,source_distance_km,back_azimuth_deg,snr_db,coda_end_sample,trace_start_time,trace_category,chunk,weight_for_subsample
0,B084.PB_20111212104350_EV,PB,B084,EH,33.611570,-116.456370,1271.0,800.0,manual,0.55,...,,0.5746,63.90,178.5,[15.39999962 15.5 17.10000038],[[3199.]],43:51.3,earthquake_local,1,9.706561e-07
1,B086.PB_20080618145426_EV,PB,B086,EH,33.557500,-116.531000,1392.0,800.0,manual,0.59,...,,0.5071,56.40,330.6,[45. 46.20000076 46.29999924],[[2230.]],54:27.0,earthquake_local,1,9.706561e-07
2,B023.PB_20130513182210_EV,PB,B023,EH,46.111200,-123.078700,177.4,500.0,manual,0.63,...,,0.9683,107.61,163.1,[12.60000038 12.10000038 11.19999981],[[3199.]],22:11.5,earthquake_local,1,9.706561e-07
3,B011.PB_20120724202107_EV,PB,B011,EH,48.649543,-123.448192,22.0,800.0,manual,0.76,...,,0.8950,99.44,82.7,[20.39999962 26.39999962 25. ],[[3199.]],21:08.5,earthquake_local,1,9.706561e-07
4,B082.PB_20150914105733_EV,PB,B082,HH,33.598182,-116.596005,1374.8,799.0,autopicker,0.92,...,,0.1618,17.99,123.9,[28.79999924 30.29999924 28.70000076],[[1516.]],57:35.0,earthquake_local,1,9.706561e-07
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,XAN.IC_20180116143912_NO,IC,XAN,BH,34.031300,108.923700,630.0,,,,...,,,,,,,2018-01-16 14:39:12,noise,6,4.247619e-06
1996,OK048.GS_20180115012442_NO,GS,OK048,HH,36.416220,-96.943740,295.0,,,,...,,,,,,,2018-01-15 01:24:42,noise,6,4.247619e-06
1997,A11.CN_20180115103824_NO,CN,A11,HH,47.243100,-70.196900,55.0,,,,...,,,,,,,2018-01-15 10:38:24,noise,6,4.247619e-06
1998,CMOB.NC_201606271215_NO,NC,CMOB,HN,37.810420,-121.802720,743.0,,,,...,,,,,,,2016-06-27 12:15:00,noise,6,4.247619e-06


In [15]:
metadata.columns

Index(['trace_name', 'network_code', 'receiver_code', 'receiver_type',
       'receiver_latitude', 'receiver_longitude', 'receiver_elevation_m',
       'p_arrival_sample', 'p_status', 'p_weight', 'p_travel_sec',
       's_arrival_sample', 's_status', 's_weight', 'source_id',
       'source_origin_time', 'source_origin_uncertainty_sec',
       'source_latitude', 'source_longitude', 'source_error_sec',
       'source_gap_deg', 'source_horizontal_uncertainty_km', 'source_depth_km',
       'source_depth_uncertainty_km', 'source_magnitude',
       'source_magnitude_type', 'source_magnitude_author',
       'source_mechanism_strike_dip_rake', 'source_distance_deg',
       'source_distance_km', 'back_azimuth_deg', 'snr_db', 'coda_end_sample',
       'trace_start_time', 'trace_category', 'chunk', 'weight_for_subsample'],
      dtype='object')

### Create labels for classification model training

In [16]:
metadata['label'] = [1 if label=='earthquake_local' else 0 for label in metadata['trace_category']]
classifier_labels = metadata['label'].values

In [17]:
classifier_labels

array([1, 1, 1, ..., 0, 0, 0])

### Parameter setup

In [24]:
image_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

In [25]:
image_processor

ViTImageProcessor {
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

## Model training & evaluation

In [38]:
preproc_imgs = [image_processor(img, return_tensors='tf', padding=True)  for img in spectrogram_imgs]

In [43]:
train_dataset_c, val_dataset_c, test_dataset_c = training_utils.prepare_datasets(imgs=preproc_imgs, labels=classifier_labels, batch_size=32)

In [44]:
model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=2)
vit_model = model.to_tf() # convert Hugging Face model to tensorflow model

ImportError: 
ViTForImageClassification requires the PyTorch library but it was not found in your environment.
However, we were able to find a TensorFlow installation. TensorFlow classes begin
with "TF", but are otherwise identically named to our PyTorch classes. This
means that the TF equivalent of the class you tried to import would be "TFViTForImageClassification".
If you want to use TensorFlow, please use TF classes instead!

If you really do want to use PyTorch please go to
https://pytorch.org/get-started/locally/ and follow the instructions that
match your environment.


In [None]:
vit_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)