In [None]:
import setGPU
import torch
import torch_geometric
import sklearn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import sys
sys.path += ["../test"]

In [None]:
import train_end2end
import graph_data

In [None]:
device = torch.device("cuda")

In [None]:
weights = torch.load("/storage/user/jpata/particleflow/data/PFNet6__npar_5552143__cfg_a8420e1ef2__user_jpata__ntrain_7000__lr_0.0001__1581357310/epoch_70/PFNet6__npar_5552143__cfg_a8420e1ef2__user_jpata__ntrain_7000__lr_0.0001__1581357310.best.pth")

In [None]:
model = train_end2end.PFNet6(15, 512, 14)
model.load_state_dict(weights)

In [None]:
model

In [None]:
# for real validation, use QCD
#p = "/storage/user/jpata/particleflow/data/QCD_run3"
# just to see, use TTbar (note it was trained on same events)
p = "/storage/user/jpata/particleflow/data/TTbar_run3"
full_dataset = graph_data.PFGraphDataset(root=p)
full_dataset.raw_dir = p
full_dataset.processed_dir = p + "/processed"

In [None]:
pred_ids = []
true_ids = []
pred_momenta = []
true_momenta = []

for i in range(1000):
    d = full_dataset.get(i)
    d.batch = torch.zeros((len(d.x)), dtype=torch.long)
    train_end2end.data_prep(d, device=device)
    edges, cand_id_onehot, cand_momentum = model(d)
    _, pred_id = torch.max(cand_id_onehot, -1)
    pred_ids += [pred_id.detach().cpu().numpy()]
    true_ids += [d.y_candidates_id.detach().cpu().numpy()]
    pred_momenta += [cand_momentum.detach().cpu().numpy()]
    true_momenta += [d.y_candidates.detach().cpu().numpy()]
    
#     cm = sklearn.metrics.confusion_matrix(
#         d.y_candidates_id.detach().cpu().numpy(),
#         pred_id.detach().cpu().numpy()
#     )
    if i%10 == 0:
        print(i, (pred_ids[-1]!=0).sum(), (true_ids[-1]!=0).sum())
        

In [None]:
n_preds = []
n_trues = []
for i in range(len(pred_ids)):
    n_true = np.sum(true_ids[i]!=0)
    n_pred = np.sum(pred_ids[i]!=0)
    n_preds += [n_pred]
    n_trues += [n_true]

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.axes()
plt.plot([1500,5000],[1500,5000], color="black", lw=0.5)
plt.scatter(n_trues, n_preds, marker=".", alpha=0.5)
plt.xlim(1500,5000)
plt.ylim(1500,5000)
plt.xlabel("Number of Target PF Candidates",fontsize=13)
plt.ylabel("Number of Predicted GNN Candidates",fontsize=13)
#plt.title("QCD Run3")

plt.text(0.67, 1.05, "Run 3 (14 TeV)", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.02, 0.98, "CMS", transform=ax.transAxes, va="top", ha="left",size=16, fontweight='bold')
plt.text(0.18, 0.975, "Simulation Preliminary", transform=ax.transAxes, va="top", ha="left",size=12,style='italic')
#plt.text(0.03, 0.92, "QCD dijet events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.03, 0.92, "$\mathrm{t}\overline{\mathrm{t}}$ events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.tight_layout()
plt.savefig("num_pred.pdf")

In [None]:
cms = []
for i in range(len(pred_ids)):
    cm = sklearn.metrics.confusion_matrix(
        true_ids[i],
        pred_ids[i], labels=range(len(train_end2end.class_labels))
    )
    cms += [cm]
cm = sum(cms)
cm = cm / 1000.0
cm = np.round(cm, 1)#.astype(np.int)

In [None]:
train_end2end.plot_confusion_matrix(cm, [int(x) for x in train_end2end.class_labels], normalize=True)
#plt.xlim(-0.5, 9.5)
#plt.ylim(-0.5, 9.5)
plt.title("Normalized Confusion Matrix (QCD Run3)")
#plt.text(0.02, 0.98, "CMS Simulation, preliminary", transform=ax.transAxes, va="top", ha="left")
#plt.tight_layout()
plt.savefig("cm.pdf")

In [None]:
pm = np.concatenate(pred_momenta)
tm = np.concatenate(true_momenta)
ti = np.concatenate(true_ids)
pi = np.concatenate(pred_ids)

In [None]:
plt.figure(figsize=(5, 5))

ax = plt.axes()
bins = np.linspace(0, 50, 100)
h0 = plt.hist(pm[pi!=0, 0], bins=bins, histtype="step", lw=1, label="PF");
h1 = plt.hist(tm[ti!=0, 0], bins=bins, histtype="step", lw=1, label="GNN");
plt.yscale("log")
plt.legend(frameon=False)
plt.ylim(10, 1e7)

plt.xlabel("Candidate $p_{\mathrm{T}}$ (a.u.)",fontsize=13)
plt.ylabel("Number of Candidates",fontsize=13)
#plt.title("QCD Run 3")

