### Misclassification of Particles

To analyze particle misclassification correctly, follow these steps in order:

1. **Run `compare_models.ipynb` First**:  
   Start by executing the `compare_models.ipynb` notebook. This step is necessary because it generates the data and metrics required for comparing models, which this notebook depends on.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# switch to the project directory
%cd ..
# working directory should be ../pdi

In [None]:
import sys
import os
module_path = os.path.abspath('src')

if module_path not in sys.path:
    sys.path.append(module_path)

2. **Set the `model_names` Variable**:  
   After running `compare_models.ipynb`, update the `model_names` variable in this notebook. Ensure it includes the names of the models you want to analyze. These names should match the ones used in the comparison notebook

In [None]:
import os
import pandas as pd
import numpy as np

from pdi.constants import PARTICLES_DICT, TARGET_CODES
from pdi.data.config import RUN, MODEL_NAME

particle_names = [PARTICLES_DICT[i] for i in TARGET_CODES]
model_names = ["Proposed"]  # ["Mean", "Regression", "Ensemble", "Delete", "NSigma"]
metrics = ["precision", "recall", "f1"]
data_types = ["all", "complete_only"]

# load predictions
tc_prediction_data = {}
for model_name in model_names:
    tc_prediction_data[model_name] = {}
    for data_type in data_types:
        tc_prediction_data[model_name][data_type] = {}
        for particle_name in particle_names:
            tc_prediction_data[model_name][data_type][particle_name] = {}
            prediction_file = f"reports/predictions/run{RUN}/{MODEL_NAME}/{data_type}/{particle_name}/{model_name}.npz"
            if os.path.exists(prediction_file):
                tc_prediction_data[model_name][data_type][particle_name] = np.load(prediction_file, allow_pickle=True)

3. **Set the `features_to_plot` Variable**:  
   The `features_to_plot` variable determines the features used for scatterplots. It can be set in two ways:
   - **List of Features**: Provide a list of feature names (e.g., `["fP", "fTPCSignal", "fBeta"]`). The code will automatically create all possible 2-element combinations of these features for plotting.
   - **List of Tuples**: Provide a list of tuples where each tuple specifies the x-axis and y-axis features explicitly (e.g., `[("fP", "fTPCSignal"), ("fP", "fBeta")]`).

   Adjust this variable based on the features you want to visualize.

In [None]:
import pandas as pd
import numpy as np
import os
from pdi.data.data_exploration import generate_figure_thumbnails_from_iterator, plot_feature_combinations
from IPython.display import display

features_to_plot = [("fP", "fTPCSignal"), ("fP", "fTRDPattern"), ("fP", "fBeta")]
saved_dataframes = {}  # Dictionary to store DataFrames

def prepare_dataframe(input_data, columns, targets, selected):
    """Prepare a DataFrame with additional columns for analysis."""
    df = pd.DataFrame(input_data, columns=columns)
    df["selected"] = selected
    df["targets"] = targets
    df = df.sort_values(by=["selected", "targets"])

    return df

def generate_scatterplots(df, features, save_dir, key):
    """Generate and save scatterplots for different conditions."""
    scatterplot_configs = [
        {
            "subdir": "all_observations",
            "data": df,
            "condition": df["selected"] != df["targets"],
            "legend": ("Correct prediction", "Incorrect prediction"),
            "title": f"Scatter of {{feature1}} vs {{feature2}} for all observations in {key}"
        },
        {
            "subdir": "target_observations",
            "data": df[df["targets"]],
            "condition": ~(df[df["targets"]]["selected"]),
            "legend": ("Correct prediction", "Incorrect prediction"),
            "title": f"Scatter of {{feature1}} vs {{feature2}} for target observations in {key}"
        },
        {
            "subdir": "non_target_observations",
            "data": df[~df["targets"]],
            "condition": df[~df["targets"]]["selected"],
            "legend": ("Correct prediction", "Incorrect prediction"),
            "title": f"Scatter of {{feature1}} vs {{feature2}} for non-target observations in {key}"
        }
    ]

    for config in scatterplot_configs:
        save_subdir = os.path.join(save_dir, config["subdir"])
        os.makedirs(save_subdir, exist_ok=True)
        html_element = generate_figure_thumbnails_from_iterator(
            plot_feature_combinations(
                config["data"], features, config["condition"],
                condition_legend=config["legend"],
                title_template=config["title"],
                log_scale_y=True,
            ),
            save_subdir
        )
        display(html_element)

# Main loop to process data and generate scatterplots
for particle_name in particle_names:
    for model_name in model_names:
        for data_type in data_types:
            prediction_file = f"reports/predictions/run{RUN}/{MODEL_NAME}/{data_type}/{particle_name}/{model_name}.npz"
            if os.path.exists(prediction_file):
                prediction_data = np.load(prediction_file, allow_pickle=True)
            else:
                print(f"Prediction file not found: {prediction_file}")
                continue

            input_data = prediction_data["input_data"]
            columns = prediction_data["columns"]
            targets = prediction_data["targets"]
            selected = prediction_data["selected"]

            # Prepare DataFrame
            df = prepare_dataframe(input_data, columns, targets, selected)
            key = f"{model_name}_{particle_name}_{data_type}"
            saved_dataframes[key] = df
            

# Generate scatterplots for each DataFrame
save_dir = f"reports/figures/misclassified_particles_scatterplots/run{RUN}/{model_name}/"
for key, df in saved_dataframes.items():
    save_subdir = os.path.join(save_dir, key)
    generate_scatterplots(df, features_to_plot, save_subdir, key)