# Ray Tune Experiment Analysis 2023

In [None]:
pwd

In [None]:
from fanova import fANOVA

## Imports

In [None]:
import sys
import os
from pathlib import Path

In [None]:
from ray.tune import Analysis, ExperimentAnalysis
import matplotlib.pyplot as plt
from matplotlib import rc_file
import pandas as pd
from matplotlib import rc_file
import pandas as pd
import seaborn as sns

In [None]:
from utils import count_skipped_configurations

In [None]:
##### FZJ #####
# ray_results_folder = "/p/project/raise-ctp2/cern/ray_results/"  # Main folder containing all ray experiments

# exp_dir = ray_results_folder + "clic_gnn_scan"  # chosen experiment
# exp_dir = ray_results_folder + "clic_transformer_scan"  # chosen experiment
# exp_dir = ray_results_folder + "clic_transformer_search_asha_n500"  # chosen experiment
# exp_dir = ray_results_folder + "clic_transformer_search_asha_n500_v2"  # chosen experiment
# exp_dir = ray_results_folder + "clic_transformer_search_asha_hyperopt_n500"  # chosen experiment


##### Flatiron #####
# model_name = "gnn"
model_name = "transformer"
ray_results_folder = "/mnt/ceph/users/ewulff/ray_results/"  # Main folder containing all ray experiments
exp_dir = ray_results_folder + f"clic_{model_name}_search_asha_hyperopt_n500"

# exp_dir = ray_results_folder + "clic_gnn_search_asha_hyperopt_n500"
# exp_dir = ray_results_folder + "clic_transformer_search_asha_hyperopt_n500"
####################
count_skipped_configurations(exp_dir)

In [None]:
ll "{ray_results_folder}"

In [None]:
expanalysis = ExperimentAnalysis(exp_dir, default_metric="val_loss", default_mode="min")

In [None]:
os.environ["TUNE_RESULT_DELIM"] = '/'
# result_df = expanalysis.dataframe()
result_df = expanalysis.results_df.dropna(axis=0, how="all")

In [None]:
csv_folder = "hpo_results_csv"
result_df.to_csv(f"{csv_folder}/unsorted_hpo_result_table_{model_name}.csv")

In [None]:
for c in result_df.columns: print(c)

In [None]:
def get_hp_df(result_df):
    # return pd.concat([result_df["val_loss"], result_df.filter(regex="config/*")], axis=1)
    return result_df.filter(regex="config/*")

In [None]:
hp_df = get_hp_df(result_df)
hp_df

In [None]:
def get_top_k_df(analysis, k):
    try:
        result_df = analysis.dataframe()
    except IndexError:
        result_df = analysis.results_df.dropna(axis=0, how="all")
    if analysis.default_mode == 'min':
        dd = result_df.nsmallest(k, analysis.default_metric)
    elif analysis.default_mode == 'max':
        dd = result_df.nlargest(k, analysis.default_metric)
    return dd

In [None]:
vars2titles = {
    'val_loss': 'Validation loss (a.u.)',
    'val_cls_loss': 'Validation classification loss (a.u.)',
    'val_reg_loss': 'Validation regression loss (a.u.)',
    'val_cls_acc_weighted': 'Validation classification accuracy',
    'val_jet_wd': 'Jet Wasserstein distance (a.u.)',
    'val_met_wd': 'MET Wasserstein distance (a.u.)',
}


def trial_id2logdir(trial_id, trial_dfs, verbose=True):
    for logdir in trial_dfs.keys():
        curr = trial_dfs[logdir]
        if "trial_id" in curr.keys():
            if curr["trial_id"][0] == trial_id:
                return logdir
        elif verbose:
            print(f"WARNING: no trial id in {logdir}")
    return None


def run_label(x=0.67, y=0.90, fz=12):
#    plt.figtext(x, y, r'Run 3 (14 TeV), $\mathrm{t}\overline{\mathrm{t}}$, QCD with PU50; $\mu, \pi, \pi_0, \tau, \gamma$, single particle guns',  wrap=False, horizontalalignment='left', fontsize=fz)
    plt.figtext(x, y, r'CLIC cluster-based dataset v1.3.0, $\mathrm{t}\overline{\mathrm{t}}$, qq',  wrap=False, horizontalalignment='left', fontsize=fz)


