Copyright 2024 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title License
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import gin
import re
import tensorflow as tf
import datetime
from eq_mag_prediction.forecasting import encoders
from eq_mag_prediction.scripts import magnitude_predictor_trainer
from eq_mag_prediction.forecasting import training_examples
from eq_mag_prediction.forecasting import one_region_model
from eq_mag_prediction.utilities import catalog_analysis
from eq_mag_prediction.forecasting.training_examples import CatalogDomain
from eq_mag_prediction.utilities import data_utils, catalog_filters
from eq_mag_prediction.forecasting import metrics
from eq_mag_prediction.forecasting.training_examples import CatalogDomain
from eq_mag_prediction.forecasting.data_sources import target_catalog
from eq_mag_prediction.forecasting.encoders import SeismicityRateEncoder, RecentEarthquakesEncoder, BiggestEarthquakesEncoder, CatalogColumnsEncoder



In [None]:
from IPython.display import display, Markdown
def printmd(string):
    display(Markdown(string))

# Loading and experiment setting

## Load model

In [None]:
MODEL_NAME = 'Hauksson'
# MODEL_NAME = 'JMA'
# MODEL_NAME = 'GeoNet_NZ'

In [None]:
experiment_dir = os.path.join(os.getcwd(), '..', 'results/trained_models/', MODEL_NAME)
custom_objects={
    '_repeat': encoders._repeat,
    }

loaded_model = tf.keras.models.load_model(
    os.path.join(experiment_dir, 'model'),
    custom_objects={'_repeat': encoders._repeat},
    compile=False,
    # safe_mode=True
    )

## Set gin

In [None]:
with open(os.path.join(experiment_dir, 'config.gin')) as f:
  with gin.unlock_config():
    gin.parse_config(f.read(), skip_unknown=True)

In [None]:
print(gin.config_str())

In [None]:
gin.finalize()

In [None]:
domain = training_examples.CatalogDomain()
labels = training_examples.magnitude_prediction_labels(domain)

# Set relevant probability density and other definitions

In [None]:
gin_config = gin.config_str()

match = re.search(r'train_and_evaluate_magnitude_prediction_model\.pdf_support_stretch = (\d+)', gin_config)
if match:
  stretch = match.group(1)
else:
  stretch = 7

stretch = float(stretch)

In [None]:
probability_density_function = metrics.kumaraswamy_mixture_instance
LOSS = metrics.MinusLoglikelihoodConstShiftStretchLoss(probability_density_function, domain.magnitude_threshold, stretch)

random_var_shift = 0 if not hasattr(LOSS, 'shift') else LOSS.shift
random_var_stretch = 7 if not hasattr(LOSS, 'stretch') else LOSS.stretch

costum_shift_stretch = lambda x, random_var_shift=random_var_shift, random_var_stretch=random_var_stretch: np.minimum((x - random_var_shift) / random_var_stretch, 1)
shift_strech_input = costum_shift_stretch


BETA_OF_TRAIN_SET = catalog_analysis.estimate_beta(labels.train_labels, None, 'BPOS')
print(BETA_OF_TRAIN_SET)
MAG_THRESH = domain.magnitude_threshold
print(MAG_THRESH)

In [None]:
SET_NAME = 'test'

In [None]:
# BETA_OF_TRAIN_SET = catalog_analysis.estimate_beta(labels.train_labels, domain.magnitude_threshold, 'MLE')
BETA_OF_TRAIN_SET = catalog_analysis.estimate_beta(labels.train_labels, None, 'BPOS')
MAG_THRESH = domain.magnitude_threshold
DAY_TO_SECONDS = 60*60*24

In [None]:
test_timestamps = np.array(list(domain.test_examples.keys()))
validation_timestamps = np.array(list(domain.validation_examples.keys()))
train_timestamps = np.array(list(domain.train_examples.keys()))
all_timestamps = np.concatenate([train_timestamps, validation_timestamps, test_timestamps])

In [None]:

printmd(
    f'<h3>train time:</h3> {datetime.datetime.fromtimestamp(train_timestamps.min()).__str__()}  -  {datetime.datetime.fromtimestamp(train_timestamps.max()).__str__()}'
    )
printmd(
    f'<h3>validation time:</h3> {datetime.datetime.fromtimestamp(validation_timestamps.min()).__str__()}  -  {datetime.datetime.fromtimestamp(validation_timestamps.max()).__str__()}'
    )
printmd(
    f'<h3>test time:</h3> {datetime.datetime.fromtimestamp(test_timestamps.min()).__str__()}  -  {datetime.datetime.fromtimestamp(test_timestamps.max()).__str__()}'
    )

In [None]:
set_timestamps = locals()[f'{SET_NAME}_timestamps']

# Get major earthquakes in catalog

In [None]:
if any(x in MODEL_NAME.lower() for x in ['jma', 'japan']):
  major_eq_df = data_utils.japan_major_earthquakes_dataframe()
  loaded_catalog = data_utils.jma_dataframe()
elif any(x in MODEL_NAME.lower() for x in ['scsn', 'hauksson', 'cali']):
  major_eq_df = data_utils.california_major_earthquakes_dataframe()
  loaded_catalog = data_utils.hauksson_dataframe()