plt.text(0.67, 1.05, "Run 3 (14 TeV)", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.02, 0.98, "CMS", transform=ax.transAxes, va="top", ha="left",size=16, fontweight='bold')
plt.text(0.18, 0.975, "Simulation Preliminary", transform=ax.transAxes, va="top", ha="left",size=12,style='italic')
#plt.text(0.03, 0.92, "QCD dijet events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.03, 0.92, "$\mathrm{t}\overline{\mathrm{t}}$ events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.tight_layout()
plt.savefig("pt_hist.pdf")

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.axes()

bins = np.linspace(-4, 4, 100)
plt.hist(pm[pi!=0, 1], bins=bins, histtype="step", lw=1);
plt.hist(tm[ti!=0, 1], bins=bins, histtype="step", lw=1);
plt.yscale("log")

plt.ylim(1000, 1e6)
plt.xlabel("Candidate $\eta$ (a.u.)",fontsize=13)
plt.ylabel("Number of Candidates",fontsize=13)
#plt.title("QCD Run 3")
plt.text(0.67, 1.05, "Run 3 (14 TeV)", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.02, 0.98, "CMS", transform=ax.transAxes, va="top", ha="left",size=16, fontweight='bold')
plt.text(0.18, 0.975, "Simulation Preliminary", transform=ax.transAxes, va="top", ha="left",size=12,style='italic')
#plt.text(0.03, 0.92, "QCD dijet events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.03, 0.92, "$\mathrm{t}\overline{\mathrm{t}}$ events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.tight_layout()
plt.savefig("eta_hist.pdf")

In [None]:
plt.figure(figsize=(5, 5))

ax = plt.axes()
bins = np.linspace(-3, 3, 60)
plt.hist(pm[pi!=0, 2], bins=bins, histtype="step", lw=1);
plt.hist(tm[ti!=0, 2], bins=bins, histtype="step", lw=1);
plt.yscale("log")
plt.ylim(1000, 1e6)

plt.xlabel("Candidate $\phi$ (a.u.)",fontsize=13)
plt.ylabel("Number of Candidates",fontsize=13)
#plt.title("QCD Run 3")

plt.text(0.67, 1.05, "Run 3 (14 TeV)", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.02, 0.98, "CMS", transform=ax.transAxes, va="top", ha="left",size=16, fontweight='bold')
plt.text(0.18, 0.975, "Simulation Preliminary", transform=ax.transAxes, va="top", ha="left",size=12,style='italic')
#plt.text(0.03, 0.92, "QCD dijet events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.03, 0.92, "$\mathrm{t}\overline{\mathrm{t}}$ events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.tight_layout()
plt.savefig("phi_hist.pdf")

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.axes()

subidx = np.where((pi!=0)&(ti!=0))[0]
rp = np.random.permutation(range(len(subidx)))[:1000]

plt.scatter(pm[subidx[rp], 0], tm[subidx[rp], 0], marker=".", alpha=0.5)
plt.xlim(0,2)
plt.ylim(0,2)
plt.plot([0,2],[0,2], color="black")

plt.xlabel("Target PF Candidate $p_{\mathrm{T}}$ (a.u.)",fontsize=13)
plt.ylabel("Predicted GNN Candidate $p_{\mathrm{T}}$ (a.u.)", fontsize=13)
#plt.title("QCD Run 3, 1000 candidates")

plt.text(0.67, 1.05, "Run 3 (14 TeV)", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.02, 0.98, "CMS", transform=ax.transAxes, va="top", ha="left",size=16, fontweight='bold')
plt.text(0.18, 0.975, "Simulation Preliminary", transform=ax.transAxes, va="top", ha="left",size=12,style='italic')
#plt.text(0.03, 0.92, "QCD dijet events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.03, 0.92, "$\mathrm{t}\overline{\mathrm{t}}$ events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.tight_layout()
plt.savefig("pt_corr.pdf")

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.axes()

plt.plot([-7, 7], [-7, 7], color="black", lw=0.5)
plt.scatter(pm[subidx[rp], 1], tm[subidx[rp], 1], marker=".", alpha=0.5)
plt.xlim(-7, 7)
plt.ylim(-7, 7)

plt.xlabel("Target PF Candidate $\eta$ (a.u.)",fontsize=13)
plt.ylabel("Predicted GNN Candidate $\eta$ (a.u.)",fontsize=13)
#plt.title("QCD Run 3, 1000 candidates")
plt.text(0.67, 1.05, "Run 3 (14 TeV)", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.02, 0.98, "CMS", transform=ax.transAxes, va="top", ha="left",size=16, fontweight='bold')
plt.text(0.18, 0.975, "Simulation Preliminary", transform=ax.transAxes, va="top", ha="left",size=12,style='italic')
#plt.text(0.03, 0.92, "QCD dijet events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.03, 0.92, "$\mathrm{t}\overline{\mathrm{t}}$ events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.tight_layout()
plt.savefig("eta_corr.pdf")

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.axes()

