Copyright 2024 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title License
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# imports

In [None]:
import pandas as pd
import numpy as np
import numbers
import xarray as xr
from sklearn import metrics as skl_metrics
import scipy as sp
from scipy import stats
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.colors as mpl_colors


#--- for printing formatted text
from IPython.display import display, Markdown
def printmd(string):
    display(Markdown(string))

import os
import datetime
from absl import flags
import time
from matplotlib import pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import FixedLocator
import logging

import gin
gin.enter_interactive_mode()

np.set_printoptions(precision=4, threshold=2500)


In [None]:
from eq_mag_prediction.scripts import calculate_benchmark_gr_properties
from eq_mag_prediction.scripts import magnitude_predictor_trainer   # import unused for gin config
from eq_mag_prediction.forecasting import metrics, training_examples
from eq_mag_prediction.forecasting import encoders
from eq_mag_prediction.forecasting import one_region_model
from eq_mag_prediction.utilities import geometry
from eq_mag_prediction.utilities import statistics_utils as statistics
from eq_mag_prediction.utilities import catalog_analysis


# Plotting settings

## Colors settings

In [None]:
# @title Create warm-cool gradient colormap

listed_colors_discrete_warm_cool =[
    '#081d58',
    '#253494',
    '#225ea8',
    '#1d91c0',
    '#41b6c4',
    '#7fcdbb',
    '#c7e9b4',
    '#edf8b1',
    '#ffffcc',
    '#ffeda0',
    '#fed976',
    '#feb24c',
    '#fd8d3c',
    '#fc4e2a',
    '#e31a1c',
    '#bd0026',
    '#800026',
]

warn_cold_cmap = mpl_colors.LinearSegmentedColormap.from_list('warn_cold_cmap', listed_colors_discrete_warm_cool)
warn_cold_cmap

In [None]:
#@title Discrete colors

listed_colors_discrete = [
    '#e41a1c',
    '#377eb8',
    '#4daf4a',
    '#984ea3',
    '#ff7f00',
    '#ffff33',
    '#f0027f',
]

color_list_cmap = mpl_colors.ListedColormap(listed_colors_discrete)
color_list_cmap


## Plot properties

In [None]:
#@title Plot lines res font etc...

FONT_SIZE = 8     #@param{type:'number'}
TITLE_SIZE = 6    #@param{type:'number'}

MARKER_SIZE = 4
AXIS_LINE = 0.5     #@param{type:'number'}
PLOT_LINE_WIDTH = 2 #@param{type:'number'}

FIG_WIDTH = 3 #@param{type:'number'}
FIG_HEIGHT = 3 #@param{type:'number'}
light_grey = '#EDEDED'

figure_dpi = 150 #@param{type:'number'}

rc_params = {'axes.linewidth': AXIS_LINE,
             'axes.titlesize': TITLE_SIZE,     # font size of the axes title
             'axes.labelsize': FONT_SIZE,      # font size of the x and y labels
             'axes.prop_cycle': mpl.cycler(color=listed_colors_discrete),

             'font.family': 'STIXGeneral',
             'font.sans-serif': 'Lato',
             'font.weight': 'normal',
             'font.size': FONT_SIZE,           # controls default text sizes

             'legend.frameon': True,           # don't show a box around the legend
             'legend.fontsize': FONT_SIZE,     # legend font size

             'figure.titlesize': FONT_SIZE,   # font size of the figure title
             'figure.figsize': (FIG_WIDTH, FIG_HEIGHT),   # font size of the figure title
             'figure.dpi' : figure_dpi,

             'lines.linewidth': PLOT_LINE_WIDTH
}

mpl.rcParams.update(rc_params)

# Loadings and Configs
models, features, labels, forecasts...


In [None]:
# MODEL_NAME = 'Hauksson'
MODEL_NAME = 'JMA'

In [None]:
experiment_dir = os.path.join(os.getcwd(), '..', 'results/trained_models/', MODEL_NAME)
custom_objects={
    '_repeat': encoders._repeat,
    }


In [None]:
# Load model
loaded_model = tf.keras.models.load_model(
    os.path.join(experiment_dir, 'model'),
    custom_objects={'_repeat': encoders._repeat},
    compile=False,
    # safe_mode=True
    )

In [None]:
# set gin configs
with open(os.path.join(experiment_dir, 'config.gin')) as f:
    with gin.unlock_config():
        gin.parse_config(f.read(), skip_unknown=False)

In [None]:
print(gin.config_str())

In [None]:
domain = training_examples.CatalogDomain()
labels = training_examples.magnitude_prediction_labels(domain)

scaler_saving_dir = os.path.join(os.getcwd(), '..', 'results/trained_models', MODEL_NAME, 'scalers')

labels = training_examples.magnitude_prediction_labels(domain)
all_encoders = one_region_model.build_encoders(domain)



In [None]:
all_encoders

In [None]:
one_region_model.compute_and_cache_features_scaler_encoder(
    domain,
    all_encoders,
    force_recalculate = False,
)
features_and_models = one_region_model.load_features_and_construct_models(
    domain, all_encoders, scaler_saving_dir
)
train_features = one_region_model.features_in_order(features_and_models, 0)
validation_features = one_region_model.features_in_order(features_and_models, 1)
test_features = one_region_model.features_in_order(features_and_models, 2)

In [None]:
forecasts = {}
for set_name in ['train', 'validation', 'test']:
    forecasts[set_name] = loaded_model.predict(locals()[f'{set_name}_features'])

# Analysis and plotting

## Coordinate handling functions

In [None]:

def is_data_at_180(longitude_coors):
  is_it = (longitude_coors.min()<=-170) & (longitude_coors.max()>=170) & (((longitude_coors>=-100) & (longitude_coors<=100)).sum()==0)
  return is_it

def longitude_to_theta(longitudes, norm_by_pi=True, convert_to_rad=True):
  theta = longitudes.copy()
  theta[longitudes<0] = theta[longitudes<0] + 360
  if convert_to_rad:
    if norm_by_pi:
      return np.deg2rad(theta)/np.pi
    return np.deg2rad(theta)
  return theta


def lon_lat_to_spherical(
    lon_lat_array: np.ndarray, # needs to be a NX2 array, 1st column lon 2nd lat
    norm_by_pi=True,
    convert_to_rad=False,
    ):
  theta_phi = lon_lat_array.copy()
  theta_phi[:, 0] = longitude_to_theta(lon_lat_array[:, 0], norm_by_pi, convert_to_rad=convert_to_rad)
  if not convert_to_rad:
    return theta_phi
  theta_phi[:, 1] = np.deg2rad(lon_lat_array[:, 1])
  if norm_by_pi:
    theta_phi[:, 1] = theta_phi[:, 1] / np.pi
  return theta_phi


def lon_lat_for_map_plotting(longitude_coors, latitude_coors):
  if is_data_at_180(longitude_coors):
    coord_array = np.hstack((
        longitude_coors.ravel()[:, None],
        latitude_coors.ravel()[:, None]
        ))
    coord_array = lon_lat_to_spherical(coord_array, convert_to_rad=False)
    longs = coord_array.T[0]
    lats = coord_array.T[1]
  else:
    longs = longitude_coors
    lats = latitude_coors
  return longs, lats

## Set relevant probability density and other definitions

In [None]:
# set the relevant probability density function
probability_density_function = metrics.kumaraswamy_mixture_instance

In [None]:
BETA_OF_TRAIN_SET = catalog_analysis.estimate_beta(labels.train_labels, None, 'BPOS')
MAG_THRESH = domain.magnitude_threshold
DAY_TO_SECONDS = 60*60*24

In [None]:
try:
    support_stretch = gin.query_parameter('train_and_evaluate_magnitude_prediction_model.pdf_support_stretch')
except:
    default_stretch = 7
    message = f"<span style='color:red; font-size:25px'>pdf_support_stretch not defined in gin, setting to default: {default_stretch}</span>"
    display(Markdown(message))
    support_stretch = default_stretch



In [None]:
# Create a shift function for labels

random_var_shift = MAG_THRESH
random_var_stretch = support_stretch

costum_shift_stretch = lambda x, random_var_shift=random_var_shift, random_var_stretch=random_var_stretch: np.minimum((x - random_var_shift) / random_var_stretch, 1)
shift_strech_input = costum_shift_stretch


In [None]:
timestamps_dict = calculate_benchmark_gr_properties.create_timestamps_dict(domain)
test_timestamps = timestamps_dict['test']
validation_timestamps = timestamps_dict['validation']
train_timestamps = timestamps_dict['train']
all_timestamps = np.concatenate([train_timestamps, validation_timestamps, test_timestamps])

