In [None]:
### Init ###

# Packages
from typing import Callable, Any, Dict, List

import numpy as np
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset, DatasetDict
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

import os, json

from multiprocessing import Pool

import torch

import cartopy.crs as ccrs
import cartopy.feature as cfeature

# Constants
earthquake_prompt_features = ["PLACE", "DATE", "LATITUDE", "LONGITUDE", "DEPTH"]
earthquake_prompt_template = "on DATE utc, an earthquake struck on PLACE. the epicenter was located at latitude LATITUDE, longitude LONGITUDE, with a depth of DEPTH km beneath the earth's surface."

earthquake_place = "earth"
earthquake_date = "2023-01-01 00:00:00"
earthquake_depth = 50

earthquake_prompt_template_heatmap = earthquake_prompt_template.replace("PLACE", earthquake_place).replace("DATE", earthquake_date).replace("DEPTH", str(earthquake_depth))
earthquake_prompt_features_heatmap = ["LATITUDE", "LONGITUDE"]

step_size = 4

min_longitude = -180
max_longitude = 180

min_latitude = -90
max_latitude = 90

longitude_grid, latitude_grid = np.meshgrid(
    np.arange(min_longitude, max_longitude + step_size, step_size),
    np.arange(min_latitude, max_latitude + step_size, step_size))

batch_size = 8

num_proc = os.cpu_count()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

shuffle_seed = 42

# Datasets paths
datasets_paths = []

# datasets_paths.append("Datasets/Earthquakes-180d-filtered.csv") # Earthquakes-180d Dataset
datasets_paths.append("Datasets/Earthquakes-1990-2023-filtered.csv") # Earthquakes-1990-2023 Dataset

datasets_subsets_sizes = {}

# datasets_subsets_sizes[datasets_paths[0]] = {"18K": (18000, 4047)}
# datasets_subsets_sizes[datasets_paths[1]] = {"1M": (), "2M": (), "3M": ()}
datasets_subsets_sizes[datasets_paths[0]] = {"1M": (int(1e6), 225000)}

heatmap_dataset_path = "Datasets/Heatmap.csv"
heatmap_dataset_prompts_path = heatmap_dataset_path.replace(".csv", "-prompts.csv")

# Models paths
models_paths = []

models_paths.append("distilbert/distilbert-base-uncased") # Distilbert-base-uncased: 67M params
# models_paths.append("FacebookAI/roberta-base") # Roberta-base: 110M params
# models_paths.append("google-bert/bert-base-uncased") # Bert-base-uncased: 110M params
# models_paths.append("google-bert/bert-large-uncased") # Bert-large-uncased: 340M params

trained_models_paths = {dataset_path: {model_path: {dataset_subset_size_name:
                        f"{dataset_path.replace("Datasets/", "Models/").replace("-filtered.csv", "-prompts-tokenized")}/{model_path.lower()}-{dataset_subset_size_name}/checkpoint-{dataset_subset_size[1]}"
                        for (dataset_subset_size_name, dataset_subset_size) in datasets_subsets_sizes[dataset_path].items()}
                        for model_path in models_paths} for dataset_path in datasets_paths}

heatmap_dataset_prompts = None

In [None]:
### Methods ###

def load_dataset_from_file(dataset_path: str):
    return load_dataset(dataset_path.split(".")[-1], data_files = dataset_path)["train"]

def load_model(model_path: str):
    return AutoModelForSequenceClassification.from_pretrained(model_path, num_labels = 1)

def create_dataset(dataset, dataset_path: str, create_dataset: Callable, create_dataset_params: Dict[str, Any],
                   load_dataset: bool = True, save_dataset: bool = True):
    print(f"Start of creation of dataset ({dataset_path})")
    
    # Load dataset
    if load_dataset: dataset = load_dataset_from_file(dataset)

    # Create dataset
    dataset = create_dataset(dataset, **create_dataset_params)

    # Save dataset
    if save_dataset:
        if dataset_path.endswith(".csv"): dataset.to_csv(dataset_path)
        elif dataset_path.endswith(".parquet"): dataset.to_parquet(dataset_path)

    print(f"End of creation of dataset ({dataset_path})")

    return dataset

def create_prompts_dataset(dataset, prompt_template: str, prompt_features: List[str]):
    
    # Packages
    from functools import reduce

    remove_features = dataset.column_names

    def create_prompt(instance):
        features = {prompt_feature: str(instance[prompt_feature.lower()]).lower() for prompt_feature in prompt_features}
        prompt = reduce(lambda prompt, feature: prompt.replace(*feature, 1), features.items(), prompt_template)
        
        return {"prompt": prompt}
    
    dataset_prompts = dataset.map(create_prompt, num_proc = num_proc)
    dataset_prompts = dataset_prompts.remove_columns(remove_features)

    return dataset_prompts

def create_tokenized_dataset(dataset, tokenizer):
    return dataset.map(lambda instance: tokenizer(instance["prompt"], padding = "max_length", truncation = True), batched = True, num_proc = num_proc)

def create_subset(dataset, subset_size: int):
    # Sample dataset
    return dataset.shuffle(seed = shuffle_seed).select(range(subset_size if subset_size <= len(dataset) else len(dataset)))

