Copyright 2024 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title License
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Imports

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow_probability as tfp
import os
import gin
import tensorflow as tf
import joblib
import keras
from eq_mag_prediction.forecasting import encoders
from eq_mag_prediction.scripts import magnitude_predictor_trainer
from eq_mag_prediction.forecasting import training_examples
from eq_mag_prediction.forecasting import one_region_model
from eq_mag_prediction.utilities import catalog_analysis
from eq_mag_prediction.forecasting.training_examples import CatalogDomain
from eq_mag_prediction.utilities import data_utils, catalog_filters
from eq_mag_prediction.forecasting import metrics
from eq_mag_prediction.forecasting.training_examples import CatalogDomain
from eq_mag_prediction.forecasting.data_sources import target_catalog
from eq_mag_prediction.forecasting.encoders import SeismicityRateEncoder, RecentEarthquakesEncoder, BiggestEarthquakesEncoder, CatalogColumnsEncoder



# Loading and experiment setting

## Load model

In [None]:
# MODEL_NAME = 'Hauksson'
MODEL_NAME = 'JMA'

In [None]:
experiment_dir = os.path.join(os.getcwd(), '..', 'results/trained_models/', MODEL_NAME)
custom_objects={
    '_repeat': encoders._repeat,
    }

loaded_model = tf.keras.models.load_model(
    os.path.join(experiment_dir, 'model'),
    custom_objects={'_repeat': encoders._repeat},
    compile=False,
    # safe_mode=True
    )

## Set gin

In [None]:
with open(os.path.join(experiment_dir, 'config.gin')) as f:
  with gin.unlock_config():
    gin.parse_config(f.read(), skip_unknown=True)

In [None]:
print(gin.config_str())

In [None]:
gin.finalize()

# Prepare and predict

## Preps

In [None]:
domain = training_examples.CatalogDomain()
labels = training_examples.magnitude_prediction_labels(domain)

## compute and construct features

### encoders

In [None]:
scaler_saving_dir = os.path.join(os.getcwd(), '..', f'results/trained_models/{MODEL_NAME}/scalers')


labels = training_examples.magnitude_prediction_labels(domain)
all_encoders = one_region_model.build_encoders(domain)

### construct and save

In [None]:

one_region_model.compute_and_cache_features_scaler_encoder(
    domain,
    all_encoders,
    force_recalculate = True,
)

 ### Reload features

In [None]:

features_and_models = one_region_model.load_features_and_construct_models(
    domain, all_encoders, scaler_saving_dir
)
train_features = one_region_model.features_in_order(features_and_models, 0)
valid_features = one_region_model.features_in_order(features_and_models, 1)
test_features = one_region_model.features_in_order(features_and_models, 2)

## Predict

In [None]:
test_forecasts = loaded_model.predict(test_features)
print(test_forecasts.shape)

## Plot predictions

In [None]:
#@title utility function
def to_rgb_string(rgb_list):
  rgb_string = 'rgb('
  for n in rgb_list:
    rgb_string += f'{n}, '
  rgb_string = rgb_string[:-2]
  rgb_string += ')'
  return rgb_string

In [None]:
probability_density_function = metrics.kumaraswamy_mixture_instance
LOSS = metrics.MinusLoglikelihoodLoss(probability_density_function, domain.magnitude_threshold)

random_var_shift = 0 if not hasattr(LOSS, 'shift') else LOSS.shift
random_var_stretch = 7 if not hasattr(LOSS, 'stretch') else LOSS.stretch

costum_shift_stretch = lambda x, random_var_shift=random_var_shift, random_var_stretch=random_var_stretch: np.minimum((x - random_var_shift) / random_var_stretch, 1)
shift_strech_input = costum_shift_stretch


BETA_OF_TRAIN_SET = catalog_analysis.estimate_beta(labels.train_labels, None, 'BPOS')
print(BETA_OF_TRAIN_SET)
MAG_THRESH = domain.magnitude_threshold
print(MAG_THRESH)

In [None]:
#--- setup data
plot_above_thresh = MAG_THRESH
m_vec = np.linspace(MAG_THRESH, 7, 500)
prob_density_inst = probability_density_function(test_forecasts)
prob_vecs = prob_density_inst.prob((m_vec[:, None] - random_var_shift)/random_var_stretch)/random_var_stretch


In [None]:
test_labels_to_plot_from = labels.test_labels[labels.test_labels>=plot_above_thresh]
prob_vecs_to_plot_from = prob_vecs.numpy()[:, labels.test_labels>=plot_above_thresh]


p_for_mags = np.exp(BETA_OF_TRAIN_SET*test_labels_to_plot_from)
p_for_mags /= p_for_mags.sum()
# rnd_seed = np.random.RandomState(seed=1902) # nice preview for socal hauksson
rnd_seed = np.random.RandomState(seed=1000)
label_idxs_to_plot = np.sort(rnd_seed.choice(prob_vecs_to_plot_from.shape[1],100, replace=False, p=p_for_mags))
labels_to_plot = test_labels_to_plot_from[label_idxs_to_plot]
mpl.rcParams.update({'font.size': 16})

In [None]:
#@title setup figure
num_mags = 25
min_mag = 2
max_mag = 6.5
m_scale = np.linspace(min_mag-0.01, max_mag, num_mags)
norm_inst = plt.Normalize(min_mag, max_mag);

chosen_colormap = plt.cm.gist_stern_r
# chosen_colormap = warn_cold_cmap
colors = chosen_colormap(np.linspace(0,1,num_mags))
colors2plot = colors[np.argmin(np.abs(test_labels_to_plot_from[label_idxs_to_plot][:,None] - m_scale[None,:]), axis=1)]


f_dist_fig, ax_dist_fig = plt.subplots(1,1)

for idx, lbl_index in enumerate(label_idxs_to_plot):

  p = ax_dist_fig.plot(m_vec, prob_vecs_to_plot_from[:, lbl_index], alpha=0.4, color=colors2plot[idx], linewidth=4);

  add_text = False
  if add_text:
    # add text
    y_peak = prob_vecs_to_plot_from[:, lbl_index].max()
    x_peak = m_vec[np.argmax(prob_vecs_to_plot_from[:, lbl_index])]
    text = str(labels_to_plot[idx])
    txt = ax_dist_fig.text(x_peak, y_peak, text);

# plot GR train set
train_gr_curve = metrics.gr_likelihood(m_vec, BETA_OF_TRAIN_SET, MAG_THRESH)
gr_handle = ax_dist_fig.plot(m_vec, train_gr_curve, 'k--', label='train_gr_likelihood', linewidth=3)
ax_dist_fig.legend(handles=gr_handle, frameon=False)

norm_inst = plt.Normalize(min_mag, max_mag);
sm = plt.cm.ScalarMappable(cmap=chosen_colormap, norm=norm_inst);

ax_dist_fig.set_xlabel('magnitude')
ax_dist_fig.set_ylabel('p(magnitude)')
ax_dist_fig.set_xscale('linear')
f_dist_fig