In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import numpy as np
import json
from uncertainties import ufloat

In [None]:
from utils import get_histories, get_full_history, get_combined_array

In [None]:
pwd

In [None]:
##### FZJ #####
train_dirs = list(Path("/p/project/raise-ctp2/cern/particleflow/experiments/")
                  .glob("before_raytune_with_jet_met_logs_*"))
info_string = "Before HPO"
###############

##### Flatiron #####
train_dirs = list(Path("/mnt/ceph/users/ewulff/particleflow/experiments/")
                  .glob("clic_gnn_beforeHPO*"))
info_string = "Before HPO"
###############

print("Length of train_dirs:", len(train_dirs))

In [None]:
histories = get_histories(train_dirs)

In [None]:
def find_shortest_history(histories):
    if len(histories) == 0:
        raise ValueError("Given history list is empty list")
    shortest = 1000000
    for history in histories:
        l = len(history['loss'])
        if l < shortest:
            shortest = l
    return shortest

find_shortest_history(histories)

In [None]:
for history in histories:
    print(len(history['loss']))

In [None]:
metrics = ['loss', 'reg_loss', 'cls_loss', 'val_loss', 'val_reg_loss', 'val_cls_loss',
           # 'cls_acc_weighted', 'val_cls_acc_weighted', 
           'val_met_wd', 'val_jet_wd', 'val_met_iqr', 'val_jet_iqr', 'val_met_med', 'val_jet_med'
          ]

# shortest = find_shortest_history(histories)
# for history in histories:
#     for metric in metrics:
#         history[metric] = history[metric][:shortest]

finished_histories = []
for history in histories:
    if len(history['loss']) == 150:
        finished_histories.append(history)
histories = finished_histories

In [None]:
loss_array = get_combined_array(histories, "loss")
reg_loss_array = get_combined_array(histories,"reg_loss")
cls_loss_array = get_combined_array(histories,"cls_loss")

val_loss_array = get_combined_array(histories,"val_loss")
val_reg_loss_array = get_combined_array(histories,"val_reg_loss")
val_cls_loss_array = get_combined_array(histories,"val_cls_loss")

# cls_acc_weighted_array = get_combined_array("cls_acc_weighted")
# val_cls_acc_weighted_array = get_combined_array("val_cls_acc_weighted")

val_met_wd_array = get_combined_array(histories,"val_met_wd")
val_jet_wd_array = get_combined_array(histories,"val_jet_wd")
val_met_iqr_array = get_combined_array(histories,"val_met_iqr")
val_jet_iqr_array = get_combined_array(histories,"val_jet_iqr")
val_met_med_array = get_combined_array(histories,"val_met_med")
val_jet_med_array = get_combined_array(histories,"val_jet_med")

In [None]:
loss_array.shape, val_loss_array.shape

In [None]:
def sigdigits(mean, std):
    return "{:.2u}".format(ufloat(mean, std))


def run_label(x=0.67, y=0.90, fz=12):
    plt.figtext(x, y, r'GNN-based model, cluster-based CLIC dataset v1.3.0, $\mathrm{t}\overline{\mathrm{t}}$, qq',  wrap=False, horizontalalignment='right', fontsize=fz)


def cms_label(x0=0.12, y=0.90, s=None, fz=22):
    # plt.figtext(x0, y,'CMS',fontweight='bold', wrap=True, horizontalalignment='left', fontsize=fz)
    # plt.figtext(x0+0.09, y,'Simulation Preliminary', style='italic', wrap=True, horizontalalignment='left', fontsize=fz-3)
    if s is not None:
        t = plt.figtext(x=x0, y=y-0.15, s=s[:-1], fontsize=fz-6)


