In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import numpy as np
import json

In [None]:
pwd

In [None]:
# train_dirs = [
#     "/p/project/raise-ctp2/cern/particleflow/experiments/"
#              ]

train_dirs = list(Path("/p/project/raise-ctp2/cern/particleflow/experiments/").glob("before_raytune_with_jet_met_logs_*"))
info_string = "Before hypertuning"

print("Length of train_dirs:", len(train_dirs))

In [None]:
train_dirs = train_dirs[:-1]

In [None]:
def get_full_history(hist_dir, verbose=False):
    jsons = list(hist_dir.glob("history*.json"))
    if verbose:
        print(f"{hist_dir.parent} has {len(jsons)} hisotries")
    jsons.sort(key=lambda x: int(x.name.split("_")[1].split(".")[0]))  # sort according to epoch number

    # initialize a dict with correct keys and empty lists as values
    with open(jsons[0]) as h:
        keys = json.load(h).keys()
    full_history = {key: [] for key in keys}

    # join epoch values to a full history
    for path in jsons:
        with open(path) as h:
            epoch = json.load(h)
            for key in epoch.keys():
                full_history[key].append(epoch[key])

    reg_loss = np.sum(
        np.array([full_history["{}_loss".format(l)] for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]]),
        axis=0,
    )
    val_reg_loss = np.sum(
        np.array(
            [full_history["val_{}_loss".format(l)] for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]]
        ),
        axis=0,
    )
    full_history.update({"reg_loss": reg_loss})
    full_history.update({"val_reg_loss": val_reg_loss})

    return full_history, len(jsons)


def get_histories(train_dirs):
    train_dirs = [Path(train_dir) for train_dir in train_dirs]
    histories = []

    for train_dir in train_dirs:
        hist, N = get_full_history(hist_dir=train_dir / "logs/history")
        histories.append(hist)

    return histories

In [None]:
histories = get_histories(train_dirs)

In [None]:
# def cms_label(x0=0.12, x1=0.23, x2=0.67, y=0.90):
#     plt.figtext(x0, y,'CMS',fontweight='bold', wrap=True, horizontalalignment='left', fontsize=16)
#     plt.figtext(x1, y,'Simulation Preliminary', style='italic', wrap=True, horizontalalignment='left', fontsize=14)
#     plt.figtext(x2, y,r'Run 3 (14 TeV), $\mathrm{t}\overline{\mathrm{t}}$, QCD with PU50',  wrap=False, horizontalalignment='left', fontsize=12)

def cms_label(x0=0.12, y=0.90, s=None, fz=22):
    plt.figtext(x0, y,'CMS',fontweight='bold', wrap=True, horizontalalignment='left', fontsize=fz)
    plt.figtext(x0+0.09, y,'Simulation Preliminary', style='italic', wrap=True, horizontalalignment='left', fontsize=fz-3)
    if s is not None:
        t = plt.figtext(x=x0, y=y-0.15, s=s[:-1], fontsize=fz-6)
#         t.set_bbox(dict(facecolor='white', alpha=0.9, edgecolor='black'))

# def run_label(x=0.67, y=0.90, fz=12):
#     plt.figtext(x, y,r'Run 3 (14 TeV), $\mathrm{t}\overline{\mathrm{t}}$, QCD with PU50',  wrap=False, horizontalalignment='left', fontsize=fz)


def run_label(x=0.67, y=0.90, fz=22):
    plt.figtext(x, y,r'Run 3 (14 TeV), $\mathrm{t}\overline{\mathrm{t}}$, $\mathrm{z}\tau\tau$, QCD, QCD with high $p_T$, PU 55-75',  wrap=False, horizontalalignment='left', fontsize=fz)


def get_combined_array(key):
    combined_array = np.array(histories[0][key])
    for ii in range(1, len(histories)):
        combined_array = np.vstack([combined_array, np.array(histories[ii][key])])
    return combined_array

In [None]:
finished_histories = []
for history in histories:
    if len(history['loss']) == 100:
        finished_histories.append(history)
histories = finished_histories

In [None]:
loss_array = get_combined_array("loss")
reg_loss_array = get_combined_array("reg_loss")
cls_loss_array = get_combined_array("cls_loss")

val_loss_array = get_combined_array("val_loss")
val_reg_loss_array = get_combined_array("val_reg_loss")
val_cls_loss_array = get_combined_array("val_cls_loss")

cls_acc_weighted_array = get_combined_array("cls_acc_weighted")
val_cls_acc_weighted_array = get_combined_array("val_cls_acc_weighted")

