In [None]:
# Remove input cells at runtime (nbsphinx)
import IPython.core.display as d
d.display_html('<script>jQuery(function() {if (jQuery("body.notebook_app").length == 0) { jQuery(".input_area").toggle(); jQuery(".prompt").toggle();}});</script>', raw=True)

# Particle classification (MODEL)

**Recommended datasample(s):** model file, train and test data produced with ``protopipe-MODEL``

**Data level(s):** DL1b (telescope-wise image parameters) + DL2 (shower geometry and estimated energy)

**Description:**

Test the performance of the trained model **before** use it to estimate the particle type of DL2 events.  
In a *protopipe* analysis part of the TRAINING sample is used for *testing* the models to get some preliminary diagnostics (i.e. before launching the much heavier DL2 production).  
Note that this notebook shows a camera-wise preliminary diagnostics (since one model is produced per-camera): this means that the model output considered here is the _telescope-wise_ quantity and not the _event-wise_ one which is instead benchmarked at a subsequent step.  
Settings and setup of the plots are done using the same configuration file used for training the model.

**Requirements and steps to reproduce:**

- produce the model with ``protopipe-MODEL``

- Execute the notebook ``protopipe-BENCHMARK``,

``protopipe-BENCHMARK launch --config_file configs/benchmarks.yaml -n TRAINING/benchmarks_MODELS_classification``

To obtain the list of all available parameters add ``--help-notebook``.

**Development and testing:**  

As with any other part of _protopipe_ and being part of the official repository, this notebook can be further developed by any interested contributor.   
The execution of this notebook is not currently automatic, it must be done locally by the user _before_ pushing a pull-request.  
Please, strip the output before pushing.