def plot_variance_curve(array_list,
                        labels,
                        skip=0,
                        ylim=None,
                        save_path=None,
                        x=0.45,
                        y=0.53,
                        loc=None,
                        ylabel=None,
                        custom_info=None,
                       ):
    fig = plt.figure()
    final_means = []
    final_stds = []
    for ii, array in enumerate(array_list):
        print(f"{labels[ii]} is averaged over {array.shape[0]} trainings.")
        xx = np.array(range(array.shape[1])) + 1  # Epochs

        xx = xx[skip:]
        array = array[:, skip:]

        std = np.std(array, axis=0)
        mean = np.mean(array, axis=0)

        plt.plot(xx, mean, label=labels[ii])
        plt.fill_between(xx, mean - std, mean + std, alpha=0.4)

        # Add individual loss curves
        # plt.plot(np.tile(xx, reps=[10,1]).transpose(), array.transpose(), linewidth=0.2)

        print(labels[ii] + ": {:s}".format(sigdigits(mean[-1], std[-1])))
        final_means.append(mean[-1])
        final_stds.append(std[-1])

#     plt.legend(bbox_to_anchor=(0.98, 0.78), loc="center right")
    if loc is not None:
        plt.legend(loc=loc)
    else:
        plt.legend()
    plt.xlabel("Epochs")
    if ylabel:
        plt.ylabel(ylabel)

    s="{:s}\nMean and standard deviation of {:d} trainings\n".format(info_string, array.shape[0])
    for ii, label in enumerate(labels):
        if custom_info:
            s += "Final {}: {:s}\n".format(label.lower(), sigdigits(custom_info[ii]['mean'], custom_info[ii]["std"]))
        else:
            s += "Final {}: {:s}\n".format(label.lower(), sigdigits(final_means[ii], final_stds[ii]))

    if ylim:
        plt.ylim(top=ylim[1], bottom=ylim[0])

    plt.subplots_adjust(left=0.14)
        
    cms_label(x0=x, y=y, s=s, fz=24)
    run_label(x=0.9, y=0.89, fz=22)
    if save_path:
        plt.savefig(Path(save_path).with_suffix('.png'))
        plt.savefig(Path(save_path).with_suffix('.pdf'))
    plt.show()

In [None]:
plt.close()

In [None]:
mpl.rc_file("my_matplotlib_rcparams.txt")

In [None]:
#Axes
mpl.rcParams["axes.labelsize"] = 24

# Ticks
mpl.rcParams["xtick.labelsize"] = 22
mpl.rcParams["ytick.labelsize"] = 22
mpl.rcParams["xtick.direction"] = "in"
mpl.rcParams["ytick.direction"] = "in"

# Legend
mpl.rcParams["legend.fontsize"] = 24

mpl.rcParams["grid.alpha"] = 0.3

In [None]:
plot_variance_curve([loss_array, val_loss_array],
                    labels=["Training loss", "Validation loss"],
                    skip=10,
                    ylim=[5, 12],
                    save_path="std_plots/beforeHPO_gnn_loss_curves_std_after_tuning.png",
                    x=0.4,
                    y=0.4,
                    ylabel="Loss (a.u.)"
                   )

In [None]:
plot_variance_curve([reg_loss_array, val_reg_loss_array],
                    labels=["Training regression loss", "Validation regression loss"],
                    skip=10,
                    save_path="std_plots/beforeHPO_gnn_reg_loss_curves_std_after_tuning.png",
                    x=0.39,
                    y=0.4,
                    ylim=(0.19, 0.55),
                    ylabel="Regression loss (a.u.)",
                   )

In [None]:
plot_variance_curve([val_jet_wd_array, val_met_wd_array],
                    labels=["Jet Wasserstein distance", "MET Wasserstein distances"],
                    skip=4,
                    save_path="std_plots/beforeHPO_gnn_wd_curves_std_after_tuning.png",
                    x=0.39,
                    y=0.3,
                    ylim=(-1, 5),
                    ylabel="Jet & MET Wasserstein distance (a.u.)",
                   )

In [None]:
plot_variance_curve([cls_loss_array, val_cls_loss_array],
                    labels=["Training classification loss", "Validation classification loss"],
                    skip=10,
                    save_path="std_plots/beforeHPO_gnn_cls_loss_curves_std_after_tuning.png",
                    x=0.33,
                    y=0.4,
                    ylim=(0.034, 0.075),
                    ylabel="Classification loss (a.u.)",
                   )