# Analysis of different MLPF event losses

In [None]:
pwd

In [None]:
#|export
from pathlib import Path
import shutil
from tqdm import tqdm
import json

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mpl

from sigfig import round

In [None]:
exp_dirs = [
    Path("/p/project/raise-ctp2/cern/ray_results/event_loss_scan"),
    Path("/p/project/raise-ctp2/cern/ray_results/event_loss_scan6"),
]

trial_dirs = []
for exp_dir in exp_dirs:
    some_trial_dirs = list(exp_dir.glob("build_model_and_train_*"))
    trial_dirs += some_trial_dirs

In [None]:
print("Number of trials:", len(trial_dirs))
for trial_dir in trial_dirs:
    print(trial_dir)

In [None]:
def trial_dir2loss_names(trial_dir):
    # return event_loss_name, met_loss_name
    return trial_dir.name.split("'")[1], trial_dir.name.split("'")[-2]

In [None]:
# check extraction of loss names works for all dirs
for trial_dir in trial_dirs:
    event_loss_name, met_loss_name = trial_dir2loss_names(trial_dir)
    print(event_loss_name, met_loss_name)

In [None]:
for count, trial_dir in tqdm(enumerate(trial_dirs), total=len(trial_dirs), desc="Copying files"):

    event_loss_name, met_loss_name = trial_dir2loss_names(trial_dir)

    dest = Path(
        f"/p/project/raise-ctp2/cern/particleflow/event_loss_logs_and_histories_60h/{event_loss_name}_{met_loss_name}"
    )

    # dest.mkdir(parents=True, exist_ok=True)
    # shutil.copytree(trial_dir / "logs", dest / f"logs_{count}")
    # shutil.copytree(trial_dir / "history", dest / f"history_{count}")

In [None]:
# def cms_label(x0=0.12, x1=0.23, x2=0.67, y=0.90):
#     plt.figtext(x0, y,'CMS',fontweight='bold', wrap=True, horizontalalignment='left', fontsize=16)
#     plt.figtext(x1, y,'Simulation Preliminary', style='italic', wrap=True, horizontalalignment='left', fontsize=14)
#     plt.figtext(x2, y,r'Run 3 (14 TeV), $\mathrm{t}\overline{\mathrm{t}}$, $\mathrm{z}\tau\tau$, QCD with PU50, QCD with high $p_T$',  wrap=False, horizontalalignment='left', fontsize=12)


def cms_label(x0=0.12, y=0.90, s=None, fz=30):
    plt.figtext(x0, y,'CMS',fontweight='bold', wrap=True, horizontalalignment='left', fontsize=fz)
    plt.figtext(x0+0.1, y,'Simulation Preliminary', style='italic', wrap=True, horizontalalignment='left', fontsize=fz-3)
    if s is not None:
        t = plt.figtext(x=x0, y=y-0.2, s=s[:-1], fontsize=fz-6)
#         t.set_bbox(dict(facecolor='white', alpha=0.9, edgecolor='black'))

def run_label(x=0.67, y=0.90, fz=22):
    plt.figtext(x, y,r'Run 3 (14 TeV), $\mathrm{t}\overline{\mathrm{t}}$, $\mathrm{z}\tau\tau$, QCD, QCD with high $p_T, PU 55-75$',  wrap=False, horizontalalignment='left', fontsize=fz)


def get_full_history(hist_dir, verbose=False):
    jsons = list(hist_dir.glob("history*.json"))
    if verbose:
        print(f"{hist_dir.parent} has {len(jsons)} hisotries")
    jsons.sort(key=lambda x: int(x.name.split("_")[1].split(".")[0]))  # sort according to epoch number

    # initialize a dict with correct keys and empty lists as values
    with open(jsons[0]) as h:
        keys = json.load(h).keys()
    full_history = {key: [] for key in keys}

    # join epoch values to a full history
    for path in jsons:
        with open(path) as h:
            epoch = json.load(h)
            for key in epoch.keys():
                full_history[key].append(epoch[key])

    reg_loss = np.sum(
        np.array([full_history["{}_loss".format(l)] for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]]),
        axis=0,
    )
    val_reg_loss = np.sum(
        np.array(
            [full_history["val_{}_loss".format(l)] for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]]
        ),
        axis=0,
    )
    full_history.update({"reg_loss": reg_loss})
    full_history.update({"val_reg_loss": val_reg_loss})

    return full_history, len(jsons)

In [None]:
fh, _ = get_full_history(trial_dirs[0] / "logs/history")

In [None]:
df = pd.DataFrame.from_dict(fh)  # orient="index", columns=[f"epoch {i}" for i in range(len(fh["loss"]))])

In [None]:
df

In [None]:
#|export
def get_combined_array(nested_list, max_length=None):
    combined_array = np.array(nested_list[0][:max_length])
    for ii in range(1, len(nested_list)):
        combined_array = np.vstack([combined_array, np.array(nested_list[ii][:max_length])])
    return combined_array


def get_largest_common(data):
    largest_common = 100
    for trial in data:
        curr = len(trial)
        if curr < largest_common:
            largest_common = curr
    return largest_common