coordinates_dict = calculate_benchmark_gr_properties.create_coordinates_dict(domain)

## Functions for computing likelihoods and baselines

In [None]:
def likelihood_probability_func(
      labels,
      forecasts,
      shift = random_var_shift,
      stretch = random_var_stretch,
      ):
  # Create a tfp.distributions.Distribution instance:
  random_variable = probability_density_function(
      tf.convert_to_tensor(forecasts))
  labels_tensor = tf.reshape(tf.convert_to_tensor(labels, dtype=forecasts.dtype), (-1,))
  likelihood = random_variable.prob(shift_strech_input(labels_tensor))/stretch
  return likelihood

def number_to_vector(num, set_name):
  if isinstance(num, numbers.Number):
    return np.full_like(getattr(labels, f'{set_name}_labels'), num)
  return num

def p_of_m_model_above_cutoff(m_tilde, set_name, forecasts_i):
  model_likelihood_above_cutoff = likelihood_probability_func(
      labels=getattr(labels, f'{set_name}_labels'),
      forecasts=forecasts_i[set_name],
      shift=random_var_shift,
      stretch=random_var_stretch,
      )[getattr(labels, f'{set_name}_labels') >= number_to_vector(m_tilde, set_name)]
  return model_likelihood_above_cutoff

def model_survival_at_cutoff(m_tilde, set_name, forecasts_i):
  # Create a tfp.distributions.Distribution instance:
  random_variable = probability_density_function(tf.convert_to_tensor(forecasts_i[set_name]))
  return random_variable.survival_function(shift_strech_input(np.maximum(m_tilde, MAG_THRESH)))


def conditioned_likelihood_model(m_tilde, set_name, forecasts_i):
  m_tilde = number_to_vector(m_tilde, set_name)
  return p_of_m_model_above_cutoff(m_tilde, set_name, forecasts_i) / model_survival_at_cutoff(m_tilde, set_name, forecasts_i)[getattr(labels, f'{set_name}_labels')>=number_to_vector(m_tilde, set_name)]


# split to sets dict
def split_to_sets_dicts(main_dict):
  train_dict = {k:v for k, v in main_dict.items() if k.endswith('_train')}
  validation_dict = {k:v for k, v in main_dict.items() if k.endswith('_validation')}
  test_dict = {k:v for k, v in main_dict.items() if k.endswith('_test')}
  return train_dict, validation_dict, test_dict


def split_name_to_model_and_set(name):
  under_score_idx = name[::-1].find('_')
  current_model = name[:-(under_score_idx+1)]
  set_name = name[-(under_score_idx):]
  return (current_model, set_name)

def split_name_to_model_and_set(name):
  under_score_idx = name[::-1].find('_')
  current_model = name[:-(under_score_idx+1)]
  set_name = name[-(under_score_idx):]
  return (current_model, set_name)

def sort_strings_w_constraint(list_of_strings, start_with_constraint):
  sorted_list = []
  for cons in start_with_constraint:
    cons_list = [l for l in list_of_strings if l.startswith(cons)]
    cons_list.sort()
    sorted_list += cons_list
  remains_list = list(set(list_of_strings) - set(sorted_list))
  remains_list.sort()
  sorted_list += remains_list
  return sorted_list

# Compute model's results and baselines

In [None]:
#@title set cache path
GR_PROPERTIES_CACHE = os.path.join(
    os.getcwd(), '..', 'results/cached_benchmarks'
)

### collect beta and mc for gr variations models

In [None]:
# an ugly workaround for parsing calculate_benchmark_gr_properties flags manually
custom_args = [
    f"--{calculate_benchmark_gr_properties._CACHE_DIR.name}=GR_PROPERTIES_CACHE",
    f"--{calculate_benchmark_gr_properties._FORCE_RECALCULATE.name}=False",
]
FLAGS = flags.FLAGS
FLAGS(custom_args)


In [None]:
# show logging info while running
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

LOAD_KDE = False
gr_models_beta, gr_models_mc = calculate_benchmark_gr_properties.compute_and_assign_benchmarks_all_sets(
    domain,
    timestamps_dict,
    coordinates_dict,
    BETA_OF_TRAIN_SET,
    MAG_THRESH,
    # compute_benchmark={'spatial_gr':False},
    compute_benchmark={'n_past_events_kde':LOAD_KDE},
)

In [None]:
for k in iter(list(gr_models_beta.keys())):
  if not k.startswith('gr_spatial'):
    continue
  k_new = k.replace('gr_spatial', 'spatial_gr')
  gr_models_beta[k_new] = gr_models_beta.pop(k)
  gr_models_mc[k_new] = gr_models_mc.pop(k)

## Compute likelihoods of models and baselines

### GR and other baselines

In [None]:
gr_likelihoods_and_baselines = {}
for k in gr_models_beta:
  set_name = k.split('_')[-1]
  if 'events_kde' in k:
    gr_likelihoods_and_baselines[k] = np.array(
        [kde(l) for kde,l in zip(gr_models_beta[k], getattr(labels, f'{set_name}_labels'))]
        ).ravel()
  else:
    gr_likelihoods_and_baselines[k] = metrics.gr_likelihood(
        getattr(labels, f'{set_name}_labels'),
        gr_models_beta[k],
        gr_models_mc[k],
        )


### Model's scores

In [None]:
likelihoods_and_baselines = {}

for set_name in ['train', 'validation', 'test']:
  likelihoods_and_baselines[f'model_{MODEL_NAME}_likelihood_{set_name}'] = np.array(
      likelihood_probability_func(
          getattr(labels, f'{set_name}_labels'),
          forecasts[set_name],
          MAG_THRESH,
          )
      )

likelihoods_and_baselines.update(gr_likelihoods_and_baselines)

## Display results

In [None]:
#@title Baselines to display

MODELS_TO_PLOT = [split_name_to_model_and_set(k)[0] for k in likelihoods_and_baselines.keys() if k.startswith('model_') & k.endswith('_test')]
MODELS_TO_PLOT += [
    'train_gr_likelihood',
    'test_gr_likelihood',
    # 'gr_last_10_days_constant_mc_likelihood',
    'gr_last_100_days_constant_mc_likelihood',
    # 'gr_last_1000_days_constant_mc_likelihood',
    # 'gr_last_10_days_fitted_mc_likelihood',
    # 'gr_last_100_days_fitted_mc_likelihood',
    # 'gr_last_1000_days_fitted_mc_likelihood',
    'n300_past_events_constant_mc',
    # 'n300_present_events_constant_mc',
    # 'n300_past_events_fitted_mc',
    # 'n300_present_events_fitted_mc',
    # 'spatial_gr_on_all_likelihood',
    # 'spatial_gr_on_train_likelihood',
    # 'gr_spatial_on_train_likelihood',
    # 'spatial_gr_on_test_likelihood',
    # 'n300_past_events_kde_constant_mc'
]
MODELS_TO_PLOT = sort_strings_w_constraint(MODELS_TO_PLOT, ['model', 'train', 'test', 'gr', 'n', 'saptial'])

COLOR_PER_MODEL = {m:listed_colors_discrete[i] for i,m in enumerate(MODELS_TO_PLOT)}

#### display helper-functions

In [None]:
def create_scores_summary_df(
    likelihoods_and_baselines_dictionary,
    per_set_boolean_filter=None,
    exclude_zeros=False,
    drop_nans=True
    ):
  model_names = set()
  for k in likelihoods_and_baselines_dictionary.keys():
    under_score_idx = k[::-1].find('_')
    model_name = k[:-(under_score_idx+1)]
    model_names.add(model_name)

  summary_df = pd.DataFrame(
      index=sort_strings_w_constraint(
          list(model_names),
           ['model_', 'train', 'test', 'gr_', 'n_'],
          ),
      columns=['train', 'validation', 'test'],
      )

  for k in likelihoods_and_baselines_dictionary.keys():
    current_model, set_name = split_name_to_model_and_set(k)

    total_logical = np.full_like(likelihoods_and_baselines_dictionary[k].ravel(), True).astype(bool)
    if per_set_boolean_filter is not None:
      total_logical = total_logical & per_set_boolean_filter[set_name]
    if exclude_zeros:
      total_logical = total_logical & (likelihoods_and_baselines_dictionary[k]!=0)
    if drop_nans:
      total_logical = total_logical & (~np.isnan(likelihoods_and_baselines_dictionary[k]))

    summary_df.loc[current_model, set_name] = float(-np.log(likelihoods_and_baselines_dictionary[k][total_logical]).mean())
  return summary_df.apply(pd.to_numeric)



