# Save trained models

In [1]:
import torch
import eq
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seml
import seaborn as sns
import tempfile
import wandb

from tqdm.auto import tqdm

## Load results from the database

In [2]:
results_df = seml.get_results("eq_mle", to_data_frame=True)
# Select and rename the relevant columns
results_df = results_df[[
    "config.model_name", 
    "config.dataset_name", 
    "config.random_seed", 
    "result.wandb_url",
    "result.final_nll_train", 
    "result.final_nll_val", 
    "result.final_nll_test",
]]
results_df.columns = ["model_name", "dataset_name", "random_seed", "wandb_url", "nll_train", "nll_val", "nll_test"]
# Set random_seed for ETAS to zero
results_df.fillna(0.0, inplace=True)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

## Save trained models to disk

In [3]:
import shutil
from pathlib import Path

def download_model_to_disk(model_name, url, filename):
    filename = Path(filename)
    save_dir = filename.parents[0]
    api = wandb.Api()
    model_class = getattr(eq.models, model_name)
    run = api.run(url)
    best_model = [file for file in run.files() if file.name == "best_model.ckpt"][0]
    # Download to target directory and rename
    download_path = best_model.download(save_dir)
    shutil.move(download_path.name, filename)
    # Load the model
    return model_class.load_from_checkpoint(filename)

In [4]:
model_save_dir = Path(eq.__file__).parents[1] / "trained_models"

# We save the NTPP model with seed 0 to disk
for idx, row in results_df.query("random_seed == 0").iterrows():
    save_path = model_save_dir / f"{row['dataset_name']}_{row['model_name']}.ckpt"
    print(f"Saving to {save_path}")
    model = download_model_to_disk(model_name=row["model_name"], url=row["wandb_url"], filename=save_path)

Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/QTMSanJacinto_ETAS.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/QTMSaltonSea_ETAS.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/SCEDC_ETAS.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/White_ETAS.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/QTMSanJacinto_RecurrentTPP.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/QTMSaltonSea_RecurrentTPP.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/SCEDC_RecurrentTPP.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/White_RecurrentTPP.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained_models/ETAS_MultiCatalog_ETAS.ckpt
Saving to /nfs/homedirs/shchur/research/earthquake-ntpp-release/trained

## Save NLL results

In [5]:
results_df.sort_values(by=["dataset_name", "model_name"], inplace=True)

In [6]:
results_dir = Path(eq.__file__).parents[1] / "results"
results_df.to_csv(results_dir / "nll_real_world.csv")