In [None]:
import glob
import json
import pandas
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
def load_history(path, min_epoch=None, max_epoch=None):
    ret = {}
    for fi in glob.glob(path):
        data = json.load(open(fi))
        epoch = int(fi.split("_")[-1].split(".")[0])
        ret[epoch] = data

    if not max_epoch:
        max_epoch = max(ret.keys())
    if not min_epoch:
        min_epoch = min(ret.keys())

    ret2 = []
    for i in range(min_epoch, max_epoch + 1):
        ret2.append(ret[i])
    return pandas.DataFrame(ret2)

In [None]:
data_losses = {}
for losstype in [
    "baseline",
    "baseline-mask_reg_cls0",
    "genjet_logcosh",
    "genjet_mse",
    "genjet_logcosh_mask_reg_cls0",
    "swd",
]:
    data_losses[losstype] = []
    num_iter = glob.glob("../experiments/fromEric/{}/logs_*".format(losstype))
    for it in num_iter:
        df = load_history("{}/history/*.json".format(it))
        data_losses[losstype].append(df)

In [None]:
def plot_as_shaded(loss, key):

    maxfilled = min([len(l) for l in data_losses[key]])
    loss_vals = np.stack([l[loss][:maxfilled] for l in data_losses[key]], axis=-1)
    x = np.array(range(len(loss_vals)))
    y = np.percentile(loss_vals, 50, axis=-1)
    err_lo = np.percentile(loss_vals, 25, axis=-1)
    err_hi = np.percentile(loss_vals, 75, axis=-1)
    c = plt.plot(x, y, label=key)
    plt.fill_between(x, err_lo, err_hi, alpha=0.2, color=c[0].get_color())

In [None]:
X = np.stack([l["loss"] for l in data_losses["baseline"]], axis=-1)

In [None]:
for i in range(len(data_losses["baseline"])):
    l = data_losses["baseline"][i]
    plt.plot(l["val_pt_loss"], l["val_cls_loss"])
plt.xlim(0.13, 0.15)
plt.ylim(0.053, 0.055)
plt.xlabel("pt loss")
plt.ylabel("classification loss")

In [None]:
loss = "val_cls_loss"
plot_as_shaded(loss, "baseline")
plot_as_shaded(loss, "baseline-mask_reg_cls0")
plot_as_shaded(loss, "genjet_logcosh")
plot_as_shaded(loss, "genjet_mse")
plot_as_shaded(loss, "genjet_logcosh_mask_reg_cls0")
plot_as_shaded(loss, "swd")
plt.legend(loc="best")
plt.ylim(0.053, 0.06)
plt.ylabel(loss)
plt.xlabel("epoch")

In [None]:
loss = "val_pt_loss"
plot_as_shaded(loss, "baseline")
plot_as_shaded(loss, "baseline-mask_reg_cls0")
plot_as_shaded(loss, "genjet_logcosh")
plot_as_shaded(loss, "genjet_mse")
plot_as_shaded(loss, "genjet_logcosh_mask_reg_cls0")
plot_as_shaded(loss, "swd")
plt.legend(loc="best")
plt.ylim(0.1, 0.3)
plt.ylabel(loss)
plt.xlabel("epoch")

In [None]:
loss = "val_energy_loss"
plot_as_shaded(loss, "baseline")
plot_as_shaded(loss, "baseline-mask_reg_cls0")
plot_as_shaded(loss, "genjet_logcosh")
plot_as_shaded(loss, "genjet_mse")
plot_as_shaded(loss, "genjet_logcosh_mask_reg_cls0")
plot_as_shaded(loss, "swd")
plt.legend(loc="best")
plt.ylim(2.5, 2.7)
plt.ylabel(loss)
plt.xlabel("epoch")