def get_grad_colormap(original_color):
  listed_colors_discrete = [
      list(original_color),
      (1, 1, 1, 1),
      ]
  return mpl_colors.LinearSegmentedColormap.from_list('grad_colormap', np.array(listed_colors_discrete))


def barplot_scores(scores_summary_df, models_to_plot_list, colors=None, set_name='test'):
  data_column = scores_summary_df[set_name].loc[MODELS_TO_PLOT]
  are_infs = np.isinf(data_column)
  non_inf_max = data_column[~are_infs].max()
  margin = (non_inf_max - data_column[~are_infs].min())/4
  replace_inf_val = np.max(data_column[~are_infs]) + 2*margin
  data_column[are_infs] = replace_inf_val
  infs_bars = np.where(are_infs)[0]


  f, ax = plt.subplots(1, 1)
  bars_handle = ax.bar(
      models_to_plot_list,
      data_column,
      color=colors
      )
  #-- account for infs:
  if infs_bars.size > 0:
    bar_ax = bars_handle[0].axes
    lim = bar_ax.get_xlim()+bar_ax.get_ylim()
    for inf_idx in infs_bars:
      bar = bars_handle[inf_idx]
      bar.set_zorder(1)
      original_color = bar.get_facecolor()
      grad_colormap = get_grad_colormap(original_color)
      bar.set_facecolor("none")
      x,y = bar.get_xy()
      w, h = bar.get_width(), bar.get_height()
      grad = np.atleast_2d(np.linspace(replace_inf_val, 0, 1000)).T
      normalizer = mpl.colors.PowerNorm(0.8, vmin=replace_inf_val-margin, vmax=replace_inf_val)
      ax.imshow(grad, extent=[x,x+w,y,y+h], aspect="auto", zorder=0, cmap=grad_colormap, norm=normalizer)
      ax.text(x+w/2, replace_inf_val, '$\infty$', ha='center', color=original_color)
    bar_ax.axis(lim)

  max_y = data_column.max() + margin
  min_y = data_column.min() - margin
  ax.set_ylim(min_y, max_y)
  for label in ax.get_xticklabels():
    label.set(rotation=-30, horizontalalignment='left')
  ax.set_ylabel(r'$-\langle \log \leftparen \mathtt{likelihood} \rightparen \rangle$')
  return f, ax

### Minus mean log-likelihood   $-<\mathcal{L}>$

In [None]:
#@title Minus mean log likelihood

summary_df_mean_ll = create_scores_summary_df(likelihoods_and_baselines, drop_nans=True, exclude_zeros=True)
summary_df_mean_ll

In [None]:
f_meanLL_barplot, ax_meanLL_barplot = barplot_scores(summary_df_mean_ll, MODELS_TO_PLOT, [COLOR_PER_MODEL[m] for m in MODELS_TO_PLOT])

## Results conditioned above a threshold

In [None]:
#@title computation of conditioned likelihood
NUMERICAL_THRESH = 1e-10
def compute_conditioned_gr_variations(
    gr_models_beta_dict,
    gr_models_mc_dict,
    labels_inst,
    m_condition,
    numerical_thresh=NUMERICAL_THRESH,
    ):
  conditioned_gr_variations = {}
  for k in gr_models_beta_dict:
    set_name = k.split('_')[-1]
    if 'events_kde' in k:
      survivals = np.array([kde.integrate_box_1d(mt, 20) for kde, mt in zip(gr_models_beta_dict[k], m_condition)]).ravel()
      above_mc_logical = getattr(labels, f'{set_name}_labels')>=m_condition
      kde_likelihoods = np.array([kde(l) for kde,l in zip(gr_models_beta_dict[k], getattr(labels, f'{set_name}_labels'))]).ravel()
      conditioned_gr_variations[k] = kde_likelihoods[above_mc_logical]/survivals[above_mc_logical]
      conditioned_gr_variations[k][kde_likelihoods[above_mc_logical] < numerical_thresh] = np.nan
    else:
      conditioned_gr_variations[k] = np.array(
          metrics.gr_conditioned_likelihood(
              getattr(labels_inst, f'{set_name}_labels'),
              gr_models_beta_dict[k],
              gr_models_mc_dict[k],
              m_condition
              )
      )
  return conditioned_gr_variations

## Likelihoods conditioned above temporal incompleteness

In [None]:
gr_models_beta_train, gr_models_beta_validation, gr_models_beta_test = split_to_sets_dicts(gr_models_beta)
gr_models_mc_train, gr_models_mc_validation, gr_models_mc_test = split_to_sets_dicts(gr_models_mc)

In [None]:
#@title mean LL conditioned on PRESENT incompleteness:  $m\geq \tilde{m}=m_c(t_{present})$


# --- set dictionary
likelihoods_and_baselines_cond_present_temp_incompleteness = {}

# --- set model conditioned likelihoods
for set_name in ['train', 'validation', 'test']:
  likelihoods_and_baselines_cond_present_temp_incompleteness[f'model_{MODEL_NAME}_likelihood_{set_name}'] = np.array(
    conditioned_likelihood_model(
      gr_models_mc[f'n300_present_events_fitted_mc_{set_name}'],
      set_name,
      forecasts
      )
    )

# --- set GR like conditioned likelihoods
conditioned_present_temporal_incompleteness_train = compute_conditioned_gr_variations(
    gr_models_beta_train,
    gr_models_mc_train,
    labels,
    gr_models_mc_train['n300_present_events_fitted_mc_train']
)

conditioned_present_temporal_incompleteness_validation = compute_conditioned_gr_variations(
    gr_models_beta_validation,
    gr_models_mc_validation,
    labels,
    gr_models_mc_validation['n300_present_events_fitted_mc_validation']
)

conditioned_present_temporal_incompleteness_test = compute_conditioned_gr_variations(
    gr_models_beta_test,
    gr_models_mc_test,
    labels,
    gr_models_mc_test['n300_present_events_fitted_mc_test']
)

likelihoods_and_baselines_cond_present_temp_incompleteness.update(conditioned_present_temporal_incompleteness_train)
likelihoods_and_baselines_cond_present_temp_incompleteness.update(conditioned_present_temporal_incompleteness_validation)
likelihoods_and_baselines_cond_present_temp_incompleteness.update(conditioned_present_temporal_incompleteness_test)


# --- display all
summary_df_mean_ll_cond_mc_present = create_scores_summary_df(likelihoods_and_baselines_cond_present_temp_incompleteness, drop_nans=True)
summary_df_mean_ll_cond_mc_present

In [None]:
f_thresh_t_present_barplot, ax_thresh_t_present_barplot = barplot_scores(summary_df_mean_ll_cond_mc_present, MODELS_TO_PLOT, [COLOR_PER_MODEL[m] for m in MODELS_TO_PLOT])

### Plot $m_c(t_{present})$


In [None]:
f_completeness, ax_completeness = plt.subplots(1, 1, figsize=(7, 3))

y_scatter = labels.test_labels
index = np.arange(y_scatter.size)

ax_completeness.scatter(
    index,
    y_scatter,
    s=2.3**y_scatter,
    alpha=0.4,
    c='k'
    )

#---- Plot under Mc
below_mc_bool = domain.earthquakes_catalog.magnitude < MAG_THRESH
test_events_bool = ((domain.earthquakes_catalog.time>=test_timestamps.min()) &
                    (domain.earthquakes_catalog.time<test_timestamps.max()))
interp_index = np.interp(
    domain.earthquakes_catalog['time'].values[test_events_bool & below_mc_bool],
    test_timestamps,
    index,
)
interp_magnitudes = domain.earthquakes_catalog['magnitude'].values[test_events_bool & below_mc_bool]

ax_completeness.scatter(
    interp_index,
    interp_magnitudes,
    c=None,
    facecolor='silver',
    s=2.3**interp_magnitudes,
)

#--- Plot Mc(t)
ax_completeness.plot(
    index,
    gr_models_mc_test['n300_present_events_fitted_mc_test'],
    )
ax_completeness.set_ylabel('Magnitude')
ax_completeness.set_xlabel('Index')


## Spatial beta metrics

# Probability density per example

In [None]:
# -- Plot only a sub group in specific order
models_to_plot_by_order = sort_strings_w_constraint(MODELS_TO_PLOT, ['model', 'n', 'gr', 'train', 'test'])
set_to_plot = 'test'

