In [None]:
import os

import matplotlib.pyplot as plt
import seisbench.models as sbm
import torch

from dataset_creation.utils import remove_traces_not_to_use
from evaluation.noisy_dataset_evaluation import find_large_error_traces, eval_mse, batch_max_values_and_residuals
from utils.common import load_dataset_and_labels, load_pretrained_model, assert_path_exists

# Model Evaluation

## Requirements and Configuration

What is your user root directory?  (`/home/<username>/` on linux machines e.g.)

In [None]:
USER_ROOT_DIR='/home/moshe/'
assert_path_exists(path_str=USER_ROOT_DIR, name='USER_ROOT_DIR')
USER_ROOT_DIR

What is the root folder of your datasets?

In [None]:
DATASETS_ROOT_DIR= os.path.join(USER_ROOT_DIR,'datasets/GFZ/')
assert_path_exists(path_str=DATASETS_ROOT_DIR, name='DATASETS_ROOT_DIR')
DATASETS_ROOT_DIR

In [None]:
# Possible values
DATASETS_ORIGINS = ['ethz', 'geofon']
SBM_CLASSES= [sbm.PhaseNet, sbm.EQTransformer]
MODEL_TO_NUM_SAMPLES = {sbm.EQTransformer:6000, sbm.PhaseNet: 3001}

In [None]:
dataset_origin = 'ethz'
assert dataset_origin in DATASETS_ORIGINS, f'Expected dataset one of {DATASETS_ORIGINS}. Got {dataset_origin}.'

In [None]:
SBM_CLASS= sbm.PhaseNet
assert SBM_CLASS in SBM_CLASSES
SBM_CLASS

In [None]:
NUM_SAMPLES=MODEL_TO_NUM_SAMPLES[SBM_CLASS]
NUM_SAMPLES       # Trace sample length - If the dataset is built for phasenet: 3001 If it is for EQTransformer: 6000

In [None]:
SAMPLE_RATE=100                                    # Sampling Rate - PhaseNet and EQTransformer expect 100Hz
LARGE_ERROR_THRESHOLD_SECONDS=1                    # Onset prediction above this value (seconds) shall be considered large for metrics
SYNTHESIZED_SNR_LIST=list(range(2,11))             # SNR levels of the synthetic data used`
NUM_OF_ORIGINAL_TRACES = 2100                      # How many original traces to use for the noisy dataset - use slice from the start

In [None]:
LARGE_ERROR_THRESHOLD_SAMPLES=LARGE_ERROR_THRESHOLD_SECONDS*SAMPLE_RATE
print(f'A residual of more than {LARGE_ERROR_THRESHOLD_SAMPLES} samples is considered large error')

In [None]:
SNR_CALC_STRATEGY_STR_ALTERNATIVES = ['energy_ratio', 'max_amplitude_vs_rms_ratio']
SNR_CALC_STRATEGY_STR = 'energy_ratio'

Browse The path of the dataset the model is going to be evaluated on.
The original data and synthetic noised data.

In [None]:
DATASET_PATH= os.path.join(DATASETS_ROOT_DIR, f'noisy_datasets/{dataset_origin}_{NUM_SAMPLES}_sample_joachim_noises_{SNR_CALC_STRATEGY_STR}_snr/')
assert_path_exists(path_str=DATASET_PATH, name='DATASET_PATH')
DATASET_PATH

In [None]:
NOISY_DATA_PATH_LIST = [os.path.join(DATASET_PATH, f'noisy_dataset_snr_{synthesized_snr}') for synthesized_snr in SYNTHESIZED_SNR_LIST ]
assert_path_exists(path_str=DATASET_PATH, name='DATASET_PATH')
for ndp in NOISY_DATA_PATH_LIST:
    assert_path_exists(path_str=ndp, name='NOISY_DATA_PATH')
NOISY_DATA_PATH_LIST

### Load the Synthetic Noisy Traces

In [None]:
synthetic_noisy_dataset_paths= [os.path.join(ndp, 'traces.pt') for ndp in NOISY_DATA_PATH_LIST]

synthetic_noisy_labels_paths= [os.path.join(ndp, 'labels.pt') for ndp in NOISY_DATA_PATH_LIST]

augmented_noises_paths = [os.path.join(ndp, 'full_noise_traces.pt') for ndp in NOISY_DATA_PATH_LIST]

