# Evaluate *ritme* trials of all usecases


## Setup

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from src.evaluate_trials import (
    multi_boxplot_metric,
    plot_complexity_vs_metric,
    plot_trend_over_time_multi_models,
    plot_recent_param_cat_over_time,
    plot_recent_param_cont_over_time,
)

warnings.filterwarnings("ignore", category=FutureWarning)

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
######## USER INPUTS ########

# path to extracted MLflow logs - with script extract_all_logs.sh
log_folder_location = "merged_all_trials_v124.csv"

# which usecase to analyze: "u1", "u2", "u3" or "all"
usecase = "u1"

# which samplers to analyse: "tpe", "random"
sampler = "tpe"

# how many trials to consider for complexity vs. performance plot
top_x = 1000

# whether to save figures in this run
save_figures = False
# figure saving dpi
dpi = 400
######## END USER INPUTS #####

In [None]:
# set title
if usecase == "u1":
    title = "Usecase 1"
    log_x_scale = False
elif usecase == "u2":
    title = "Usecase 2"
    log_x_scale = False
elif usecase == "u3":
    title = "Usecase 3"
    log_x_scale = True
else:
    title = "All usecases"

## Extract trial information

In [None]:
# extract all trial information
all_trials = pd.read_csv(log_folder_location)
# sort by asc metrics.rmse_val
all_trials = all_trials.sort_values(by="metrics.rmse_val", ascending=True)
print(f"Found {all_trials.shape[0]} trials")

In [None]:
if usecase != "all":
    print(f"Analyzing trials for usecase: {usecase}")
    if usecase == "u3":
        all_trials = all_trials[
            np.logical_and(
                all_trials["tags.experiment_tag"].str.startswith("u3_galaxy"),
                ~all_trials["tags.experiment_tag"].str.contains("w_start"),
            )
        ]
    else:
        all_trials = all_trials[
            all_trials["tags.experiment_tag"].str.startswith(usecase)
        ]

if sampler != "all":
    print(f"Analyzing trials for sampler: {sampler}")
    all_trials = all_trials[all_trials["tags.experiment_tag"].str.contains(sampler)]

print(f"Selected {all_trials.shape[0]} trials")

## Find best trial

In [None]:
# find the best trial & the best model type
top_1_trial = all_trials.head(1)
best_model_type = top_1_trial["params.model"].values[0]

print(best_model_type)
top_1_trial["tags.experiment_tag"]

In [None]:
top_1_trial["metrics.nb_features"]

## Insights on performance: ALL trials - ft_eng/model vs. RMSE validation

In [None]:
fig, axes = multi_boxplot_metric(
    all_trials,
    metric_col="metrics.rmse_val",
    metric_name="RMSE Validation",
    group_specs=[
        ("params.data_aggregation", "Data aggregation"),
        ("params.data_selection", "Data selection"),
        ("params.data_transform", "Data transform"),
        ("params.data_enrich", "Data enrichment"),
        ("params.model", "Model type"),
    ],
    order_by_median=True,
    showfliers=False,
    title=title,
    x_log_scale=log_x_scale,
)
if save_figures:
    fig.savefig(
        f"result_figures/boxplot_all_trials_{usecase}_{sampler}.pdf",
        bbox_inches="tight",
        dpi=dpi,
    )

In [None]:
fig, axes = multi_boxplot_metric(
    all_trials,
    metric_col="metrics.nb_features",
    metric_name="Number of features",
    group_specs=[
        ("params.model", "Model type"),
    ],
    order_by_median=True,
    showfliers=False,
    title=title,
    x_log_scale=True,
    figsize=(6, 4),
)
plt.tight_layout()
if save_figures:
    fig.savefig(
        f"result_figures/boxplot_all_trials_{usecase}_{sampler}_nb_fts.pdf",
        bbox_inches="tight",
        dpi=dpi,
    )

## Model complexity vs. performance: top X trials

In [None]:
top_x_trials = all_trials.head(top_x)