def print_model_error_visualization(dataset_path: str, model_path: str, dataset_subset_size_name: str, trained_model_path: str):
    with open(trained_model_path + "/trainer_state.json") as trainer_state:
        trainer_log_history = json.load(trainer_state)["log_history"]

    epochs_eval_mse = []
    epochs_test_mse = []

    eval_mse_errors = []
    test_mse_errors = []

    for eval in trainer_log_history:
        if "loss" in eval:
            epochs_eval_mse.append(eval["epoch"])
            eval_mse_errors.append(eval["loss"])
        else:
            epochs_test_mse.append(eval["epoch"])
            test_mse_errors.append(eval["eval_loss"])
    
    dataset_name = dataset_path.split("Datasets/", maxsplit = 1)[1].replace(".csv", "")

    plt.title(f"### {model_path} ({dataset_name + "-" + dataset_subset_size_name}) ###")
    plt.xlabel("Epochs")
    plt.ylabel("MSE error")

    plt.plot(epochs_eval_mse, eval_mse_errors, marker = "o", label = "Train")
    plt.plot(epochs_test_mse, test_mse_errors, marker = "o", label = "Test")

    plt.legend()
    plt.show()

def print_model_heatmap_visualization(dataset_path: str, model_path: str, dataset_subset_size, trained_model, model_tokenizer):

    def predict_magnitudes(batch):
        inputs = model_tokenizer(batch["prompt"], return_tensors = "pt", padding = True, truncation = True).to(device)
        with torch.no_grad():
            outputs = trained_model.to(device)(**inputs)
            magnitudes = outputs.logits.squeeze().tolist()

        return {"magnitude": magnitudes}

    magnitude_predictions = np.array(heatmap_dataset_prompts.map(predict_magnitudes, batched = True, batch_size = batch_size)["magnitude"]).reshape(longitude_grid.shape)

    dataset = create_dataset(dataset_path, dataset_path, create_subset, {"subset_size": dataset_subset_size[1][0]}, True, False)

    fig = plt.figure(figsize = (25, 12.5))
    ax = plt.axes(projection = ccrs.PlateCarree())

    ax.add_feature(cfeature.COASTLINE, linewidth = 0.8)
    ax.add_feature(cfeature.BORDERS, linewidth = 0.5)
    ax.add_feature(cfeature.LAND, facecolor = "lightgray")

    dataset_name = dataset_path.split("Datasets/", maxsplit = 1)[1].replace(".csv", "")

    plt.title(f"### {model_path} ({dataset_name + "-" + dataset_subset_size[0]}) ###", fontsize = 20)
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")

    img = ax.imshow(magnitude_predictions, origin = "lower", extent = [min_longitude, max_longitude, min_latitude, max_latitude], transform = ccrs.PlateCarree(), cmap = "coolwarm", alpha = 0.6)

    cbar = plt.colorbar(img, ax = ax, orientation = "vertical")
    cbar.ax.tick_params(labelsize = 15, direction = "in", pad = -30)

    ax.scatter(dataset["longitude"], dataset["latitude"], c = dataset["magnitude"], cmap = "coolwarm", transform = ccrs.PlateCarree(), edgecolors = "black")

    plt.show()

In [None]:
### Load trained models ###

models_tokenizers = {model_path: AutoTokenizer.from_pretrained(model_path) for model_path in models_paths}

trained_models = {dataset_path: {model_path: {dataset_subset_size_name:
                    load_model(trained_models_paths[dataset_path][model_path][dataset_subset_size_name])
                    for dataset_subset_size_name in datasets_subsets_sizes[dataset_path].keys()}
                    for model_path in models_paths} for dataset_path in datasets_paths}

In [None]:
### MSE error models visualization ###

for dataset_path in datasets_paths:
    for model_path in models_paths:
        for dataset_subset_size_name in datasets_subsets_sizes[dataset_path].keys():
            print_model_error_visualization(dataset_path, model_path, dataset_subset_size_name, trained_models_paths[dataset_path][model_path][dataset_subset_size_name])

In [None]:
### Create heatmap dataset ###

heatmap_dataset = Dataset.from_dict({"longitude": longitude_grid.ravel(), "latitude": latitude_grid.ravel()})

heatmap_dataset.to_csv(heatmap_dataset_path)

In [None]:
### Load heatmap dataset ###

heatmap_dataset = load_dataset_from_file(heatmap_dataset_path)

In [None]:
### Create heatmap prompts dataset ###

heatmap_dataset_prompts_args = {"prompt_template": earthquake_prompt_template_heatmap, "prompt_features": earthquake_prompt_features_heatmap}

heatmap_dataset_prompts = create_dataset(heatmap_dataset, heatmap_dataset_prompts_path, create_prompts_dataset, heatmap_dataset_prompts_args, False, True)

In [None]:
### Load heatmap prompts dataset ###

heatmap_dataset_prompts = load_dataset_from_file(heatmap_dataset_prompts_path)

In [None]:
### Heatmap earthquakes risk models visualization ###

for dataset_path in datasets_paths:
    for model_path in models_paths:
        for dataset_subset_size in datasets_subsets_sizes[dataset_path].items():
            print_model_heatmap_visualization(dataset_path, model_path, dataset_subset_size, trained_models[dataset_path][model_path][dataset_subset_size[0]], models_tokenizers[model_path])