factors_paths = [os.path.join(ndp, 'factors.pt') for ndp in NOISY_DATA_PATH_LIST]
for synthetic_noisy_dataset_path, synthetic_noisy_labels_path, augmented_noises_path, factors_path in zip(synthetic_noisy_dataset_paths, synthetic_noisy_labels_paths,augmented_noises_paths, factors_paths):
    assert_path_exists(path_str=synthetic_noisy_dataset_path, name='synthetic_noisy_dataset_path')
    assert_path_exists(path_str=synthetic_noisy_labels_path, name='synthetic_noisy_labels_path')
    assert_path_exists(path_str=augmented_noises_path, name='augmented_noises_path')
    assert_path_exists(path_str=factors_path, name='factors_path')

print('Synthetic Data will be stored in:')
print(synthetic_noisy_dataset_paths)
print(synthetic_noisy_labels_paths)
print(augmented_noises_paths)
print(factors_paths)

In [None]:
noisy_traces_list, noisy_traces_labels_list = [], []
for synthetic_noisy_dataset_path, synthetic_noisy_labels_path in zip(synthetic_noisy_dataset_paths, synthetic_noisy_labels_paths):
    synthetic_noisy_dataset, synthetic_noisy_labels = load_dataset_and_labels(dataset_path=synthetic_noisy_dataset_path, labels_path=synthetic_noisy_labels_path)
    synthetic_noisy_dataset, synthetic_noisy_labels = synthetic_noisy_dataset.float(), synthetic_noisy_labels.float()

    assert NUM_SAMPLES == synthetic_noisy_dataset.shape[-1], f'Expected Dataset contain {NUM_SAMPLES} samples. Got {synthetic_noisy_dataset.shape[-1]}'
    assert synthetic_noisy_dataset.shape[0] == synthetic_noisy_labels.shape[0], f'Expected Dataset contain label for each trace. Got {synthetic_noisy_dataset.shape[0]} traces and {synthetic_noisy_labels.shape[0]} labels'

    print(f'The loaded dataset has {synthetic_noisy_dataset.shape[0]} traces')
    print(f'Each has {synthetic_noisy_dataset.shape[1]} channels of {synthetic_noisy_dataset.shape[2]} samples.')
    print(f'Each entry is of type {synthetic_noisy_dataset.dtype}')

    print(f'The loaded labels have {synthetic_noisy_labels.shape[0]} labels.')
    print(f'Each entry is of type {synthetic_noisy_labels.dtype}')

    noisy_traces_list.append(synthetic_noisy_dataset)
    noisy_traces_labels_list.append(synthetic_noisy_labels)

In [None]:
noisy_traces_list, noisy_traces_labels_list, total_indicies_to_use = remove_traces_not_to_use(noisy_traces_list=noisy_traces_list, noisy_labels_list=noisy_traces_labels_list, noisy_data_path_list=NOISY_DATA_PATH_LIST, num_of_original_traces=NUM_OF_ORIGINAL_TRACES)
[(t.shape,l.shape) for (t,l) in zip(noisy_traces_list, noisy_traces_labels_list)], f'{len(total_indicies_to_use)} traces each snr level'

### Load the original High SNR Traces

Load a dataset of high SNR traces taken from the original ETHZ\GEOFON dataset.

In [None]:
dataset_traces_path = os.path.join(DATASET_PATH, 'original_dataset.pt')
dataset_labels_path = os.path.join(DATASET_PATH, 'original_labels.pt')
original_dataset = torch.load(dataset_traces_path)[total_indicies_to_use]  #[:num_traces]
original_labels = torch.load(dataset_labels_path)[total_indicies_to_use]    #[:num_traces]

num_original_traces = original_dataset.shape[0]
num_original_labels = original_labels.shape[0]
num_original_samples = original_dataset.shape[-1]

assert num_original_labels == num_original_traces, f'Expected traces equal num labels.Got {num_original_traces} traces and {num_original_labels} labels'
assert num_original_samples == NUM_SAMPLES, f'Expected {NUM_SAMPLES} in each trace. Got {num_original_samples}.'

print(f'Loaded {num_original_traces} traces and corresponding labels.')

### Load a Pretrained Model

In [None]:
pretrained_model = load_pretrained_model(model_class=SBM_CLASS, dataset_trained_on=dataset_origin)
pretrained_model.eval()

## Evaluate the Pretrained Model on  the Loaded Datasets

### Root-Mean-Squared-Errors (RMSE)

The RMSE is the most common metric. Note that it is highly affected by outliars.

In [None]:
noisy_datasets_metric_result =[eval_mse(model = pretrained_model, traces=synthetic_noisy_dataset, labels=synthetic_noisy_labels,batch_size=32) for (synthetic_noisy_dataset, synthetic_noisy_labels) in zip(noisy_traces_list, noisy_traces_labels_list)]
original_dataset_metric_result =  eval_mse(model = pretrained_model, traces=original_dataset, labels=original_labels,  batch_size=32)
noisy_datasets_metric_result, original_dataset_metric_result