figc, _ = plot_complexity_vs_metric(
    top_x_trials,
    metric_col="metrics.rmse_val",
    metric_name="RMSE Validation",
    group_col="params.model",
    group_name="Model type",
    n=top_x,
    figsize=(7, 6),
    title=title,
    x_log_scale=True,
)

if save_figures:
    figc.savefig(
        f"result_figures/complexity_top_trials_{usecase}_{sampler}.pdf",
        bbox_inches="tight",
        dpi=dpi,
    )

## Training over time: ALL trials

In [None]:
nb_models = all_trials["params.model"].nunique()

In [None]:
fig, axes = plot_trend_over_time_multi_models(
    all_trials,
    y_col="metrics.rmse_val",
    window=100,
    title_prefix="Model: ",
    figsize=(7, 3 * nb_models),
    first_n=None,
    y_log_scale=True,
    std_alpha=0.3,
)
if save_figures:
    fig.savefig(
        f"result_figures/trend_over_time_{usecase}_{sampler}.pdf",
        bbox_inches="tight",
        dpi=dpi,
    )

## Trend over time: best model type last N trials

In [None]:
# Show last N trials as a barplot colored by group
n_last = 200
window_length = 40

In [None]:
best_model_trials = all_trials[all_trials["params.model"] == best_model_type]
best_model_trials.shape

In [None]:
# get continuous and categorical hyperparameter columns
cont_cols = best_model_trials.select_dtypes(include=[np.number]).columns.tolist()
cont_params_cols = [x for x in cont_cols if x.startswith("params")]


cat_cols = [c for c in best_model_trials.columns if c not in cont_cols]
cat_params_cols = [x for x in cat_cols if x.startswith("params")]
cat_params_cols.remove("params.model")
cat_params_cols.remove("params.data_enrich_with")

In [None]:
# select columns that are not NaN over all trials of the best_model_types
cont_params_cols = [x for x in cont_params_cols if best_model_trials[x].notna().any()]
cat_params_cols = [x for x in cat_params_cols if best_model_trials[x].notna().any()]

In [None]:
for group in cat_params_cols:
    fig_recent, ax_recent = plot_recent_param_cat_over_time(
        best_model_trials,
        y_col="metrics.rmse_val",
        group_col=group,
        time_col="start_time",
        n_last=n_last,
        title=f"{best_model_type} model — {group.replace('params.', '').replace('_', ' ')}",
        figsize=(7, 2),
        font_scale=0.9,
        y_log_scale=False,
        window=window_length,
    )
    if save_figures:
        fig_recent.savefig(
            f"result_figures/recent_{n_last}_trials_{usecase}_{sampler}_by_{group.replace('.', '_')}.pdf",
            bbox_inches="tight",
            dpi=dpi,
        )

In [None]:
# logarithmic hyperparameters:
log_hyperparams = [
    "params.alpha",
    "params.min_samples_split",
    "params.min_samples_leaf",
    "params.lambda",
    "params.eta",
    "params.gamma",
    "params.reg_alpha",
    "params.reg_lambda",
    "params.weight_decay",
    "params.early_stopping_min_delta",
]

In [None]:
# plot numeric values
for param in cont_params_cols:
    print(f"Processing parameter: {param}")

    if param in log_hyperparams:
        binning = "log-uniform"
    else:
        binning = "uniform"
    figc, _ = plot_recent_param_cont_over_time(
        best_model_trials,
        param_col=param,
        group_col="params.data_selection",
        time_col="start_time",
        n_last=n_last,
        n_bins=4,
        binning=binning,
        figsize=(7, 2),
        title=f"{best_model_type} model — {param.replace('params.', '').replace('_', ' ')}",
        font_scale=0.9,
        palette_name="Spectral",
        y_log_scale=False,
        window=window_length,
    )
    if save_figures:
        figc.savefig(
            f"result_figures/param_bins_last_{n_last}_{usecase}_{sampler}_{param.replace('.', '_')}.pdf",
            bbox_inches="tight",
            dpi=dpi,
        )