# Scaling laws for neural particle flow reconstruction in 2024

In [None]:
pwd

## Imports

In [None]:
import sys
from pathlib import Path

In [None]:
from ray import tune
from ray.tune import ExperimentAnalysis, ResultGrid
import matplotlib.pyplot as plt
from matplotlib import rc_file
import pandas as pd
from matplotlib import rc_file
import pandas as pd
import seaborn as sns
import numpy as np

In [None]:
from utils import count_skipped_configurations

In [None]:
##### FZJ #####
# ray_results_folder = "/p/project/raise-ctp2/cern/ray_results/"  # Main folder containing all ray experiments

##### Flatiron #####
ray_results_folder = "/mnt/ceph/users/ewulff/ray_results/"  # Main folder containing all ray experiments
exp_dir = ray_results_folder + "test_raytune_Dscan_v2/"
####################
count_skipped_configurations(exp_dir)

In [None]:
# ExperimentAnalysis is how Ray used to organize results
expanalysis = ExperimentAnalysis(exp_dir, default_metric="val_loss", default_mode="min")

# ResultGrid is the new way Ray organizes results
resultgrid = ResultGrid(expanalysis)

In [None]:
rg_df = resultgrid.get_dataframe()

In [None]:
# add path and parameter counts to resultgrid dataframe

# get column names from a result (using the best because it's easy to get out of resultgrid)
column_names = pd.read_csv(resultgrid.get_best_result().path + '/params.csv').columns
params_df = pd.DataFrame(columns=column_names)

trial_id_list = []
path_list = []

for i, result in enumerate(resultgrid):
    trial_id_list.append(result.metrics["trial_id"])
    path_list.append(result.path)
    df = pd.read_csv(result.path + '/params.csv')
    df["trial_id"] = result.metrics["trial_id"]
    params_df = pd.concat([params_df, df], ignore_index=True)

path_df = pd.DataFrame({"trial_id": trial_id_list, "path": path_list})

# sort dataframes according to trial_dir to make sure we add the paths in correct order
path_df.sort_values("trial_id", inplace=True)
rg_df.sort_values("trial_id", inplace=True)
params_df.sort_values("trial_id", inplace=True)

rg_df["path"] = path_df["path"]
rg_df[["total_params", "trainable_params", "nontrainable_params"]] = params_df[["total_params", "trainable_params", "nontrainable_params"]]

In [None]:
#TODO: implement merging of resultgrid dataframes from several Ray Tune runs.

## Scaling Law study

In [None]:
axs = rg_df.plot(y="val_loss", x="config/train_loop_config/ntrain")
axs.set_ylabel("Validation loss")
axs.set_xlabel("Training samples")
axs.set_xscale("log")
axs.set_yscale("log")
# axs.grid(True, axis="both")

## OLD CODE
## HPO analysis

In [None]:
# result_df = expanalysis.results_df
result_df = rg_df
# result_df = expanalysis.results_df.dropna(axis=0, how="all")

In [None]:
def get_hp_df(result_df):
    return result_df.filter(regex="config/*")

In [None]:
hp_df = get_hp_df(result_df)
hp_df

In [None]:
def get_top_k_df(analysis, k):
    try:
        if isinstance(analysis, tune.result_grid.ResultGrid):
            result_df = analysis.get_dataframe()
        else:
            result_df = analysis.dataframe()
    except IndexError:
        result_df = analysis.results_df.dropna(axis=0, how="all")
    if isinstance(analysis, tune.result_grid.ResultGrid):
        dd = result_df.nsmallest(k, "val_loss")  
    elif analysis.default_mode == 'min':
        dd = result_df.nsmallest(k, analysis.default_metric)
    elif analysis.default_mode == 'max':
        dd = result_df.nlargest(k, analysis.default_metric)
    return dd

In [None]:
vars2titles = {
    'val_loss': 'Validation loss (a.u.)',
    'val_cls_loss': 'Validation classification loss (a.u.)',
    'val_reg_loss': 'Validation regression loss (a.u.)',
    'val_charge_loss': 'Validation charge loss (a.u.)',
    'val_cls_acc_weighted': 'Validation classification accuracy',
    'val_jet_wd': 'Jet Wasserstein distance (a.u.)',
    'val_met_wd': 'MET Wasserstein distance (a.u.)',
}