def topk_summary_plot_v2(analysis, k, save=False, save_dir=None, skip=0, last=None, ylim=None, supress_labels=False, figsize=(12,11)):
    to_plot = [
        'val_loss', 'val_cls_loss', 'val_reg_loss', 'val_jet_wd', 'val_met_wd',
    ]

    dd = get_top_k_df(analysis, k)
    dfs = analysis.trial_dataframes

    fig, axs = plt.subplots(len(to_plot), 1, figsize=figsize, tight_layout=False, sharex=True)
    plt.tight_layout(rect=[0.05, 0.02, 0.9, 1.0])
    for irow, (var, ax_row) in enumerate(zip(to_plot, axs)):
        i_plot = 1
        if "logdir" in dd.keys():
            iterator = enumerate(dd["logdir"])
        else:
            iterator = enumerate(dd.index)
        for ii, key in iterator:
            if not "logdir" in dd.keys():
                key = trial_id2logdir(key, dfs, verbose=False)
            if var == 'val_reg_loss':
                values = sum([dfs[key]["val_{}_loss".format(l)].values for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]])
                values = values[skip:last]
            else:
                values = dfs[key][var][skip:last]

            iterations = dfs[key].index.values[skip:last]

            # curve labels
            if (irow == 0) and (not supress_labels):
                ax_row.plot(iterations, values, alpha=0.5, label="#{}".format(i_plot))
            else:
                ax_row.plot(iterations, values, alpha=0.5)

            ax_row.set_title(vars2titles[var])
            ax_row.grid(alpha=0.3)

            if ylim:
                ax_row.set_ylim(ylim[irow])
            i_plot += 1

    ax_row.set_xlabel("Epoch")
    fig.legend(loc="center right", bbox_to_anchor=(1, 0.5), )
    plt.figtext(0.89, 0.61, "Top trials", fontsize=18)
    fig.patch.set_facecolor('white')
    plt.subplots_adjust(left=None, bottom=None, right=0.8, top=0.9, wspace=None, hspace=None)
    run_label(x=0.38, y=0.95, fz=18)

    if save or save_dir:
        if save_dir:
            plt.savefig(str(Path(save_dir) / "topk_summary_plot_v2.pdf"))
        else:
            plt.savefig("topk_summary_plot_v2.pdf")
    else:
        plt.show()

In [None]:
tops = get_top_k_df(expanalysis, 10)

In [None]:
rc_file("my_matplotlib_rcparams.txt")

In [None]:
topk_summary_plot_v2(expanalysis, 6, skip=40, figsize=(12, 12), save_dir="plots/")

In [None]:
from utils import get_best_checkpoint, get_best_epoch, get_best_loss

best_config = expanalysis.get_best_config()
best_logdir = expanalysis.get_best_logdir()

print("best logdir:", best_logdir)
print("best config", best_config)

print("best checkpoint:", get_best_checkpoint(best_logdir), ", type:", type(get_best_checkpoint(best_logdir)))
print("best epoch:", get_best_epoch(best_logdir), ", type:", type(get_best_epoch(best_logdir)))
print("best loss:", get_best_loss(best_logdir), ", type:", type(get_best_loss(best_logdir)))

In [None]:
expanalysis.best_dataframe["val_loss"].min()

In [None]:
from utils import showJetMet
showJetMet(best_logdir, save_dir="plots")

In [None]:
def strip_config_str(key):
    return key.split("config/")[-1]


def style_df(df):
    cm_green = sns.light_palette("green", as_cmap=True)
    cm_red = sns.light_palette("red", as_cmap=True)

#    max_is_better = ['cls_acc_unweighted', 'val_cls_acc_weighted', 'val_cls_acc_unweighted']
    min_is_better = ['loss', 'cls_loss', 'val_loss', 'val_cls_loss', 'val_reg_loss', 'val_jet_wd', 'val_met_wd', 'val_jet_iqr', 'val_met_iqr']

#    max_is_better = ['val_cls_acc_weighted', 'val_cls_acc_unweighted']
#    min_is_better = ['val_loss', 'val_cls_loss', 'val_reg_loss']

    return (df.style
#      .background_gradient(cmap=cm_green, subset=max_is_better)
      .background_gradient(cmap=cm_red, subset=min_is_better)
#      .highlight_max(subset=max_is_better, props='color:black; font-weight:bold; background-color:yellow;')
      .highlight_min(subset=min_is_better, props='color:black; font-weight:bold; background-color:yellow;')
      .set_caption('Top {} trials according to {}'.format(len(df), expanalysis.default_metric))
      .hide_index()
      )