elif any(x in MODEL_NAME.lower() for x in ['nz', 'geonet']):
  major_eq_df = data_utils.nz_major_earthquakes_dataframe()
  loaded_catalog = data_utils.nz_geonet_dataframe()

rows_to_keep = np.isin(loaded_catalog.time.values, set_timestamps)
set_catalog = loaded_catalog.copy().iloc[rows_to_keep, :]


In [None]:
time_logical = (major_eq_df.time >= set_timestamps.min() - DAY_TO_SECONDS) & (major_eq_df.time <= set_timestamps.max() + DAY_TO_SECONDS)
relevant_major_eq = major_eq_df[time_logical]

In [None]:
relevant_major_eq

In [None]:
time_margin = 2*DAY_TO_SECONDS
space_margin = 4
mag_margin = 0.8


all_event_details = []
for event in relevant_major_eq.itertuples():
  time_logical = (event.time-time_margin <= set_catalog.time) & (event.time+time_margin >= set_catalog.time)
  if np.isfinite(event.longitude):
    lon_logical = (event.longitude-space_margin <= set_catalog.longitude) & (event.longitude+space_margin >= set_catalog.longitude)
  else:
    lon_logical = np.full_like(time_logical, True)
  if np.isfinite(event.latitude):
    lat_logical = (event.latitude-space_margin <= set_catalog.latitude) & (event.latitude+space_margin >= set_catalog.latitude)
  else:
    lat_logical = np.full_like(time_logical, True)
  mag_logical = (event.magnitude-mag_margin <= set_catalog.magnitude) & (event.magnitude+mag_margin >= set_catalog.magnitude)
  total_logical = time_logical & lon_logical & lat_logical & mag_logical

  if not any(total_logical):
    print(f'NOTICE!!! no event found for {event.name} magnitude {event.magnitude}')
    continue

  event_index = set_catalog[total_logical].magnitude.idxmax()
  event_details = set_catalog[total_logical].loc[event_index]
  event_details['name'] = event.name
  event_details['index_in_set'] = np.where(set_catalog.index==event_index)[0][0]
  event_details['wiki_magnitude'] = event.magnitude
  event_details['wiki_longitude'] = event.longitude
  event_details['wiki_latitude'] = event.latitude
  event_details['catalog_date'] = datetime.datetime.fromtimestamp(event_details.time).strftime('%Y-%m-%d %H:%M')
  event_details['wiki_date'] = datetime.datetime.fromtimestamp(event.time).strftime('%Y-%m-%d %H:%M')
  all_event_details.append(event_details)
all_event_details = pd.DataFrame(all_event_details)

# Plot specific distributions

In [None]:
all_event_details

In [None]:

scaler_saving_dir = os.path.join(os.getcwd(), '..', f'results/trained_models/{MODEL_NAME}/scalers')
all_encoders = one_region_model.build_encoders(domain)

one_region_model.compute_and_cache_features_scaler_encoder(
    domain,
    all_encoders,
    force_recalculate = False,
)

In [None]:
features_and_models = one_region_model.load_features_and_construct_models(
    domain, all_encoders, scaler_saving_dir
)
train_features = one_region_model.features_in_order(features_and_models, 0)
valid_features = one_region_model.features_in_order(features_and_models, 1)
test_features = one_region_model.features_in_order(features_and_models, 2)

In [None]:
test_forecasts = loaded_model.predict(test_features)
plot_above_thresh = MAG_THRESH
m_vec = np.linspace(MAG_THRESH, 0.9*(MAG_THRESH+LOSS.stretch), 500)
prob_density_inst = probability_density_function(test_forecasts)


In [None]:
for event in all_event_details.itertuples():
  prob_func = lambda m0: prob_density_inst[event.index_in_set].prob((m0 - random_var_shift)/random_var_stretch)/random_var_stretch
  prob_vec = prob_func(m_vec)


  f_major, ax_major = plt.subplots(1,1, figsize=(5,5))
  p = ax_major.plot(m_vec, prob_vec, alpha=0.4, color='r', linewidth=4);
  train_gr_curve = metrics.gr_likelihood(m_vec, BETA_OF_TRAIN_SET, MAG_THRESH)
  gr_handle = ax_major.plot(m_vec, train_gr_curve, 'k--', label='train_gr_likelihood', linewidth=3)
  eq_true_mag = metrics.gr_likelihood(np.array([event.magnitude]), BETA_OF_TRAIN_SET, MAG_THRESH)
  ax_major.scatter([event.magnitude], eq_true_mag, s=100, marker='o', edgecolors='k', c='none', linewidths=2)
  ax_major.scatter([event.magnitude], [prob_func(event.magnitude).numpy()], s=500, marker='*')
  title = f'{event.name}, {event.magnitude:.2f}\n{datetime.datetime.fromtimestamp(event.time).strftime("%d-%b-%y")}'
  ax_major.set_title(title, fontsize=15)
  ax_major.set_yscale('log')
  #-- set xticks
  ax_major.set_xticks([event.magnitude], [f'{event.magnitude:.2f}'], color='red', size=18)
  #-- set yticks
  ax_major.set_yticks([])
  ax_major.tick_params(
    axis='y',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    left=False,      # ticks along the bottom edge are off
    right=False,         # ticks along the top edge are off
  )
  f_major