def trial_id2logdir(trial_id, trial_dfs, verbose=True):
    for logdir in trial_dfs.keys():
        curr = trial_dfs[logdir]
        if "trial_id" in curr.keys():
            if curr["trial_id"][0] == trial_id:
                return logdir
        elif verbose:
            print(f"WARNING: no trial id in {logdir}")
    return None


def run_label(x=0.67, y=0.90, fz=12):
#    plt.figtext(x, y, r'Run 3 (14 TeV), $\mathrm{t}\overline{\mathrm{t}}$, QCD with PU50; $\mu, \pi, \pi_0, \tau, \gamma$, single particle guns',  wrap=False, horizontalalignment='left', fontsize=fz)
    plt.figtext(x, y, r'CLIC cluster-based dataset v1.3.0, $\mathrm{t}\overline{\mathrm{t}}$, qq',  wrap=False, horizontalalignment='left', fontsize=fz)


def topk_summary_plot_v2(analysis, k, save=False, save_dir=None, skip=0, last=None, ylim=None, supress_labels=False, figsize=(12,11)):
    to_plot = [
        'val_loss', 'val_cls_loss', 'val_reg_loss', 'val_charge_loss', # 'val_jet_wd', 'val_met_wd',
    ]

    dd = get_top_k_df(analysis, k)
    dfs = analysis.trial_dataframes

    fig, axs = plt.subplots(len(to_plot), 1, figsize=figsize, tight_layout=False, sharex=True)
    plt.tight_layout(rect=[0.05, 0.02, 0.9, 1.0])
    for irow, (var, ax_row) in enumerate(zip(to_plot, axs)):
        i_plot = 1
        if "logdir" in dd.keys():
            iterator = enumerate(dd["logdir"])
        else:
            iterator = enumerate(dd.index)
        for ii, key in iterator:
            if not "logdir" in dd.keys():
                key = trial_id2logdir(key, dfs, verbose=False)
            # if var == 'val_reg_loss':
            #     values = sum([dfs[key]["val_{}_loss".format(l)].values for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]])
            #     values = values[skip:last]
            # else:
            values = dfs[key][var][skip:last]

            iterations = dfs[key].index.values[skip:last]

            # curve labels
            if (irow == 0) and (not supress_labels):
                ax_row.plot(iterations, values, alpha=0.5, label="#{}".format(i_plot))
            else:
                ax_row.plot(iterations, values, alpha=0.5)

            ax_row.set_title(vars2titles[var])
            ax_row.grid(alpha=0.3)

            if ylim:
                ax_row.set_ylim(ylim[irow])
            i_plot += 1

    ax_row.set_xlabel("Epoch")
    fig.legend(loc="center right", bbox_to_anchor=(1, 0.5), )
    plt.figtext(0.89, 0.61, "Top trials", fontsize=18)
    fig.patch.set_facecolor('white')
    plt.subplots_adjust(left=None, bottom=None, right=0.8, top=0.9, wspace=None, hspace=None)
    run_label(x=0.38, y=0.95, fz=18)

    if save or save_dir:
        if save_dir:
            plt.savefig(str(Path(save_dir) / "topk_summary_plot_v2.pdf"))
        else:
            plt.savefig("topk_summary_plot_v2.pdf")
    else:
        plt.show()

In [None]:
tops = get_top_k_df(expanalysis, 10)

In [None]:
rc_file("my_matplotlib_rcparams.txt")

In [None]:
# WARNING: if each model is trained on a differently sized dataset, epochs contain different numbers of training samples
topk_summary_plot_v2(expanalysis, 16, skip=1, figsize=(12, 12), save_dir="plots/")

In [None]:
from utils import get_best_checkpoint, get_best_epoch, get_best_loss
# best_config = expanalysis.get_best_config()
# best_logdir = expanalysis.get_best_trial().local_path
best_config = resultgrid.get_best_result().config
best_logdir = resultgrid.get_best_result().path

print("best logdir:", best_logdir)
print("best config", best_config)

# print("best checkpoint:", get_best_checkpoint(best_logdir), ", type:", type(get_best_checkpoint(best_logdir)))
# print("best epoch:", get_best_epoch(best_logdir), ", type:", type(get_best_epoch(best_logdir)))
# print("best loss:", get_best_loss(best_logdir), ", type:", type(get_best_loss(best_logdir)))

In [None]:
# from utils import showJetMet
# showJetMet(best_logdir, save_dir="plots")

In [None]:
def strip_config_str(key):
    return key.split("config/")[-1]