In [None]:
f_likelihood_per_example, ax_likelihood_per_example = plt.subplots(1,1,)
mag_check = np.linspace(MAG_THRESH, 8, 1000)
window_size = 0.3
models_to_plot = [
    [m for m in models_to_plot_by_order if m.startswith('model_')][0],
    'n300_past_events_constant_mc',
]
for i_model in models_to_plot:
  if i_model not in models_to_plot_by_order:
    continue
  else:
    zorder = models_to_plot_by_order.index(i_model)
  sc = ax_likelihood_per_example.scatter(
      labels.test_labels,
      likelihoods_and_baselines[f'{i_model}_{set_to_plot}'],
      c=COLOR_PER_MODEL[i_model],
      alpha=0.4,
      label=i_model,
      zorder=zorder
      )

train_gr_curve = metrics.gr_likelihood(mag_check, BETA_OF_TRAIN_SET, MAG_THRESH)
gr_handle = ax_likelihood_per_example.plot(
    mag_check,
    train_gr_curve,
    '--',
    linewidth=3,
    color=COLOR_PER_MODEL['train_gr_likelihood'],
    label='train_gr_likelihood',
    )
leg = ax_likelihood_per_example.legend()
for lh in leg.legendHandles:
    lh.set_alpha(1)
ax_likelihood_per_example.set_xlabel('magnitude', labelpad=10)
ax_likelihood_per_example.set_ylabel('label\'s likelihood', labelpad=10)
ax_likelihood_per_example.set_yscale('log')

# Organize likelihoods data in df

In [None]:
#--- create catalog for comparisons
rows_to_keep = np.isin(domain.earthquakes_catalog.time.values, all_timestamps)
comparison_catalog = domain.earthquakes_catalog.copy().iloc[rows_to_keep, :]

#-- add set name indicator
train_indicator = ['train']*len(labels.train_labels)
validation_indicator = ['validation']*len(labels.validation_labels)
test_indicator = ['test']*len(labels.test_labels)
set_indicator = np.concatenate([train_indicator, validation_indicator, test_indicator])
comparison_catalog['set_name'] = set_indicator


baselines_names = []
set_names = ['train', 'validation', 'test']
for k in likelihoods_and_baselines.keys():
  for s in set_names:
    if k.endswith(s):
      baselines_names.append(k[:-(len(s)+1)])
baselines_names = list(set(baselines_names))

for base_name in baselines_names:
  baseline_vector = np.empty((0))
  for s in set_names:
    k = base_name + '_' + s
    if k not in likelihoods_and_baselines.keys():
      baseline_vector = np.concatenate([baseline_vector, np.full_like(getattr(labels, f'{s}_labels'), np.nan)])
    else:
      baseline_vector = np.concatenate([baseline_vector, likelihoods_and_baselines[k]])
  comparison_catalog[base_name] = baseline_vector


  #--- add conditioned scores
  baseline_vector_conditioned = np.empty((0))
  for s in set_names:
    k = base_name + '_' + s

    baseline_addition = np.full_like(getattr(labels, f'{s}_labels'), np.nan)

    above_thresh_logical = getattr(labels, f'{s}_labels') >= locals()[f'gr_models_mc_{s}'][f'n300_present_events_fitted_mc_{s}']
    if k in likelihoods_and_baselines_cond_present_temp_incompleteness.keys():
      baseline_addition[above_thresh_logical] = likelihoods_and_baselines_cond_present_temp_incompleteness[k]
    baseline_vector_conditioned = np.concatenate([baseline_vector_conditioned, baseline_addition])
  comparison_catalog[base_name+'_conditioned'] = baseline_vector_conditioned

# Comparisons to benchmarks

### Select data to present


In [None]:

scores_name = f'model_{MODEL_NAME}_likelihood'
baseline_name = 'train_gr_likelihood'

scores_name += '_conditioned'
baseline_name += '_conditioned'

start_time = domain.test_start_time + 0
end_time = domain.test_end_time + 0
longs = comparison_catalog[comparison_catalog.set_name == 'test'].longitude
lats = comparison_catalog[comparison_catalog.set_name == 'test'].latitude
min_longitude = longs.min() - (longs.max() - longs.min()) / 10
max_longitude = longs.max() + (longs.max() - longs.min()) / 10
min_latitude = lats.min() - (lats.max() - lats.min()) / 10
max_latitude = lats.max() + (lats.max() - lats.min()) / 10

### Fucntions for organizing data for plot

In [None]:

def logical_for_cropped_comparison_cat():
  latitude_logical = (comparison_catalog['latitude'].values >= min_latitude) & (
      comparison_catalog['latitude'].values <= max_latitude
  )
  longitude_logical = (
      comparison_catalog['longitude'].values >= min_longitude
  ) & (comparison_catalog['longitude'].values <= max_longitude)
  time_logical = (comparison_catalog['time'].values >= start_time) & (
      comparison_catalog['time'].values < end_time
  )
  total_logical = latitude_logical & longitude_logical & time_logical
  return total_logical


def create_cropped_comparison_catalog(scores_name, baseline_name):
  total_logical = logical_for_cropped_comparison_cat()
  cropped_comparison_catalog = comparison_catalog[total_logical]

  # Non condtioned information gain
  baseline_name_no_suffix, scores_name_no_suffix = (
      s.removesuffix('_conditioned') for s in [baseline_name, scores_name]
  )
  add_log_difference = np.log(
      cropped_comparison_catalog[scores_name_no_suffix]
  ) - np.log(cropped_comparison_catalog[baseline_name_no_suffix])

  finite_difference = np.copy(add_log_difference)
  finite_difference[~np.isfinite(finite_difference)] = 0
  add_information_gain = np.nancumsum(finite_difference)


  # Condtioned information gain:
  add_log_difference_conditioned = np.log(cropped_comparison_catalog[scores_name_no_suffix + '_conditioned']) - np.log(
      cropped_comparison_catalog[baseline_name_no_suffix + '_conditioned']
  )
  finite_difference_conditioned = np.copy(add_log_difference_conditioned)
  finite_difference_conditioned[~np.isfinite(finite_difference_conditioned)] = 0
  add_information_gain_conditioned = np.nancumsum(finite_difference_conditioned)

  add_actual_time = [datetime.datetime.fromtimestamp(raw_t) for raw_t in cropped_comparison_catalog['time'].values]


  cropped_comparison_catalog = cropped_comparison_catalog.assign(
      log_difference= add_log_difference,
      information_gain= add_information_gain,
      log_difference_conditioned= add_log_difference_conditioned,
      information_gain_conditioned= add_information_gain_conditioned,
      actual_time= add_actual_time,
  )

  cropped_comparison_catalog = cropped_comparison_catalog.assign()
  return cropped_comparison_catalog

## Temporal comparisons

In [None]:
cropped_comparison_catalog = create_cropped_comparison_catalog(
    scores_name, baseline_name
)

In [None]:
#@title plot temporal gain

#------ SCATTER INFORMATION GAIN -----

base_factor=2.3
def _forward_scatter_size_legend(x):
  return base_factor**x
def _backward_scatter_size_legend(x):
  return np.log(x)/np.log(base_factor)

f_likelihood_overtime, ax_likelihood_overtime = plt.subplots(1,1, figsize=(7, 4))

x_scatter = np.arange(len(cropped_comparison_catalog['actual_time'].values))
sc = ax_likelihood_overtime.scatter(
    x_scatter,
    cropped_comparison_catalog['magnitude'].values,
    c=cropped_comparison_catalog['log_difference'].values,
    s=_forward_scatter_size_legend(cropped_comparison_catalog['magnitude']),
    alpha=0.8,
    cmap='coolwarm',
    norm=mpl.colors.TwoSlopeNorm(vmin=-2, vcenter=0, vmax=4),
    )


_ = ax_likelihood_overtime.set_xlabel('Event index')
_ = ax_likelihood_overtime.set_ylabel('Magnitude')
cb = f_likelihood_overtime.colorbar(sc, ax=ax_likelihood_overtime, location='left', extend='both', aspect=28, label='Log-lokelihood difference')
cb.ax.set_yscale('linear')


#----- SECOND X AXIS -----

def forward_2nd_xaxis(x):
  return np.interp(x, x_scatter, test_timestamps)
def inverse_2nd_xaxis(x):
  return np.interp(x, test_timestamps, x_scatter)



# ----- list of first day in  each month:
min_datetime = datetime.datetime.fromtimestamp(test_timestamps.min())
min_year = min_datetime.year

max_datetime = datetime.datetime.fromtimestamp(test_timestamps.max())
max_year = max_datetime.year

first_of_months = []
for y in range(min_year, max_year+1):
  for m in range(1, 13):
    first_of_months.append(datetime.datetime(year=y, month=m, day=1))
first_of_months = [t for t in first_of_months if ((t>min_datetime) & (t<max_datetime))][::6]
first_of_months_epoch_time = [t.timestamp() for t in first_of_months]
first_of_months_string = [t.strftime('%b-%y') for t in first_of_months]

