In [None]:
import pandas as pd
import json
import glob
import matplotlib.pyplot as plt
import numpy as np

import sklearn
import sklearn.metrics
import matplotlib
import scipy
import mplhep as hep


In [None]:
!ls ../experiments/*/pred.npz

In [None]:
def flatten(arr):
    return arr.reshape((arr.shape[0]*arr.shape[1], arr.shape[2]))

In [None]:
pid_names = {
    1: "ch.had",
    2: "n.had",
    3: "HFEM",
    4: "HFHAD",
    5: "g",
    6: "el",
    7: "mu"
}

var_names = {
    1: "charge",
    2: "pt",
    3: "eta",
    4: "sin phi",
    5: "cos phi",
    6: "energy"
}

In [None]:
dd = np.load("../experiments/cms-gnn-skipconn-v2-6c655f0d/pred.npz")
X = dd["X"]
ygen = dd["ygen"]
ycand = dd["ycand"]
ypred = dd["ypred"]
ypred_raw = dd["ypred_raw"]

X_f = X.reshape((X.shape[0]*X.shape[1], X.shape[2]))
ygen_f = ygen.reshape((ygen.shape[0]*ygen.shape[1], ygen.shape[2]))
ycand_f = ycand.reshape((ycand.shape[0]*ycand.shape[1], ycand.shape[2]))
ypred_f = ypred.reshape((ypred.shape[0]*ypred.shape[1], ypred.shape[2]))
ypred_raw_f = ypred_raw.reshape((ypred_raw.shape[0]*ypred_raw.shape[1], ypred_raw.shape[2]))

In [None]:
msk_X = X_f[:, 0]!=0

In [None]:
x_labels = [
    "none", "track", "PS1", "PS2", "ECAL", "HCAL", "GSF", "BREM", "HFEM", "HFHAD", "SC", "HO"
]
y_labels = [
    "none", "ch.had", "n.had", "HFEM", "HFHAD", "g", "el", "mu"
]

cm0 = sklearn.metrics.confusion_matrix(
    X_f[msk_X, 0],
    ycand_f[msk_X, 0],
    labels=range(12),
    normalize="pred"
)

plt.figure(figsize=(8, 8))
plt.imshow(cm0[:12, :8], cmap="Blues")
plt.colorbar()
plt.yticks(ticks=range(12), labels=x_labels);
plt.xticks(ticks=range(8), labels=y_labels);
plt.xlabel("PFCandidate")
plt.ylabel("PFElement")

In [None]:
cm1 = sklearn.metrics.confusion_matrix(
    X_f[msk_X, 0],
    ypred_f[msk_X, 0],
    labels=range(12),
    normalize="pred"
)

plt.figure(figsize=(8, 8))
plt.imshow(cm1[:12, :8], cmap="Blues")
plt.colorbar()
plt.yticks(ticks=range(12), labels=x_labels);
plt.xticks(ticks=range(8), labels=y_labels);
plt.xlabel("PFCandidate")
plt.ylabel("PFElement")

In [None]:
def apply_thresholds_f(thresholds):
    msk = np.ones_like(ypred_raw_f)
    for i in range(len(thresholds)):
        msk[:, i] = ypred_raw_f[:, i]>thresholds[i]
    ypred_id_f = np.argmax(ypred_raw_f*msk, axis=-1)
    return ypred_id_f

def apply_thresholds(thresholds):
    msk = np.ones_like(ypred_raw)
    for i in range(len(thresholds)):
        msk[:, :, i] = ypred_raw[:, :, i]>thresholds[i]
    ypred_id = np.argmax(ypred_raw*msk, axis=-1)
    return ypred_id

In [None]:
niter = 0
accs = []
def func(thresholds):
    global niter
    #thresholds = np.round(thresholds, 2)
    
    ypred_id = apply_thresholds(thresholds)
    #ypred_id = np.argmax(ypred_raw/thresholds, axis=-1)
    
    err = 0
    for icls in range(1,8):
        ntrue = np.sum(ycand[:, :, 0]==icls, axis=1)
        npred = np.sum(ypred_id==icls, axis=1)
        e = np.sqrt(np.sum((ntrue-npred)**2)) / np.mean(ntrue)
        err += e
    #perm = np.random.permutation(ycand_f[msk_X, 0].shape[0])[:100000]
    #acc = sklearn.metrics.balanced_accuracy_score(ycand_f[msk_X, 0][perm], ypred_id_f[msk_X][perm])
    #acc = sklearn.metrics.jaccard_score(ycand_f[msk_X, 0], ypred_id_f[msk_X], average="macro")
    #err += np.sum((thresholds - 1.0)**2)
    accs.append(err)
    niter += 1
    if niter%10==0:
        print(niter, err, thresholds)
    return err

ret = scipy.optimize.minimize(
    func,
    0.5*np.ones(8),
    method="Nelder-Mead",
    options={"adaptive": True, "xatol": 0.01, "fatol": 0.01}
    #bounds=[(0,2) for i in range(8)]
)

In [None]:
plt.plot(accs)
plt.ylim(0.8*np.min(accs),100)

In [None]:
ypred_id = apply_thresholds(ret.x)
ypred_id_f = apply_thresholds_f(ret.x)

In [None]:
cm_norm = sklearn.metrics.confusion_matrix(
    ycand_f[msk_X, 0],
    ypred_id_f[msk_X],
    labels=range(8),
    normalize="true"
)

cm = sklearn.metrics.confusion_matrix(
    ycand_f[msk_X, 0],
    ypred_id_f[msk_X],
    labels=range(8),
)

In [None]:
plt.figure(figsize=(8, 8))
plt.imshow(cm_norm[1:, 1:], cmap="Blues")
plt.colorbar()

In [None]:
plt.figure(figsize=(8, 8))
plt.imshow(cm[1:, 1:], cmap="Blues", norm=matplotlib.colors.LogNorm())
plt.colorbar()

In [None]:
bins = {
    2: np.linspace(0,100,100),
    3: np.linspace(-8,8,100),
    4: np.linspace(-1,1,100),
    5: np.linspace(-1,1,100),
    6: np.linspace(0,500,100),
}

In [None]:
u1 = np.unique(ycand_f[msk_X, 0], return_counts=True)
u2 = np.unique(ypred_id_f[msk_X], return_counts=True)
u3 = np.unique(ygen_f[msk_X, 0], return_counts=True)

In [None]:
icls = 6
ivar = 2
b = np.linspace(0,100,100)
plt.hist(ygen_f[ygen_f[:, 0]==icls, ivar], bins=b, histtype="step", lw=2);
plt.hist(ycand_f[ycand_f[:, 0]==icls, ivar], bins=b, histtype="step", lw=2);
plt.hist(ypred_f[ypred_id_f==icls, ivar], bins=b, histtype="step", lw=2);
plt.yscale("log")

In [None]:
plt.bar(u1[0]-0.2, u1[1], width=0.4)
plt.bar(u2[0]+0.2, u2[1], width=0.4)
#plt.bar(u3[0]+0.2, u3[1], width=0.4)
#plt.yscale("log")

In [None]:
fig, axes = plt.subplots(7, 6, figsize=(6*6,7*5))

for axs, icls in zip(axes, range(1,8)):    
    axes = axs.flatten()
    
    npred = np.sum(ypred_id == icls, axis=1)
    ncand = np.sum(ycand[:, :, 0] == icls, axis=1)
    ngen = np.sum(ygen[:, :, 0] == icls, axis=1)
    
    a = 0.5*np.min(ncand)
    b = 1.5*np.max(ncand)
    
    axes[0].scatter(ncand, npred)
    
    axes[0].set_xlim(a,b)
    axes[0].set_ylim(a,b)
    axes[0].plot([a,b],[a,b], color="black", ls="--")
    axes[0].set_title(pid_names[icls])
    axes[0].set_xlabel("number of PFCandidates")
    axes[0].set_ylabel("number of MLPFCandidates")
        
    for ivar, ax in zip([2,3,4,5,6], axes[1:]):
        hist = np.histogram2d(
            ycand_f[(ycand_f[:, 0]==icls) & (ypred_id_f==icls), ivar],
            ypred_f[(ycand_f[:, 0]==icls) & (ypred_id_f==icls), ivar], bins=(bins[ivar], bins[ivar])
        )
        hep.hist2dplot(
            hist, cmap="Blues",
            #norm=matplotlib.colors.LogNorm(vmin=1, vmax=max(10, 10*np.max(hist[0]))),
            norm=matplotlib.colors.Normalize(vmin=0, vmax=max(10, np.max(hist[0]))),
            ax=ax
        )
        ax.set_title("{}, {}".format(pid_names[icls], var_names[ivar]))
        ax.set_xlabel("true value (PFCandidate)")
        ax.set_ylabel("reconstructed value (MLPF)")
plt.tight_layout()