plt.plot([-5, 5], [-5, 5], color="black", lw=0.5)
plt.scatter(pm[subidx[rp], 2], tm[subidx[rp], 2], marker=".", alpha=0.5)
plt.xlim(-3,3)
plt.ylim(-3,3)


plt.xlabel("Target PF Candidate $\phi$ (a.u.)",fontsize=13)
plt.ylabel("Predicted GNN Candidate $\phi$ (a.u.)",fontsize=13)
#plt.title("QCD Run3, 1000 candidates")

plt.text(0.67, 1.05, "Run 3 (14 TeV)", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.02, 0.98, "CMS", transform=ax.transAxes, va="top", ha="left",size=16, fontweight='bold')
plt.text(0.18, 0.975, "Simulation Preliminary", transform=ax.transAxes, va="top", ha="left",size=12,style='italic')
#plt.text(0.03, 0.92, "QCD dijet events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.text(0.03, 0.92, "$\mathrm{t}\overline{\mathrm{t}}$ events", transform=ax.transAxes, va="top", ha="left",size=12)
plt.tight_layout()
plt.savefig("phi_corr.pdf")

In [None]:
import pandas as pd
import tqdm

import matplotlib as mpl
mpl.rcParams['figure.figsize'] = [8.0, 6.0]
mpl.rcParams['font.size'] = 12
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['figure.titlesize'] = 'medium'

d = full_dataset.get(1)
d.batch = torch.zeros((len(d.x)), dtype=torch.long)
train_end2end.data_prep(d, device=device)
edges, cand_id_onehot, cand_momentum = model(d)
output = edges.detach().cpu().numpy()
d = full_dataset.get(1)
x_data = d.x.detach().cpu().numpy()
mask = ((x_data[:,4]==0) & (x_data[:,5]==0) & (x_data[:,6]==0) & (x_data[:,7]==0))
good_index = np.zeros((x_data.shape[0],1,2),dtype=int)
good_x = x_data[:,2:4].copy()                                                                            
good_x[~mask] = x_data[~mask,2:4].copy()
df = pd.DataFrame(good_x, columns=['eta','phi'])
df['isTrack'] = ~mask
row, col = d.edge_index.cpu().detach().numpy()
y_truth = d.y.cpu().detach().numpy()

min_phi = -1.25
max_phi = 1.25
min_eta = -1.25
max_eta = 1.25
extra = 1.0
x = 'eta'
y = 'phi'
for plot_type in [['input'],['truth'],['output']]: 
    k = 0
    plt.figure(figsize=(8, 6))                        
    for i, j in tqdm.tqdm(zip(row, col),total=len(y_truth)):
        x1 = df[x][i]
        x2 = df[x][j]
        y1 = df[y][i]
        y2 = df[y][j]
        if (x1 < min_eta-extra or x1 > max_eta+extra) or (x2 < min_eta-extra or x2 > max_eta+extra): continue
        if (y1 < min_phi-extra or y1 > max_phi+extra) or (y2 < min_phi-extra or y2 > max_phi+extra): continue
        if 'input' in plot_type:
            seg_args = dict(c='b',alpha=0.1,zorder=1)
            plt.plot([df[x][i], df[x][j]],
                 [df[y][i], df[y][j]], '-', **seg_args)
        if 'truth' in plot_type and y_truth[k]:
            seg_args = dict(c='r',alpha=0.8,zorder=2)
            plt.plot([df[x][i], df[x][j]],
                 [df[y][i], df[y][j]], '-', **seg_args)
        if 'output' in plot_type:
            seg_args = dict(c='g',alpha=output[k].item(),zorder=3)
            plt.plot([df[x][i], df[x][j]],
                 [df[y][i], df[y][j]], '-', **seg_args)
        k+=1
    cut_mask = (df[x] > min_eta-extra) & (df[x] < max_eta+extra) & (df[y] > min_phi-extra) & (df[y] < max_phi+extra)
    cluster_mask = cut_mask & ~df['isTrack']
    track_mask = cut_mask & df['isTrack']
    plt.scatter(df[x][cluster_mask], df[y][cluster_mask],c='g',marker='o',s=50,zorder=4,alpha=1)
    plt.scatter(df[x][track_mask], df[y][track_mask],c='b',marker='p',s=50,zorder=5,alpha=1)
    plt.xlabel("Track or Cluster $\eta$",fontsize=18)
    plt.ylabel("Track or Cluster $\phi$",fontsize=18)
    plt.xlim(min_eta, max_eta)
    plt.ylim(min_phi, max_phi)
    plt.figtext(0.12, 0.90,'CMS',fontweight='bold', wrap=True, horizontalalignment='left', fontsize=20)
    plt.figtext(0.22, 0.90,'Simulation Preliminary', style='italic', wrap=True, horizontalalignment='left', fontsize=18)
    plt.figtext(0.67, 0.90,'Run 3 (14 TeV)',  wrap=True, horizontalalignment='left', fontsize=18)
    plt.savefig('graph_%s_%s_%s.pdf'%(x,y,'_'.join(plot_type)))