secax = ax_likelihood_overtime.secondary_xaxis(-0.1, functions=(forward_2nd_xaxis, inverse_2nd_xaxis))
secax.spines['bottom'].set_color('dimgrey')
secax.tick_params(axis='x', labelcolor='dimgrey')
_ = secax.xaxis.set_major_locator(FixedLocator(first_of_months_epoch_time))
_ = secax.xaxis.set_minor_locator(FixedLocator([]))
_ = secax.set_xticklabels(first_of_months_string)
for label in secax.get_xticklabels(which='major'):
  _ = label.set(rotation=-60, horizontalalignment='left')
_ = secax.set_xlabel('Date', color='dimgrey')


#----- PLOT CUM INFORMATION ON 2ND Y AXIS -----

ax_cum_info_gain = ax_likelihood_overtime.twinx()
x_info_gain = np.arange(len(cropped_comparison_catalog['actual_time'].values))
_ = ax_cum_info_gain.plot(
    x_info_gain,
    cropped_comparison_catalog['information_gain'].values,
    '--',
    color='k',
    linewidth=2,
    label='cumulative information gain')
_ = ax_cum_info_gain.set_ylabel('cumulative information gain', color='dimgrey')
ax_cum_info_gain.tick_params(axis='y', labelcolor='dimgrey')

_ = ax_cum_info_gain.plot(
    x_info_gain,
    cropped_comparison_catalog['information_gain_conditioned'].values,
    '-.',
    color='dimgrey',
    linewidth=2,
    label='cumulative information gain - conditioned')
# _ = ax_cum_info_gain.set_ylim(-10, 290)
_ = ax_cum_info_gain.legend(loc='upper right')

# f_likelihood_overtime

## Spatial comparisons

### Plot LL difference in space

In [None]:
f_map, ax_map = plt.subplots()

comparison_catalog_test = comparison_catalog[comparison_catalog['set_name'] == 'test']
time_logical = (comparison_catalog_test['time'].values >= start_time) & (comparison_catalog_test['time'].values < end_time)


longs, lats = lon_lat_for_map_plotting(
    cropped_comparison_catalog['longitude'].values,
    cropped_comparison_catalog['latitude'].values
    )

sc = ax_map.scatter(
    longs,
    lats,
    c=cropped_comparison_catalog['log_difference'].values,
    s=2.3**cropped_comparison_catalog['magnitude'],
    alpha=0.3,
    cmap='coolwarm',
    norm=mpl.colors.TwoSlopeNorm(vmin=-2, vcenter=0, vmax=4),
    )

cb = f_map.colorbar(sc, ax=ax_map, shrink=0.5)
cb.ax.set_yscale('linear')


ax_map.set_xlim(longs.min(), longs.max())
ax_map.set_ylim(lats.min(), lats.max())
ax_map.set_aspect('equal')

### Spatial distribution compared to seismicity

In [None]:
examples_array = np.array([(v[0][0].lng, v[0][0].lat, k) for k,v in domain.test_examples.items()])
examples_array[:, 0], examples_array[:, 1] = lon_lat_for_map_plotting(
    examples_array[:, 0],
    examples_array[:, 1],
    )

mod_test_examples = {row[2]:[[geometry.Point(row[0], row[1])]] for row in examples_array}

longitude_bins = np.arange(examples_array[:, 0].min(), examples_array[:, 0].max(), 0.2)
latitude_bins = np.arange(examples_array[:, 1].min(), examples_array[:, 1].max(), 0.2)
temporal_bins = np.linspace(examples_array[:, 2].min(), examples_array[:, 2].max(), 2)


counts_in_bin, _ = np.histogramdd(examples_array, bins=(longitude_bins, latitude_bins, temporal_bins),
               density=False)


In [None]:
#--- Gaussian blur with nans

def gauss_blur_w_nans(U, sigma=0.5, truncate=4.0):
  V=U.copy()
  V[np.isnan(U)]=0
  VV=sp.ndimage.gaussian_filter(V, sigma=sigma, truncate=truncate)

  W=np.ones_like(U)
  W[np.isnan(U)]=0
  WW=sp.ndimage.gaussian_filter(W, sigma=sigma, truncate=truncate)

  return VV/WW

## Relation between spatial advantage and $m_c$

### Compute $m_c(x,y)$ map

#### function to create $m_c$ map of constant calculation radii

In [None]:
def map_mc_beta(
    sub_catalog,
    grid_spacing=0.1,
    grid_side_degrees=None,
    method = 'MAXC',
    lon_vec=None,
    lat_vec=None,
    ):

  lon_mod, lat_mod = lon_lat_for_map_plotting(sub_catalog.longitude.values, sub_catalog.latitude.values)
  sub_catalog = sub_catalog.copy()
  sub_catalog.longitude = lon_mod
  sub_catalog.latitude = lat_mod
  if lon_vec is None:
    lon_vec = np.arange(sub_catalog.longitude.min(), sub_catalog.longitude.max(), grid_spacing)
  if lat_vec is None:
    lat_vec = np.arange(sub_catalog.latitude.min(), sub_catalog.latitude.max(), grid_spacing)
  longs, lats = np.meshgrid(lon_vec, lat_vec)
  centers = [[geometry.Point(ln, lt) for (ln, lt) in zip(longs.ravel(), lats.ravel())]]
  if grid_side_degrees is None:
    grid_side_degrees = grid_spacing
  mc_beta = catalog_analysis.completeness_and_beta_in_square(
      catalog = sub_catalog,
      time_slice=slice(0, len(sub_catalog)),
      centers=centers,
      grid_side_degrees=grid_side_degrees,
      method = method,
  )

  counts_in_bin = catalog_analysis.counts_in_square(
      catalog = sub_catalog,
      time_slice=slice(0, len(sub_catalog)),
      centers=centers,
      grid_side_degrees=grid_side_degrees,
  )

  mc, beta = mc_beta[0,:, 0], mc_beta[0,:, 1]
  mc = mc.reshape(lats.shape)
  beta = beta.reshape(lats.shape)
  counts_in_bin = counts_in_bin.reshape(lats.shape)
  mc_xr = xr.DataArray(mc.reshape(lats.shape), dims=['latitude', 'longitude'], coords={'latitude':lat_vec, 'longitude':lon_vec})
  beta_xr = xr.DataArray(beta.reshape(lats.shape), dims=['latitude', 'longitude'], coords={'latitude':lat_vec, 'longitude':lon_vec})
  counts_in_bin_xr = xr.DataArray(counts_in_bin.reshape(lats.shape), dims=['latitude', 'longitude'], coords={'latitude':lat_vec, 'longitude':lon_vec})
  return mc_xr, beta_xr, counts_in_bin_xr

In [None]:
MC_BY_MINIMAL_N_EVENTS = True
MIN_EVENTS_IN_BIN = 100

lon_centers = longitude_bins[:-1] + (longitude_bins[1:] - longitude_bins[:-1])/2
lat_centers = latitude_bins[:-1] + (latitude_bins[1:] - latitude_bins[:-1])/2

time_slice = slice(
    np.where(domain.earthquakes_catalog.time==train_timestamps.min())[0][0],
    np.where(domain.earthquakes_catalog.time==train_timestamps.max())[0][0],
    )
catalog_for_mc = domain.earthquakes_catalog[time_slice]
new_lon, new_lat = lon_lat_for_map_plotting(catalog_for_mc.longitude.values, catalog_for_mc.latitude.values)
catalog_for_mc['longitude'] = new_lon
catalog_for_mc['latitude'] = new_lat

if MC_BY_MINIMAL_N_EVENTS:
  ddeg = 0.1 if MODEL_NAME=='Hauksson' else 0.5
  mc_xr, nevents_xr, radius_xr = catalog_analysis.compute_grid_of_local_completeness(
      catalog_for_mc,
      grid_spacing=ddeg,
      minimal_radius=ddeg,
      minimal_events=MIN_EVENTS_IN_BIN,
  )

else:
  mc_xr, beta_xr, counts_in_bin_xr = map_mc_beta(
      catalog_for_mc,
      grid_spacing=0.1,
      grid_side_degrees=0.2,
      lon_vec=lon_centers,
      lat_vec=lat_centers,
      )


In [None]:
example_lons, example_lats = lon_lat_for_map_plotting(examples_array[:, 0], examples_array[:, 1])