In [None]:
noisy_datasets_metric_result_trimmed =[eval_mse(model = pretrained_model, traces=synthetic_noisy_dataset, labels=synthetic_noisy_labels, ignore_events_above_samples_threshold=500 ,batch_size=32) for (synthetic_noisy_dataset, synthetic_noisy_labels) in zip(noisy_traces_list, noisy_traces_labels_list)]
original_dataset_metric_result_trimmed =  eval_mse(model = pretrained_model, traces=original_dataset, labels=original_labels, ignore_events_above_samples_threshold=500, batch_size=32)
noisy_datasets_metric_result, original_dataset_metric_result

In [None]:
fig, (ax,ax_trimmed) = plt.subplots(1,2,figsize=(10,8), sharey='all' )
plt.suptitle(f'RMSE(samples) vs. SNR(dB)  (dashed line - original data RMSE) SNR Estimated using {SNR_CALC_STRATEGY_STR}')
ax.set_title('All residuals included')
ax.plot(range(2,11),  noisy_datasets_metric_result);
ax.hlines(y=original_dataset_metric_result, xmin=2, xmax=10, linestyles='dashed');
ax_trimmed.set_title('Residuals more than 500 samples omitted')
ax_trimmed.plot(range(2,11),  noisy_datasets_metric_result_trimmed);
ax_trimmed.hlines(y=original_dataset_metric_result_trimmed, xmin=2, xmax=10, linestyles='dashed');

### Large Errors

Evaluate both the original traces and the synthetic noisy traces and save the large error traces - the traces where the model had picking error (residual) larger than  predefined threshold (1 second by default - 100 samples)

In [None]:
large_error_traces_index_list_original_dataset = find_large_error_traces(dataset=original_dataset, model=pretrained_model.float(), labels=original_labels, threshold_samples=LARGE_ERROR_THRESHOLD_SAMPLES)


large_error_traces_index_list_synthetic_noisy_datasets = [find_large_error_traces(dataset=synthetic_noisy_dataset, model=pretrained_model.float(), labels=synthetic_noisy_labels, threshold_samples=LARGE_ERROR_THRESHOLD_SAMPLES) for (synthetic_noisy_dataset, synthetic_noisy_labels) in zip(noisy_traces_list, noisy_traces_labels_list)]

In [None]:
dataset_lens = [int(len(large_error_traces_index_list_synthetic_noisy_dataset)) for large_error_traces_index_list_synthetic_noisy_dataset in large_error_traces_index_list_synthetic_noisy_datasets]

print(f'There are {dataset_lens} large errors in the noisy datasets')

dataset_len_original = int(len(large_error_traces_index_list_original_dataset))

print(f'There are {dataset_len_original} large errors in the original dataset')

In [None]:
large_error_traces_index_list_original_dataset_trimmed = find_large_error_traces(dataset=original_dataset, model=pretrained_model.float(), labels=original_labels, threshold_samples=LARGE_ERROR_THRESHOLD_SAMPLES,  ignore_errors_larger_than_samples=500)


large_error_traces_index_list_synthetic_noisy_datasets_trimmed = [find_large_error_traces(dataset=synthetic_noisy_dataset, model=pretrained_model.float(), labels=synthetic_noisy_labels, threshold_samples=LARGE_ERROR_THRESHOLD_SAMPLES, ignore_errors_larger_than_samples=500) for (synthetic_noisy_dataset, synthetic_noisy_labels) in zip(noisy_traces_list, noisy_traces_labels_list)]

In [None]:
dataset_trimmed_lens = [int(len(large_error_traces_index_list_synthetic_noisy_dataset_trimmed)) for large_error_traces_index_list_synthetic_noisy_dataset_trimmed in large_error_traces_index_list_synthetic_noisy_datasets_trimmed]

print(f'There are {dataset_trimmed_lens} large errors in the noisy datasets')

dataset_trimmed_len_original = int(len(large_error_traces_index_list_original_dataset_trimmed))

print(f'There are {dataset_trimmed_len_original} large errors in the original dataset')

In [None]:
fig, (ax,ax_trimmed) = plt.subplots(1,2,figsize=(20,8), sharey='all')
plt.suptitle(f'Large Error Count vs. SNR (dashed line - original data large error count) SNR Estimated using {SNR_CALC_STRATEGY_STR}')
ax.set_title('All Errors Included')
ax.plot(range(2,11),  dataset_lens);
ax.hlines(y=dataset_len_original, xmin=2, xmax=10, linestyles='dashed');
ax_trimmed.set_title('Errors Larger Than 500 samples Omitted')
ax_trimmed.plot(range(2,11),  dataset_trimmed_lens);
ax_trimmed.hlines(y=dataset_trimmed_len_original, xmin=2, xmax=10, linestyles='dashed');