def summarize_top_k(analysis, k, save=False, save_dir=None):
    dd = get_top_k_df(analysis, k)

    val_reg_loss = sum([dd["val_{}_loss".format(l)].values for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]])

    summary = pd.concat([dd[[
                             "loss",
                             "cls_loss",
                             "val_loss",
                             "val_cls_loss",
                             ]],
                         pd.DataFrame({"val_reg_loss": val_reg_loss}, index=dd.index),
                         dd[[
                             'val_jet_wd', 'val_met_wd',
                             'val_jet_iqr', 'val_met_iqr',
                         ]],
                         dd.filter(regex=("config/*")),
#                        dd["logdir"],
                        ],
                         axis=1)
    summary.columns = [strip_config_str(col) for col in summary.columns]

    styled_summary = style_df(summary.iloc[:,:])

    if save or save_dir:
        if save_dir:
            styled_summary.to_excel(str(Path(save_dir) / "summary_table.xlsx"), engine='openpyxl')
        else:
            styled_summary.to_excel("summary_table.xlsx")
    return summary, styled_summary

In [None]:
top_k = 50
summ, styled = summarize_top_k(expanalysis, top_k)

In [None]:
summ.to_csv(f"{csv_folder}/top{top_k}_sorted_hpo_result_table_{model_name}.csv", index_label="trial_number")

In [None]:
styled

In [None]:
su = summ.iloc[:10,:]
su.index = range(1,11)

In [None]:
def plot_metric(logdir, trial_dfs, metric, skip=0, end=None, include_val=True, logdirs=None, save=False, xlim=None, ylim=None):
    key = metric
#    hp_df = get_hp_df(result_df)
    plt.figure()

    df = trial_dfs[logdir]
    plt.plot(df[key][skip:end], label="Training")
    if include_val:
        clr = plt.gca().lines[-1].get_color()  # get color of last plotted line
        plt.plot(df["val_" + key][skip:end], "--", color=clr, label="Validation")
    plt.legend()
    plt.ylabel(key)
    plt.xlabel("epoch")
    plt.grid(alpha=0.3)

    plt.xlim(xlim)
    plt.ylim(ylim)

    if save:
        print(f"Saving figs/{metric}.pdf")
        plt.savefig(f"figs/{metric}.pdf")


def monitor_plot(logdir, trial_dfs, skip=0, end=None, **kwargs):
    metrics_to_plot = ['loss', 'pt_loss', 'charge_loss', 'cls_loss', 'cos_phi_loss', 'energy_loss', 'eta_loss']
    metrics_to_plot_no_val = ['val_jet_iqr', 'val_jet_med', 'val_met_wd', 'val_met_iqr', 'val_met_med']

    for metric in metrics_to_plot:
        plot_metric(logdir, trial_dfs, metric, skip=skip, end=end, **kwargs)

    for metric in metrics_to_plot_no_val:
        plot_metric(logdir, trial_dfs, metric, include_val=False, skip=skip, end=end, **kwargs)


In [None]:
trial_dfs = expanalysis.trial_dataframes

### Hyperparameter importance using fANOVA

In [None]:
Y = result_df["val_loss"].values
X = hp_df

In [None]:
X.shape, Y.shape

In [None]:
f = fANOVA(X, Y)

In [None]:
res = f.quantify_importance((0,))

In [None]:
res[(0,)]['individual importance']

In [None]:
res = {}
res["hp"] = []
res["importance"] = []
for i in range(10):
    res["importance"].append(f.quantify_importance((i,))[(i,)]['individual importance'])
    res["hp"].append(X.columns[i].split("config/")[1])
    print(f.quantify_importance((i,)), X.columns[i].split("config/")[1])

In [None]:
importance_df = pd.DataFrame(res)
importance_df

In [None]:
importance_df.to_csv(f"{csv_folder}/hp_importances_{model_name}.csv", index=False)

In [None]:
f.get_most_important_pairwise_marginals(n=10)

In [None]:
monitor_plot(best_logdir, trial_dfs, skip=50, save=False, ylim=None)