if MC_BY_MINIMAL_N_EVENTS:
  x_inds = np.searchsorted(mc_xr.longitude, example_lons)
  x_inds = np.minimum(x_inds, len(mc_xr.longitude)-1)
  y_inds = np.searchsorted(mc_xr.latitude, example_lats)
  y_inds = np.minimum(y_inds, len(mc_xr.latitude)-1)

  local_mc = mc_xr.values[y_inds, x_inds]
  local_nevents = nevents_xr.values[y_inds, x_inds]
  local_radius = radius_xr.values[y_inds, x_inds]

else:
  x_inds = np.searchsorted(longitude_bins, example_lons)
  x_inds -= 1
  x_inds = np.maximum(0, x_inds)
  x_inds = np.minimum(len(mc_xr.longitude)-1, x_inds)
  y_inds = np.searchsorted(latitude_bins , example_lats)
  y_inds -= 1
  y_inds = np.maximum(0, y_inds)
  y_inds = np.minimum(len(mc_xr.latitude)-1, y_inds)

  local_mc = mc_xr.values[y_inds, x_inds]
  local_bin_counts = counts_in_bin_xr.values[y_inds, x_inds]
  local_mc[local_bin_counts < MIN_EVENTS_IN_BIN] = np.nan


mc_spatial_total_bool = (local_mc <= labels.test_labels)


In [None]:
#@title Plot all events with their m_c

f_map_local_mc, ax_map_local_mc = plt.subplots()

sc = ax_map_local_mc.scatter(
    example_lons[mc_spatial_total_bool],
    example_lats[mc_spatial_total_bool],
    c=local_mc[mc_spatial_total_bool],
    alpha=0.3,
    cmap='coolwarm',
    )

cb = f_map_local_mc.colorbar(sc, ax=ax_map_local_mc, shrink=0.5)
cb.ax.set_yscale('linear')


ax_map_local_mc.set_xlim(example_lons.min(), example_lons.max())
ax_map_local_mc.set_ylim(example_lats.min(), example_lats.max())
ax_map_local_mc.set_aspect('equal')

In [None]:
#@title plot $m_c$ and other relevant maps
f_maps, ax_maps = plt.subplots(1, 3, figsize=(27,9))


if MC_BY_MINIMAL_N_EVENTS:
  _ = mc_xr.plot(ax=ax_maps[0], )
  ax_maps[0].set_title('$m_c$')
  _ = nevents_xr.plot(ax=ax_maps[1])
  ax_maps[1].set_title('$n\ events$')
  _ = radius_xr.plot(ax=ax_maps[2])
  ax_maps[2].set_title('$R$')
else:
  _ = mc_xr.plot(ax=ax_maps[0], )
  ax_maps[0].set_title('$m_c$')
  _ = beta_xr.plot(ax=ax_maps[1], vmax=4)
  ax_maps[1].set_title('$\\beta$')
  _ = counts_in_bin_xr.plot(ax=ax_maps[2], vmax=150)
  ax_maps[2].set_title('$counts$')

## Condition on spatial $m_c(x,y)$

In [None]:
#@title mean LL conditioned on LOCAL incompleteness:  $m\geq \tilde{m}=m_c(x,y)$

# --- set dictionary
likelihoods_and_baselines_cond_local_incompleteness = {}

# --- set model conditioned likelihoods
set_name = 'test'

for set_name in ['test']:
  mll_local_completeness = np.full_like(labels.test_labels, np.nan)
  spatial_result = np.array(
      conditioned_likelihood_model(
          local_mc,
          set_name,
          forecasts
          )
      )
  mll_local_completeness[mc_spatial_total_bool] = spatial_result
  likelihoods_and_baselines_cond_local_incompleteness[f'model_{MODEL_NAME}_likelihood_{set_name}'] = mll_local_completeness

conditioned_local_incompleteness_test = compute_conditioned_gr_variations(
    gr_models_beta_test,
    gr_models_mc_test,
    labels,
    local_mc
)
for k in conditioned_local_incompleteness_test.keys():

  mll_local_completeness = np.full_like(labels.test_labels, np.nan)
  mll_local_completeness[mc_spatial_total_bool] = conditioned_local_incompleteness_test[k]
  conditioned_local_incompleteness_test[k] = mll_local_completeness

likelihoods_and_baselines_cond_local_incompleteness.update(conditioned_local_incompleteness_test)


# --- display all
summary_df_mean_ll_cond_mc_local = create_scores_summary_df(likelihoods_and_baselines_cond_local_incompleteness, drop_nans=True, exclude_zeros=False)
summary_df_mean_ll_cond_mc_local

In [None]:
f_thresh_t_local_barplot, ax_thresh_t_local_barplot = barplot_scores(summary_df_mean_ll_cond_mc_local, MODELS_TO_PLOT, [COLOR_PER_MODEL[m] for m in MODELS_TO_PLOT])

### Add local conditioning to cropped catalog

In [None]:
model_name_likelihood_test = [k for k in likelihoods_and_baselines_cond_local_incompleteness.keys() if k.startswith('model')][0]
add_log_difference_spatial_conditioned = np.log(likelihoods_and_baselines_cond_local_incompleteness[model_name_likelihood_test]) - np.log(
    likelihoods_and_baselines_cond_local_incompleteness['train_gr_likelihood_test']
)
finite_difference_spatial_conditioned = np.copy(add_log_difference_spatial_conditioned)
finite_difference_spatial_conditioned[~np.isfinite(finite_difference_spatial_conditioned)] = 0
add_information_gain_spatial_conditioned = np.nancumsum(finite_difference_spatial_conditioned)


cropped_comparison_catalog = cropped_comparison_catalog.assign(
    log_difference_spatial_conditioned=add_log_difference_spatial_conditioned,
    information_gain_spatial_conditioned=add_information_gain_spatial_conditioned,
  )


assign_dict = {split_name_to_model_and_set(k)[0]+'_local_condition':v for k, v in likelihoods_and_baselines_cond_local_incompleteness.items()}
cropped_comparison_catalog = cropped_comparison_catalog.assign(**assign_dict)

### Add spatially conditioned curve to cumulative info plot

In [None]:
_ = ax_cum_info_gain.plot(
    x_info_gain,
    add_information_gain_spatial_conditioned,
    '-.',
    color='red',
    linewidth=2,
    label='cumulative information gain - spatially conditioned')
_ = ax_cum_info_gain.legend(loc='upper right')

f_likelihood_overtime

## Show spatial distribution of information gain

In [None]:
#--- Comute LL spatial distributions
longitude_bins = mc_xr.longitude.values
longitude_half_diff = (longitude_bins[1] - longitude_bins[0])/2
longitude_bins = np.append(longitude_bins - longitude_half_diff, longitude_bins[-1] + longitude_half_diff)
latitude_bins = mc_xr.latitude.values
latitude_half_diff = (latitude_bins[1] - latitude_bins[0])/2
latitude_bins = np.append(latitude_bins - latitude_half_diff, latitude_bins[-1] + latitude_half_diff)

test_likelihood_keys = [k for k in likelihoods_and_baselines_cond_local_incompleteness.keys() if k.endswith('_test')]
H_meanLL_dict_cond_local = {}
for k in test_likelihood_keys:
  values = likelihoods_and_baselines_cond_local_incompleteness[k]
  statistic, x_edge, y_edge, binnumber = stats.binned_statistic_2d(
      example_lons,
      example_lats,
      values=np.log(values),
      statistic=np.nanmean,
      bins=[longitude_bins, latitude_bins],
      )

  H_meanLL_dict_cond_local[k] = statistic


set_to_plot = 'test'
H_to_plot = [k+'_'+set_to_plot for k in MODELS_TO_PLOT]
model_full_name = [h_name for h_name in H_to_plot if h_name.startswith('model_')][0]



spatial_advantage_cond_local = {}
for k in H_to_plot:
  if k.startswith('model_'):
    continue
  spatial_advantage_cond_local[k] = H_meanLL_dict_cond_local[model_full_name].T - H_meanLL_dict_cond_local[k].T

In [None]:
# --- Plot differences in space

longitude_diff = longitude_bins[1] - longitude_bins[0]
latitude_diff = latitude_bins[1] - latitude_bins[0]
extent = (
    longitude_bins[0] - longitude_diff/2,
    longitude_bins[-1] + longitude_diff/2,
    latitude_bins[0] - latitude_diff/2,
    latitude_bins[-1] + latitude_diff/2,
)


vmin = -2
vmax = 3
gamma = 0.3


spatial_diff_figs = {}
for k in spatial_advantage_cond_local:
  f, ax = plt.subplots(1, 1,)
  ii0 = ax.imshow(
      spatial_advantage_cond_local[k],
      origin='lower',
      extent=extent,
      cmap='coolwarm',
      norm=mpl.colors.SymLogNorm(0.01, vmin=vmin, vmax=vmax, clip=True),
      )
  ax.set_title(f'Difference to {k}')
  cb = f.colorbar(ii0, ax=ax, extend='both')
  ax.set_aspect('equal', 'box')
  ax.set_xlim(extent[0], extent[1])
  ax.set_ylim(extent[2], extent[3])
  spatial_diff_figs[k] = f