## Maximum Value of Prediction Function vs. Residual for each SNR

Aggregate all prediction functions maximum values

In [None]:
max_values_list, residuals_list = [], []
# Maximum values of th noisy traces predictions
for noised_traces, noised_labels in zip(noisy_traces_list, noisy_traces_labels_list):
    max_values, residuals = batch_max_values_and_residuals(batch=noised_traces, labels=noised_labels, model=pretrained_model)
    max_values_list.append(max_values)
    residuals_list.append(residuals)
# Maximum values of the original traces
orig_max_values, orig_residuals = batch_max_values_and_residuals(batch=original_dataset, labels=original_labels, model=pretrained_model)

### Plot a scatter of All Max Values

### Compute The Mean and Standard Deviation of The Prediction Maximum Values

In [None]:
orig_max_values_mean = orig_max_values.mean()
orig_max_values_std  = orig_max_values.std()
orig_residuals_mean = orig_residuals.mean()
orig_residuals_std = orig_residuals.std()

orig_max_values.shape, orig_max_values_mean

In [None]:
max_values_mean_list = [max_values.mean() for max_values in max_values_list] + [orig_max_values_mean]
max_values_stds_list = [max_values.std() for max_values in max_values_list] + [orig_max_values_std]
residuals_means_list = [residuals.mean() for residuals in residuals_list] + [orig_residuals_mean]
residuals_stds_list = [residuals.std() for residuals in residuals_list] + [orig_residuals_std]

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(20,6))
plt.suptitle(f'Model Evaluation {str(SBM_CLASS)} on {str.upper(dataset_origin)}')
# for i in range(4):
x_value_list = list(range(2,11)) + ['orig']
axs[0].set_title(f'Prediction Function Max Value mean')
axs[0].plot(x_value_list, max_values_mean_list)
axs[1].set_title(f'Prediction Function Max Value std')
axs[1].plot(x_value_list, max_values_stds_list)
axs[2].set_title(f'Residual mean')
axs[2].plot(x_value_list, residuals_means_list)
axs[3].set_title(f'Residuals std')
axs[3].plot(x_value_list, residuals_stds_list)

fig.subplots_adjust(wspace=0.2)

## Plot Noised Example

Change the `idx` variable to plot a different example

In [None]:
idx = 7

noised_traces =[synthetic_noisy_dataset[idx] for synthetic_noisy_dataset in noisy_traces_list]
noised_labels =[synthetic_noisy_labels[idx] for synthetic_noisy_labels in noisy_traces_list]

fig, axs = plt.subplots(10,1,figsize=(10,30), sharey='all')
plt.suptitle(f'Synthetic Noised Versions Of the Same Original Trace - SNR Estimated using {SNR_CALC_STRATEGY_STR}')

axs[0].set_title('Original Trace')
axs[0].plot(original_dataset[idx,0])
axs[0].vlines(x=original_labels[idx], ymin=-1, ymax=2, linestyles='dashed')
for i in range(1,10):
    trace = noised_traces[9-i][0]
    axs[i].set_title(f'SNR {11 - i}')
    axs[i].plot(trace)
    axs[i].vlines(x=original_labels[idx], ymin=-1, ymax=2, linestyles='dashed')
fig.subplots_adjust(hspace=0.5)

In [None]:
# pretrained_model, noised_traces, noised_labels
with torch.no_grad():
    fig, axs = plt.subplots(10,1,figsize=(10,30), sharey='all')
    plt.suptitle('Model Prediction Probability')
    label = original_labels[idx]
    pred_prob = pretrained_model(original_dataset[idx].unsqueeze(dim=0)).squeeze()
    # pred_prob[:2,:(int(label)-500)] = 0
    # pred_prob[:2,(int(label)+500):] = 0
    axs[0].set_title(f'Original Trace')
    axs[0].plot(pred_prob[0])
    axs[0].vlines(x=original_labels[idx], ymin=-1, ymax=2, linestyles='dashed')
    for i in range(1,10):
        trace = noised_traces[9 - i]
        trace = trace.unsqueeze(dim=0)
        pred_prob = pretrained_model(trace).squeeze()
        # pred_prob[:2,:(int(label)-500)] = 0
        # pred_prob[:2,(int(label)+500):] = 0
        axs[i].set_title(f'SNR {11 - i}')
        axs[i].plot(pred_prob[0])
        axs[i].vlines(x=original_labels[idx], ymin=-1, ymax=2, linestyles='dashed')

fig.subplots_adjust(hspace=0.5)