def style_df(df):
    cm_green = sns.light_palette("green", as_cmap=True)
    cm_red = sns.light_palette("red", as_cmap=True)

#    max_is_better = ['cls_acc_unweighted', 'val_cls_acc_weighted', 'val_cls_acc_unweighted']
    # min_is_better = ['loss', 'cls_loss', 'val_loss', 'val_cls_loss', 'val_reg_loss', 'val_jet_wd', 'val_met_wd', 'val_jet_iqr', 'val_met_iqr']
    min_is_better = ['loss', 'cls_loss', 'reg_loss', 'charge_loss', 'val_loss', 'val_cls_loss', 'val_reg_loss', 'val_charge_loss']

#    max_is_better = ['val_cls_acc_weighted', 'val_cls_acc_unweighted']
#    min_is_better = ['val_loss', 'val_cls_loss', 'val_reg_loss']

    return (df.style
#      .background_gradient(cmap=cm_green, subset=max_is_better)
      .background_gradient(cmap=cm_red, subset=min_is_better)
#      .highlight_max(subset=max_is_better, props='color:black; font-weight:bold; background-color:yellow;')
      .highlight_min(subset=min_is_better, props='color:black; font-weight:bold; background-color:yellow;')
      .set_caption('Top {} trials according to {}'.format(len(df), expanalysis.default_metric))
      .hide_index()
      )


def summarize_top_k(analysis, k, save=False, save_dir=None):
    dd = get_top_k_df(analysis, k)

    # val_reg_loss = sum([dd["val_{}_loss".format(l)].values for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]])

    summary = pd.concat([dd[[
                             "loss",
                             "cls_loss",
                             "reg_loss",
                             "charge_loss",
                             "val_loss",
                             "val_cls_loss",
                             "val_reg_loss",
                             "val_charge_loss",
                             ]],
                         # pd.DataFrame({"val_reg_loss": val_reg_loss}, index=dd.index),
                         # dd[[
                             # 'val_jet_wd', 'val_met_wd',
                             # 'val_jet_iqr', 'val_met_iqr',
                         # ]],
                         dd.filter(regex=("config/*")),
#                        dd["logdir"],
                        ],
                         axis=1)
    summary.columns = [strip_config_str(col) for col in summary.columns]

    styled_summary = style_df(summary.iloc[:,:-1])

    if save or save_dir:
        if save_dir:
            styled_summary.to_excel(str(Path(save_dir) / "summary_table.xlsx"), engine='openpyxl')
        else:
            styled_summary.to_excel("summary_table.xlsx")
    return summary, styled_summary

In [None]:
summ, styled = summarize_top_k(expanalysis, 20)
summ, styled = summarize_top_k(resultgrid, 20)

In [None]:
summ

In [None]:
styled

In [None]:
def plot_metric(logdir, trial_dfs, metric, skip=0, end=None, include_val=True, logdirs=None, save=False, xlim=None, ylim=None):
    key = metric
#    hp_df = get_hp_df(result_df)
    plt.figure()

    df = trial_dfs[logdir]
    plt.plot(df[key][skip:end], label="Training")
    if include_val:
        clr = plt.gca().lines[-1].get_color()  # get color of last plotted line
        plt.plot(df["val_" + key][skip:end], "--", color=clr, label="Validation")
    plt.legend()
    plt.ylabel(key)
    plt.xlabel("epoch")
    plt.grid(alpha=0.3)

    plt.xlim(xlim)
    plt.ylim(ylim)

    if save:
        print(f"Saving figs/{metric}.pdf")
        plt.savefig(f"figs/{metric}.pdf")


def monitor_plot(logdir, trial_dfs, skip=0, end=None, **kwargs):
    metrics_to_plot = ['loss', 'pt_loss', 'charge_loss', 'cls_loss', 'cos_phi_loss', 'energy_loss', 'eta_loss']
    metrics_to_plot_no_val = ['val_jet_iqr', 'val_jet_med', 'val_met_wd', 'val_met_iqr', 'val_met_med']

    for metric in metrics_to_plot:
        plot_metric(logdir, trial_dfs, metric, skip=skip, end=end, **kwargs)

    for metric in metrics_to_plot_no_val:
        plot_metric(logdir, trial_dfs, metric, include_val=False, skip=skip, end=end, **kwargs)


In [None]:
trial_dfs = expanalysis.trial_dataframes

In [None]:
monitor_plot(best_logdir, trial_dfs, skip=50, save=False, ylim=None)