# Convert to a regression problem

In [None]:
plot_kde = False
if plot_kde:
  sets_to_plot = ['train', 'validation', 'test']
  f_regression, ax_regression = plt.subplots(1, 1, figsize=(7, 7))
  alpha = 0.8
  for k in sets_to_plot:
    ax_regression.scatter(
        getattr(labels, f'{k}_labels'),
        probability_density_function(forecasts[k]).mean()*random_var_stretch + random_var_shift,
        alpha=0.2,
        label=k,
        )
  ax_regression.plot([MAG_THRESH, 5], [MAG_THRESH, 5], linewidth=3, color='k')
  ax_regression.set_xlabel('label')
  ax_regression.set_ylabel('mean of prediction')
  leg = ax_regression.legend()
  f_regression

# Plot resulting distributions

## Distribution for random queries

In [None]:
#---- setup data
plot_above_thresh = MAG_THRESH
m_vec = np.linspace(MAG_THRESH, 7, 500)
prob_density_inst = probability_density_function(forecasts['test'])
prob_vecs = prob_density_inst.prob((m_vec[:, None] - random_var_shift)/random_var_stretch)/random_var_stretch

test_labels_to_plot_from = labels.test_labels[labels.test_labels>=plot_above_thresh]
prob_vecs_to_plot_from = prob_vecs.numpy()[:, labels.test_labels>=plot_above_thresh]


p_for_mags = np.exp(BETA_OF_TRAIN_SET*test_labels_to_plot_from)
p_for_mags /= p_for_mags.sum()
rnd_seed = np.random.RandomState(seed=1000)
label_idxs_to_plot = np.sort(rnd_seed.choice(prob_vecs_to_plot_from.shape[1],100, replace=False, p=p_for_mags))
labels_to_plot = test_labels_to_plot_from[label_idxs_to_plot]

In [None]:
#--- setup figure
num_mags = 25
min_mag = 2
max_mag = 6.5
m_scale = np.linspace(min_mag-0.01, max_mag, num_mags)
norm_inst = plt.Normalize(min_mag, max_mag);

chosen_colormap = warn_cold_cmap
colors = chosen_colormap(np.linspace(0,1,num_mags))
colors2plot = colors[np.argmin(np.abs(test_labels_to_plot_from[label_idxs_to_plot][:,None] - m_scale[None,:]), axis=1)]


f_dist_fig, ax_dist_fig = plt.subplots(1,1,)

for idx, lbl_index in enumerate(label_idxs_to_plot):
  p = ax_dist_fig.plot(m_vec, prob_vecs_to_plot_from[:, lbl_index], alpha=0.4, color=colors2plot[idx], linewidth=4);

  add_text = False
  if add_text:
    # add text
    y_peak = prob_vecs_to_plot_from[:, lbl_index].max()
    x_peak = m_vec[np.argmax(prob_vecs_to_plot_from[:, lbl_index])]
    text = str(labels_to_plot[idx])
    txt = ax_dist_fig.text(x_peak, y_peak, text);

# plot GR train set
train_gr_curve = metrics.gr_likelihood(m_vec, BETA_OF_TRAIN_SET, MAG_THRESH)
gr_handle = ax_dist_fig.plot(m_vec, train_gr_curve, 'k--', label='train_gr_likelihood', linewidth=3)
ax_dist_fig.legend(handles=gr_handle, frameon=False)

norm_inst = plt.Normalize(min_mag, max_mag);
sm = plt.cm.ScalarMappable(cmap=chosen_colormap, norm=norm_inst);

cb = plt.colorbar(sm, ax=ax_dist_fig, label='True magnitude (label)')
ax_dist_fig.set_xlabel('magnitude')
ax_dist_fig.set_ylabel('p(magnitude)')
ax_dist_fig.set_xscale('linear')

## Marginal distribution

In [None]:
#--- Helper function
def num_to_vec(num, data):
  if isinstance(num, numbers.Number):
    return np.full_like(data, num)
  return num

In [None]:
#--- data setup

set_to_plot = 'test'
m_vec = np.linspace(MAG_THRESH, 7, 500)
marginal_norm = {}
prob_density_inst = probability_density_function(forecasts[set_to_plot])
prob_vecs = prob_density_inst.prob((m_vec[:, None] - random_var_shift)/random_var_stretch)/random_var_stretch
marginal = prob_vecs.numpy().sum(axis=1)
marginal_norm[f'model_{MODEL_NAME}_likelihood'] = marginal/np.trapz(marginal[1:], m_vec[1:])

all_benchmarks_names = [split_name_to_model_and_set(k)[0] for k in gr_models_beta.keys() if k.endswith(f'_{set_to_plot}')]
for name in all_benchmarks_names:
  if 'events_kde' in name:
    kde_pdf = np.array([kde_inst(m_vec) for kde_inst in gr_models_beta[f'{name}_{set_to_plot}']])
    marginal = kde_pdf.sum(axis=0)
  else:
    beta_name =num_to_vec(
      gr_models_beta[f'{name}_{set_to_plot}'], getattr(labels, f'{set_name}_labels')
      ).ravel()
    sampling_points_name = m_vec[:, None] - num_to_vec(
      gr_models_mc[f'{name}_{set_to_plot}'], getattr(labels, f'{set_name}_labels')
      )[None, :]
    marginal_unraveled = tfp.distributions.Exponential(
    beta_name,
    force_probs_to_zero_outside_support=True
    ).prob(sampling_points_name).numpy()
    marginal = np.nansum(marginal_unraveled, axis=1)
  marginal_norm[name] = marginal/np.trapz(marginal[1:], m_vec[1:])



In [None]:
#--- plot

bins = np.arange(MAG_THRESH, 10, 0.25)

marginal_fig, marginal_ax = plt.subplots(1,1, figsize=(7,4))
h = marginal_ax.hist(labels.test_labels, bins, alpha=0.3, density=True, label='test', color='royalblue')
h = marginal_ax.hist(labels.train_labels, bins, alpha=0.3, density=True, label='train', color='darkorange')

for model_name in MODELS_TO_PLOT:
  marginal_ax.plot(
      m_vec,
      marginal_norm[model_name],
      label=model_name,
      linewidth=3,
      color=COLOR_PER_MODEL[model_name],
      )
marginal_fig.legend(loc='outside upper right')
marginal_ax.set_xlabel('magnitude')
marginal_ax.set_ylabel('p(magnitude)')
marginal_ax.set_yscale('log')

# Convert to a binary classifier

In [None]:
# Set parameters for display
m_chosen = 4
m_thresh_vec = np.array([m_chosen])

### Utility functions

In [None]:
def auc_p_at_r(recall, precision, tp_ratio):
  sort_idx = np.argsort(recall)
  recall_sorted = recall[sort_idx]
  precision_sorted = precision[sort_idx]
  # widths = np.diff(precision)
  # heights = recall[1:]
  heights = precision_sorted[1:]
  widths = np.diff(recall_sorted)
  return (widths * heights).sum()- tp_ratio*(widths.sum())

def boolean_metrics(random_variable_instance, test_labels, shift=0, m_thresh_vec=None):
  if m_thresh_vec is None:
    m_thresh_vec = np.arange(MAG_THRESH, 8, 0.33)

  support = [0, np.inf]
  tp_fp_results = {}
  pr_results = {}
  roc_auc = {}
  tp_ratio = {}
  for (i, m_thresh) in enumerate(m_thresh_vec):
    m_thresh = int(m_thresh)
    boolean_test = test_labels >= m_thresh
    tp_ratio[m_thresh] = boolean_test.sum()/boolean_test.size

    p_val = random_variable_instance.survival_function((m_thresh - shift)/random_var_stretch).numpy()/random_var_stretch
    fpr, tpr, _ = skl_metrics.roc_curve(boolean_test, p_val);
    tp_fp_results[m_thresh] = (fpr, tpr)
    try:
      roc_auc[m_thresh] = skl_metrics.roc_auc_score(boolean_test, p_val);
    except:
      roc_auc[m_thresh] = 0
    precision, recall, _ = skl_metrics.precision_recall_curve(boolean_test, p_val);
    pr_results[m_thresh] = (precision, recall)


  return {
      'tp_fp_results': tp_fp_results,
      'pr_results': pr_results,
      'ap': {k:np.mean(pr_results[k][0][pr_results[k][1]!=0]) for k in pr_results.keys()},
      'auc_pr': {k:skl_metrics.auc(pr_results[k][1], pr_results[k][0]) for k in pr_results.keys()},
      'auc_p_at_r': {k:auc_p_at_r(pr_results[k][1], pr_results[k][0], tp_ratio[k]) for k in pr_results.keys()},
      'roc_auc': roc_auc,
      'tp_ratio': tp_ratio,
  }