## Table of contents
* [Feature importance](#Feature-importance)
* [Feature distributions](#Feature-distributions)
* [Boosted Decision Tree Error rate](#Boosted-Decision-Tree-Error-rate)
* [Model output](#Model-output)
* [Energy-dependent distributions](#Energy-dependent-distributions)
* [Energy-dependent ROC curves](#ROC-curve-variation-on-test-sample)
* [AUC as a function of reconstructed energy](#AUC-as-a-function-of-reconstructed-energy)
* [Precision-Recall](#Precision-Recall)

## Imports
[back to top](#Table-of-contents)

In [None]:
import glob
import gzip
import pickle
from pathlib import Path

import joblib
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.style as style
import numpy as np
import pandas as pd
import yaml
from cycler import cycler
from matplotlib.pyplot import rc
from scipy.optimize import curve_fit
from sklearn.metrics import auc, roc_curve

plt.rcParams.update({"figure.max_open_warning": 0})

from protopipe.benchmarks.operations import get_evt_subarray_model_output
from protopipe.benchmarks.plot import (
    BoostedDecisionTreeDiagnostic,
    ClassifierDiagnostic,
    plot_distributions,
    plot_hist,
    plot_roc_curve,
)
from protopipe.benchmarks.utils import string_to_boolean
from protopipe.pipeline.io import load_config, load_obj

## Load models
[back to top](#Table-of-contents)

In [None]:
analyses_directory = None
analysis_name = None
analysis_name_2 = None
model_configuration_filename = None  # Name of the configuration file of the model
output_directory = Path.cwd()  # default output directory for plots
use_seaborn = False

In [None]:
# Handle boolean variables (papermill reads them as strings)
use_seaborn = string_to_boolean([use_seaborn])

In [None]:
# Check that the model configuration file has been defined
# either from the CLI of from the benchmarks configuration file (default)
if model_configuration_filename is None:
    try:
        model_configuration_filename = model_configuration_filenames["classification"]
    except KeyError:
        raise ValueError("The name of the configuration file is undefined.")

In [None]:
analysis_configuration_path = (
    Path(analyses_directory) / analysis_name / Path("configs/analysis.yaml")
)
model_configuration_path = (
    Path(analyses_directory) / analysis_name / "configs" / model_configuration_filename
)
input_directory = (
    Path(analyses_directory)
    / analysis_name
    / Path("estimators/gamma_hadron_classifier")
)

In [None]:
# Load configuration files
ana_cfg = load_config(analysis_configuration_path)
cfg = load_config(model_configuration_path)

# Get info from configs
model_type = "classifier"
method_name = cfg["Method"]["name"].split(".")[-1]

In [None]:
cameras = [
    model.split("/")[-1].split("_")[1]
    for model in glob.glob(f"{input_directory}/{model_type}*.pkl.gz")
]
data = {
    camera: dict.fromkeys(["model", "data_scikit", "data_train", "data_test"])
    for camera in cameras
}

for camera in cameras:

    data[camera]["data_scikit"] = load_obj(
        glob.glob(
            f"{input_directory}/data_scikit_{model_type}_{method_name}_{camera}.pkl.gz"
        )[0]
    )
    data[camera]["data_train"] = pd.read_pickle(
        glob.glob(
            f"{input_directory}/data_train_{model_type}_{method_name}_{camera}.pkl.gz"
        )[0]
    )
    data[camera]["data_test"] = pd.read_pickle(
        glob.glob(
            f"{input_directory}/data_test_{model_type}_{method_name}_{camera}.pkl.gz"
        )[0]
    )

    modelName = f"{model_type}_{camera}_{method_name}.pkl.gz"
    data[camera]["model"] = joblib.load(
        glob.glob(f"{input_directory}/{model_type}_{camera}_{method_name}.pkl.gz")[0]
    )

## Settings and setup
[back to top](#Table-of-contents)

In [None]:
try:
    if cfg["Method"]["use_proba"] is True:
        output_model_name = "gammaness"
    else:
        output_model_name = "score"
except KeyError:
    output_model_name = "gammaness"

In [None]:
# Energy (both true and reconstructed) [TeV]
nbins = cfg["Diagnostic"]["energy"]["nbins"]

energy_edges = np.logspace(
    np.log10(cfg["Diagnostic"]["energy"]["min"]),
    np.log10(cfg["Diagnostic"]["energy"]["max"]),
    nbins + 1,
    True,
)

In [None]:
# Parameters for energy variation
cut_list = [
    "reco_energy >= {:.2f} and reco_energy <= {:.2f}".format(
        energy_edges[i], energy_edges[i + 1]
    )
    for i in range(len(energy_edges) - 1)
]

In [None]:
features_basic = cfg["FeatureList"]["Basic"]
features_derived = cfg["FeatureList"]["Derived"]
features = features_basic + list(features_derived)
features = sorted(features)

In [None]:
diagnostic = dict.fromkeys(cameras)
for camera in cameras:
    diagnostic[camera] = ClassifierDiagnostic(
        model=data[camera]["model"],
        feature_name_list=features,
        target_name=cfg["Method"]["target_name"],
        data_train=data[camera]["data_train"],
        data_test=data[camera]["data_test"],
        model_output_name=output_model_name,
        is_output_proba=cfg["Method"]["use_proba"],
    )

## Benchmarks
[back to top](#Table-of-contents)

In [None]:
# First we check if a _plots_ folder exists already.  
# If not, we create it.
plots_folder = Path(output_directory) / "plots"
plots_folder.mkdir(parents=True, exist_ok=True)

In [None]:
# Plot aesthetics settings

style.use(matplotlib_settings["style"])
cmap = matplotlib_settings["cmap"]

if matplotlib_settings["style"] == "seaborn-colorblind":
    
    colors_order = ['#0072B2', '#D55E00', '#F0E442', '#009E73', '#CC79A7', '#56B4E9']
    rc('axes', prop_cycle=cycler(color=colors_order))

if use_seaborn:
    import seaborn as sns

    sns.set_theme(context=seaborn_settings["theme"]["context"] if "context" in seaborn_settings["theme"] else "talk",
                  style=seaborn_settings["theme"]["style"] if "style" in seaborn_settings["theme"] else "whitegrid",
                  palette=seaborn_settings["theme"]["palette"] if "palette" in seaborn_settings["theme"] else None,
                  font=seaborn_settings["theme"]["font"] if "font" in seaborn_settings["theme"] else "Fira Sans",
                  font_scale=seaborn_settings["theme"]["font_scale"] if "font_scale" in seaborn_settings["theme"] else 1.0,
                  color_codes=seaborn_settings["theme"]["color_codes"] if "color_codes" in seaborn_settings["theme"] else True
                  )
    
    sns.set_style(seaborn_settings["theme"]["style"], rc=seaborn_settings["rc_style"])
    sns.set_context(seaborn_settings["theme"]["context"],
                    font_scale=seaborn_settings["theme"]["font_scale"] if "font_scale" in seaborn_settings["theme"] else 1.0)

### Feature importance
[back to top](#Table-of-contents)

In [None]:
for camera in cameras:
    plt.figure(figsize=(5, 5))
    ax = plt.gca()
    ax = diagnostic[camera].plot_feature_importance(
        ax,
        **{"alpha": 0.7, "edgecolor": "black", "linewidth": 2, "color": "darkgreen"}
    )
    ax.set_ylabel("Feature importance")
    ax.grid()
    plt.title(camera)
    plt.tight_layout()

### Feature distributions
[back to top](#Table-of-contents)

**Note:** quantities like ``h_max`` and ``impact_dist`` are automatically shown as ``log10`` for these plots for better clarity.

In [None]:
for camera in cameras:
    
    fig, axes = diagnostic[camera].plot_features(
                camera,
                data_list=[
                    data[camera]["data_train"].query("label==1"),
                    data[camera]["data_test"].query("label==1"),
                    data[camera]["data_train"].query("label==0"),
                    data[camera]["data_test"].query("label==0"),
                ],
                nbin=30,
                hist_kwargs_list=[
                    {
                        "edgecolor": "blue",
                        "color": "blue",
                        "label": "Gamma training sample",
                        "alpha": 0.2,
                        "fill": True,
                        "ls": "-",
                        "lw": 2,
                    },
                    {
                        "edgecolor": "blue",
                        "color": "blue",
                        "label": "Gamma test sample",
                        "alpha": 1,
                        "fill": False,
                        "ls": "--",
                        "lw": 2,
                    },
                    {
                        "edgecolor": "red",
                        "color": "red",
                        "label": "Proton training sample",
                        "alpha": 0.2,
                        "fill": True,
                        "ls": "-",
                        "lw": 2,
                    },
                    {
                        "edgecolor": "red",
                        "color": "red",
                        "label": "Proton test sample",
                        "alpha": 1,
                        "fill": False,
                        "ls": "--",
                        "lw": 2,
                    },
                ],
                error_kw_list=[
                    dict(ecolor="blue", lw=2, capsize=3, capthick=2, alpha=0.2),
                    dict(ecolor="blue", lw=2, capsize=3, capthick=2, alpha=1),
                    dict(ecolor="red", lw=2, capsize=3, capthick=2, alpha=0.2),
                    dict(ecolor="red", lw=2, capsize=3, capthick=2, alpha=1),
                ],
                ncols=3,
            )
    plt.show()

### Boosted Decision Tree Error rate
[back to top](#Table-of-contents)

In [None]:
if method_name != "AdaBoostClassifier":
    
    print("The model is not an AdaBoostClassifier")

else:
    
    for camera in cameras:
        plt.figure(figsize=(5, 5))
        ax = plt.gca()
        opt = {"color": "darkgreen", "ls": "-", "lw": 2}
        BoostedDecisionTreeDiagnostic.plot_error_rate(
            ax, model, data_scikit, **opt
        )
        plt.title(camera)
        plt.tight_layout()

        plt.figure(figsize=(5, 5))
        ax = plt.gca()
        BoostedDecisionTreeDiagnostic.plot_tree_error_rate(ax, model, **opt)
        plt.title(camera)
        plt.show()

### Model output
[back to top](#Table-of-contents)

In [None]:
for camera in cameras:

    fig, ax = diagnostic[camera].plot_image_model_output_distribution(camera, nbin=50)
    ax[0].set_xlim([0, 1])
    ax[0].set_ylim([0, 1])
    fig.tight_layout()

    plt.figure(figsize=(5, 5))
    ax = plt.gca()
    plot_roc_curve(
        ax,
        diagnostic[camera].data_train[diagnostic[camera].model_output_name],
        diagnostic[camera].data_train[cfg["Method"]["target_name"]],
        **dict(color="darkgreen", lw=2, label="Training sample")
    )
    plot_roc_curve(
        ax,
        data[camera]["data_test"][diagnostic[camera].model_output_name],
        diagnostic[camera].data_test[cfg["Method"]["target_name"]],
        **dict(color="darkorange", lw=2, label="Test sample")
    )
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    ax.legend(loc="lower right")
    plt.title(camera)
    plt.show()

## Energy-dependent distributions
[back to top](#Table-of-contents)

In [None]:
hist_kwargs_list = [
    {
        "edgecolor": "blue",
        "color": "blue",
        "label": "Gamma training sample",
        "alpha": 0.2,
        "fill": True,
        "ls": "-",
        "lw": 2,
    },
    {
        "edgecolor": "blue",
        "color": "blue",
        "label": "Gamma test sample",
        "alpha": 1,
        "fill": False,
        "ls": "--",
        "lw": 2,
    },
    {
        "edgecolor": "red",
        "color": "red",
        "label": "Proton training sample",
        "alpha": 0.2,
        "fill": True,
        "ls": "-",
        "lw": 2,
    },
    {
        "edgecolor": "red",
        "color": "red",
        "label": "Proton test sample",
        "alpha": 1,
        "fill": False,
        "ls": "--",
        "lw": 2,
    },
]

error_kw_list = [
    dict(ecolor="blue", lw=2, capsize=3, capthick=2, alpha=0.2),
    dict(ecolor="blue", lw=2, capsize=3, capthick=2, alpha=1),
    dict(ecolor="red", lw=2, capsize=3, capthick=2, alpha=0.2),
    dict(ecolor="red", lw=2, capsize=3, capthick=2, alpha=1),
]

n_feature = len(cut_list)
ncols = 2
nrows = (
    int(n_feature / ncols) if n_feature % ncols == 0 else int((n_feature + 1) / ncols)
)

for camera in cameras:

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 5 * nrows))
    plt.subplots_adjust(hspace=0.5, wspace=0.5)
    if nrows == 1 and ncols == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    data_list = [
        data[camera]["data_train"].query("label==1"),
        data[camera]["data_test"].query("label==1"),
        data[camera]["data_train"].query("label==0"),
        data[camera]["data_test"].query("label==0"),
    ]

    for i, colname in enumerate(cut_list):
        ax = axes[i]

        # Range for binning
        the_range = [0, 1]

        for j, d in enumerate(data_list):
            if len(d) == 0:
                continue

            ax = plot_hist(
                ax=ax,
                data=d.query(cut_list[i])[output_model_name],
                nbin=30,
                limit=the_range,
                norm=True,
                yerr=True,
                hist_kwargs=hist_kwargs_list[j],
                error_kw=error_kw_list[j],
            )

        ax.set_xlim(the_range)
        ax.set_ylim(0, 1.2)
        ax.set_xlabel(output_model_name)
        ax.set_ylabel("Arbitrary units")
        # ax.legend(loc="best", fontsize="small")
        ax.legend(loc="upper center")
        ax.set_title(
            f"{energy_edges[i]:.2f} TeV < E_reco < {energy_edges[i+1]:.2f} TeV"
        )
        ax.grid()

    plt.suptitle(camera)

## ROC curve variation on test sample
[back to top](#Table-of-contents)

In [None]:
for camera in cameras:

    plt.figure(figsize=(6, 6))
    ax = plt.gca()

    color = 1.0
    step_color = 1.0 / (len(cut_list))
    for i, cut in enumerate(cut_list):
        c = color - (i + 1) * step_color

        test_data = data[camera]["data_test"].query(cut)
        if len(test_data) == 0:
            continue

        opt = dict(
            color=str(c),
            lw=2,
            # label="{}".format(cut.replace("reco_energy", "E")),
            label=f"{energy_edges[i]:.2f} TeV < E_reco < {energy_edges[i+1]:.2f} TeV",
        )
        plot_roc_curve(
            ax,
            test_data[output_model_name],
            test_data[cfg["Method"]["target_name"]],
            **opt,
        )
    ax.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    ax.set_title(camera)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.legend(loc="lower right")
    plt.tight_layout()

## AUC as a function of reconstructed energy
[back to top](#Table-of-contents)

In [None]:
finer_energy_edges = np.logspace(
    np.log10(0.02),
    np.log10(200),
    21,
    True,
)

cut_list_with_finer_energy_edges = [
    "reco_energy >= {:.2f} and reco_energy <= {:.2f}".format(
        finer_energy_edges[i], finer_energy_edges[i + 1]
    )
    for i in range(len(finer_energy_edges) - 1)
]

for camera in cameras:
    plt.figure(figsize=(8, 8))
    plt.title(camera)

    aucs = []
    reco_energy = []

    for i, cut in enumerate(cut_list_with_finer_energy_edges):

        selected_images = data[camera]["data_test"].query(cut)
        if len(selected_images) == 0:
            continue

        fpr, tpr, _ = roc_curve(
            y_score=selected_images[output_model_name],
            y_true=selected_images[cfg["Method"]["target_name"]],
        )
        roc_auc = auc(fpr, tpr)

        aucs.append(roc_auc)
        reco_energy.append(0.5 * (finer_energy_edges[i] + finer_energy_edges[i + 1]))

    plt.plot(reco_energy, aucs, "bo")
    plt.hlines(
        1,
        xmin=plt.gca().get_xlim()[0],
        xmax=plt.gca().get_xlim()[1],
        linestyles="dashed",
        color="green",
    )
    plt.ylim(0, 1.2)
    plt.xscale("log")
    plt.xlabel("log10(Reconstructed energy [TeV])")
    plt.ylabel("AUC")
    plt.grid()

## Precision-Recall
[back to top](#Table-of-contents)

In [None]:
from sklearn.metrics import PrecisionRecallDisplay

try:
    if cfg["Method"]["use_proba"] is True:
        response_method = "predict_proba"
    else:
        response_method = "decision_function"
except KeyError:
    response_method = "predict_proba"

for camera in cameras:

    plt.figure(figsize=(7, 5))
    plt.grid()
    plt.title(camera)

    for i, cut in enumerate(cut_list):
        c = color - (i + 1) * step_color

        selected_test_data = diagnostic[camera].data_test.query(cut)

        # skip the energy bin if it's not there
        if len(selected_test_data) == 0:
            continue

        PrecisionRecallDisplay.from_estimator(
            estimator=diagnostic[camera].model,
            x=selected_test_data[features].to_numpy(),
            y=selected_test_data[cfg["Method"]["target_name"]],
            response_method=response_method,
            ax=plt.gca(),
            name=f"{energy_edges[i]:.2f} TeV < E_reco < {energy_edges[i+1]:.2f} TeV",
            color=f"C{i}",
        )

    plt.ylim(0, 1)