val_met_wd_array = get_combined_array("val_met_wd")
val_jet_wd_array = get_combined_array("val_jet_wd")
val_met_iqr_array = get_combined_array("val_met_iqr")
val_jet_iqr_array = get_combined_array("val_jet_iqr")
val_met_med_array = get_combined_array("val_met_med")
val_jet_med_array = get_combined_array("val_jet_med")

In [None]:
loss_array.shape, val_loss_array.shape

In [None]:
def plot_variance_curve(array_list,
                        labels,
                        skip=0,
                        ylim=None,
                        save_path=None,
                        x=0.45,
                        y=0.53,
                        loc=None,
                        ylabel=None,
                        custom_info=None,
                       ):
    fig = plt.figure()
    final_means = []
    final_stds = []
    for ii, array in enumerate(array_list):
        xx = np.array(range(array.shape[1])) + 1  # Epochs

        xx = xx[skip:]
        array = array[:, skip:]

        std = np.std(array, axis=0)
        mean = np.mean(array, axis=0)

        plt.plot(xx, mean, label=labels[ii])
        plt.fill_between(xx, mean - std, mean + std, alpha=0.4)

        # Add individual loss curves
        # plt.plot(np.tile(xx, reps=[10,1]).transpose(), array.transpose(), linewidth=0.2)

        print(labels[ii] + ": {:.4f} +/- {:#.2g}".format(mean[-1], std[-1]))
        final_means.append(mean[-1])
        final_stds.append(std[-1])

#     plt.legend(bbox_to_anchor=(0.98, 0.78), loc="center right")
    if loc is not None:
        plt.legend(loc=loc)
    else:
        plt.legend()
    plt.xlabel("Epochs")
    if ylabel:
        plt.ylabel(ylabel)

    s="{:s}\nMean and standard deviation of {:d} trainings\n".format(info_string, array.shape[0])
    for ii, label in enumerate(labels):
        if custom_info:
            s += "Final {}: {} +/- {}\n".format(label.lower(), custom_info[ii]['mean'], custom_info[ii]["std"])
        else:
            s += "Final {}: {:.3f} +/- {:#.2g}\n".format(label.lower(), final_means[ii], final_stds[ii])

    if ylim:
        plt.ylim(top=ylim[1], bottom=ylim[0])

    cms_label(x0=x, y=y, s=s, fz=26)
    run_label(x=0.12, y=0.89, fz=24)
    if save_path:
        plt.savefig(Path(save_path).with_suffix('.png'))
        plt.savefig(Path(save_path).with_suffix('.pdf'))
    plt.show()

In [None]:
plt.close()

In [None]:
mpl.rc_file("my_matplotlib_rcparams")

In [None]:
#Axes
mpl.rcParams["axes.titlesize"] = 20
mpl.rcParams["axes.labelsize"] = 30

# Ticks
mpl.rcParams["xtick.labelsize"] = 25
mpl.rcParams["ytick.labelsize"] = 25
mpl.rcParams["xtick.direction"] = "in"
mpl.rcParams["ytick.direction"] = "in"

# Legend
mpl.rcParams["legend.fontsize"] = 24

mpl.rcParams["grid.alpha"] = 0

In [None]:
pwd

In [None]:
plot_variance_curve([loss_array, val_loss_array],
                    labels=["Training loss", "Validation loss"],
                    skip=10,
                    # ylim=[0.5, 2.5],
                    save_path="loss_curves_std_after_tuning.png",
                    x=0.37,
                    y=0.68,
                    ylabel="Loss"
                   )

In [None]:
plot_variance_curve([cls_acc_weighted_array, val_cls_acc_weighted_array],
                    ["Train accuracy", "Valididation accuracy"],
                    skip=10,
                    # ylim=(0.7, 0.9),
                    save_path="cls_acc_std_after_tuning.png",
                    x=0.37,
                    y=0.48,
                    loc="lower right",
                    ylabel="Accuracy"
                   )

In [None]:
plot_variance_curve([reg_loss_array, val_reg_loss_array],
                    labels=["Training regression loss", "Validation regression loss"],
                    skip=10,
                    # ylim=(0.0, 0.2),
                    ylabel="Regression loss",
                   )

In [None]:
plot_variance_curve([val_jet_wd_array, val_met_wd_array],
                    labels=["Jet Wasserstein distance", "MET Wasserstein distances"],
                    skip=0,
                    # ylim=(0.0, 0.2),
                    ylabel="Jet and MET Wasserstein distance",
                   )

In [None]:
plot_variance_curve([cls_loss_array, val_cls_loss_array],
                    labels=["Training classification loss", "Validation classification loss"],
                    skip=10,
#                     ylim=(0.0, 0.005),
                    ylabel="Classification loss",
                   )