def is_numeric(var):
  is_it = not hasattr(var, 'shape')
  if is_it:
    return True
  if len(var.shape)==0:
    return True
  return False

## Aggregate all classification data

In [None]:
random_variables_and_labels_and_shifts = {}
for set_name in ['train', 'validation', 'test']:
  random_variables_and_labels_and_shifts[f'model_{MODEL_NAME}_likelihood_{set_name}'] = (
    probability_density_function(forecasts[set_name]),
    getattr(labels, f'{set_name}_labels'),
    MAG_THRESH
    )

all_benchmarks_names = [split_name_to_model_and_set(k)[0] for k in gr_models_beta.keys() if k.endswith(f'_{set_name}')]
for name in all_benchmarks_names:
  for set_name in ['train', 'validation', 'test']:
    this_beta = gr_models_beta[f'{name}_{set_name}']
    this_mc = gr_models_mc[f'{name}_{set_name}']
    if 'events_kde' in name:
      # random_var = kde_random_var(this_beta, samp_points)
      pass
    elif is_numeric(this_beta):
      random_var = tfp.distributions.Exponential(
          num_to_vec(
              this_beta, getattr(labels, f'{set_name}_labels')
              ).ravel(), force_probs_to_zero_outside_support=True
          )

      random_variables_and_labels_and_shifts[f'{name}_{set_name}'] = (
          random_var,
          getattr(labels, f'{set_name}_labels'),
          gr_models_mc[f'{name}_{set_name}']
          )
    else:
      is_finite = np.isfinite(this_beta.ravel()) & np.isfinite(this_mc.ravel())
      random_var = tfp.distributions.Exponential(
          num_to_vec(
              this_beta[is_finite], getattr(labels, f'{set_name}_labels')[is_finite]
              ).ravel(), force_probs_to_zero_outside_support=True
          )

      random_variables_and_labels_and_shifts[f'{name}_{set_name}'] = (
          random_var,
          getattr(labels, f'{set_name}_labels')[is_finite],
          gr_models_mc[f'{name}_{set_name}'][is_finite]
          )

boolean_metrics_dict = {k: boolean_metrics(v[0], v[1], shift=v[2], m_thresh_vec=m_thresh_vec) for k, v in random_variables_and_labels_and_shifts.items()}

## ROC curve

In [None]:
def plot_ROC(all_booleans_dict, models_to_plot, plot_colors_dict, ax, display_names_list, m_thresh=4, set_name='test', is_first=False):
  plot_colors_set_dict = {}
  models_to_plot_set = []
  for m in models_to_plot:
    model_name = [k for k in all_booleans_dict.keys() if k.startswith(m) and k.endswith(set_name)][0]
    models_to_plot_set.append(model_name)
    plot_colors_set_dict[model_name] = plot_colors_dict[m]

  assert (m_thresh in all_booleans_dict[models_to_plot_set[0]]['tp_fp_results'].keys()), 'm_thresh not found in keys'

  #--- iterate on models and benchmarks
  for model_benchmark, v in all_booleans_dict.items():
    if model_benchmark not in models_to_plot_set:
      continue
    print(f'model_benchmark: {model_benchmark}')

    roc_auc = v['roc_auc']
    p = ax.plot(
        v['tp_fp_results'][m_thresh][0], v['tp_fp_results'][m_thresh][1],
        label=f'{model_benchmark} (auc={roc_auc[m_thresh]:.2f})',
        color=plot_colors_set_dict[model_benchmark],
        )
    p = ax.plot([0, 1], [0, 1], 'k--', linewidth=3);
    _ = ax.set_xticks([0, 1], [0, 1])
    _ = ax.set_yticks([0, 1], [0, 1])

    #--- AUC
    order_index = dict(zip(models_to_plot_set, range(len(models_to_plot_set))))
    axins = ax.inset_axes(
        [0.85, 0.05, 0.1, 0.5],
        xlim=(-1, 1), ylim=(-0.5, len(MODELS_TO_PLOT)+0.5), xticks=[], yticks=[])
    axins.axis('off')

    counter = 0
    for model_benchmark, v in all_booleans_dict.items():
      if model_benchmark not in models_to_plot_set:
        continue
      roc_auc = v['roc_auc'][m_thresh]
      y0 = len(models_to_plot_set) - 1*order_index[model_benchmark] -1.2
      print_string = display_names_list[counter] + ' ' + f'{roc_auc:.2f}'
      ha = 'right'
      x0 = 1
      axins.text(x0, y0, print_string, fontsize=6, color=plot_colors_set_dict[model_benchmark], ha=ha, va='center')
      counter += 1
    axins.text(x0, len(models_to_plot_set), 'AUC', fontsize=6, weight='heavy', ha=ha, va='center')
  ax.set_ylabel('Recall ' + r'$( \frac{tp}{tp + fn} )$',)
  ax.set_xlabel(' False positive rate ' + r'$( \frac{fp}{fp + tn} )$',)

In [None]:
f, ax = plt.subplots(1,1)
plot_ROC(boolean_metrics_dict, MODELS_TO_PLOT, COLOR_PER_MODEL, ax, MODELS_TO_PLOT, m_thresh=4, set_name='test', is_first=True)

## P@R curve

In [None]:
DELTA = 1e-2  # 3e-3

def precision_at_recall(precision, recall, delta):
  thresholds = np.arange(0, 1, delta)
  p_at_r = []
  for t in thresholds:
     p_at_r.append(np.max(precision[np.where(recall > t)]))
  p_at_r = np.array(p_at_r)
  return p_at_r

def plot_PR(all_booleans_dict, models_to_plot, plot_colors_dict, ax, m_thresh=4, set_name='test'):
  plot_colors_set_dict = {}
  models_to_plot_set = []
  for m in models_to_plot:
    model_name = [k for k in all_booleans_dict.keys() if k.startswith(m) and k.endswith(set_name)][0]
    models_to_plot_set.append(model_name)
    plot_colors_set_dict[model_name] = plot_colors_dict[m]

  assert (m_thresh in all_booleans_dict[models_to_plot_set[0]]['tp_fp_results'].keys()), 'm_thresh not found in keys'

  #--- iterate on models and benchmarks
  for model_benchmark, v in all_booleans_dict.items():
    if model_benchmark not in models_to_plot_set:
      continue
    print(f'model_benchmark: {model_benchmark}')


    recall = v['pr_results'][m_thresh][1]
    precision = v['pr_results'][m_thresh][0]
    delta = DELTA
    prec_at_recall = precision_at_recall(precision, recall, delta)
    p = ax.plot(
        np.arange(0, 1, delta),
        prec_at_recall,
        color=plot_colors_set_dict[model_benchmark],
        )
    _ = ax.set_xticks([0, 1], [0, 1])
    _ = ax.set_yticks([0, 1], [0, 1])

    #--- AUC
    order_index = dict(zip(models_to_plot_set, range(len(models_to_plot_set))))
    axins = ax.inset_axes(
        [0.85, 0.42, 0.1, 0.5],
        xlim=(-1, 1), ylim=(-0.5, len(MODELS_TO_PLOT)+0.5), xticks=[], yticks=[])
    axins.axis('off')

    for model_benchmark, v in all_booleans_dict.items():
      if model_benchmark not in models_to_plot_set:
        continue
      ap = v['ap'][m_thresh]
      auc_par = v['auc_p_at_r'][m_thresh]
      y0 = len(models_to_plot_set) - 1*order_index[model_benchmark] -1.2
      print_string = model_benchmark + ' ' + f'{auc_par:.2f}'
      axins.text(1, y0, print_string, fontsize=6, color=plot_colors_set_dict[model_benchmark], ha='right', va='center',)
    axins.text(1, len(models_to_plot_set), 'AUC', fontsize=6, weight='heavy', ha='right', va='center')


  ax.set_ylabel('Precision ' + r'$( \frac{tp}{tp + fp} )$',)
  ax.set_xlabel('Recall  ' + r'$( \frac{tp}{tp + fn} )$',)

In [None]:
f, ax = plt.subplots(1,1)
plot_PR(boolean_metrics_dict, MODELS_TO_PLOT, COLOR_PER_MODEL, ax, m_thresh=4, set_name='test')