def plot_variance_curve(
    data_list,
    metric,
    labels,
    skip=0,
    ylim=None,
    save_path=None,
    x=0.45,
    y=0.53,
    loc=None,
    ylabel=None,
    verbose=False,
    s_xpos=0.5,
    s_ypos=0.6,
):
    fig = plt.figure()
    final_means = []
    final_stds = []
    for ii, data in enumerate(data_list):

        data = [d[metric] for d in data]

        largest_common_epoch = get_largest_common(data)

        array = get_combined_array(data, max_length=largest_common_epoch)
        xx = np.array(range(array.shape[1])) + 1  # Epochs

        xx = xx[skip:]
        array = array[:, skip:]

        std = np.std(array, axis=0)
        mean = np.mean(array, axis=0)

        plt.plot(xx, mean, label=labels[ii])
        plt.fill_between(xx, mean - std, mean + std, alpha=0.4)

        if verbose:
            print(labels[ii] + ": {}".format(round(mean[-1], std[-1], cutoff=99)))
        final_means.append(mean[-1])
        final_stds.append(std[-1])

    s = "Mean and standard deviation of {} trainings\n".format(array.shape[0])
    for ii, label in enumerate(labels):
        s += "Final {}: {}\n".format(label.lower(), round(final_means[ii], final_stds[ii], cutoff=99))

    # t = plt.figtext(x=s_xpos, y=s_ypos, s=s[:-1], fontsize=12)
    # t.set_bbox(dict(facecolor="white", alpha=0.9, edgecolor="black"))

    if loc is not None:
        plt.legend(loc=loc)
    else:
        plt.legend()
    plt.xlabel("Epochs")
    if ylabel:
        plt.ylabel(ylabel)

    if ylim:
        plt.ylim(top=ylim[1], bottom=ylim[0])

    # cms_label(x0=0.3, y=0.65, s=s, fz=28)
    cms_label(x0=0.135, y=0.82, s=None, fz=28)
    run_label(x=0.16, y=0.90, fz=22)
    plt.grid(alpha=0.3)

    if save_path:
        plt.savefig(Path(save_path).with_suffix(".png"))
        plt.savefig(Path(save_path).with_suffix(".pdf"))
    else:
        plt.show()

In [None]:
#|export
def get_history_summary(trial_dirs):
    # create a dict with keys and empty lists
    summary = {key: [] for key in [f"{trial_dir2loss_names(a)[0]}:{trial_dir2loss_names(a)[1]}" for a in trial_dirs]}

    for trial_dir in tqdm(trial_dirs, total=len(trial_dirs), desc="Proessing history files"):
        fh, _ = get_full_history(trial_dir / "logs/history")
        event_loss_name, met_loss_name = trial_dir2loss_names(trial_dir)
        key = f"{event_loss_name}:{met_loss_name}"

        summary[key].append(fh)
    return summary

In [None]:
#|export
label_dict = {
    "sliced_wasserstein:none": "Sliced Wasserstein",
    "none:none": "Baseline",
    "gen_jet_logcosh:none": "Gen-jet logcosh",
    "hist_2d:none": "2D histogram",
    "none:met": "MET",
}


def plot_metric(metric, history_summary, save_path=None, verbose=False, ylabel=None, skip=0):
    if save_path is not None:
        Path(save_path).parent.mkdir(exist_ok=True, parents=True)
    plot_variance_curve(
        data_list=[history_summary[key] for key in history_summary.keys()],
        metric=metric,
        labels=[label_dict[key] for key in history_summary.keys()],
        ylabel=ylabel,
        save_path=save_path,
        verbose=verbose,
        skip=skip,
    )

In [None]:
#|export
mpl.rc_file("my_matplotlib_rcparams")
history_summary = get_history_summary(trial_dirs)
if "hist_2d:none" in history_summary.keys():
    history_summary.pop("hist_2d:none")

metrics_to_plot = [
    "val_cls_loss",
    "val_reg_loss",
    "val_jet_wd",
    "val_jet_iqr",
    "val_jet_med",
    "val_met_wd",
    "val_met_iqr",
    "val_met_med",
]

metric2name = {
    "val_cls_loss": "Validation classification loss",
    "val_reg_loss": "Validation regression loss",
    "val_jet_wd": "Validation jet Wasserstein distance",
    "val_jet_iqr": "Validation jet IQR",
    "val_jet_med": "Validation jet median",
    "val_met_wd": "Validation MET Wasserstein distance",
    "val_met_iqr": "Validation MET IQR",
    "val_met_med": "Validation MET median",
}


for metric in tqdm(metrics_to_plot, total=len(metrics_to_plot), desc="Plotting"):
    plot_metric(
        metric,
        history_summary,
        save_path=Path("event_loss_plots") / metric,
        verbose=False,
        ylabel=metric2name[metric],
        skip=5,
    )

In [None]:
from nbdev.export import nb_export

nb_export("analyze_event_loss_scan.ipynb", ".")

In [None]:
ll

In [None]:
cat event_loss_analysis.py