In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from sklearn import metrics

In [None]:
def NormOne(a):
    return np.ones_like(a)/len(a)

## Look into inference output

Read file with inference results on validation data set and compute useful derived quantities.

In [None]:
pred = pd.read_hdf("infer_test_all2.h5",mode='a')
pred

In [None]:
print('total number of events:',len(pred.groupby(['run','subrun','event']).count()))

In [None]:
classes=['MIP','HIP','shower','michel','diffuse']
pred['sem_label'] = pd.Categorical(pred['y_semantic']).codes
pred['sem_pred'] = pd.Categorical(pred['x_semantic']).codes
pred['x_semantic_2nd'] = pred[classes].mask(pred[classes].eq(pred[classes].max(axis=1), axis=0)).idxmax(axis=1)

In [None]:
pred['isgood'] = pred.eval('sem_label==sem_pred').astype(int)

In [None]:
pred['true_score'] = 0
pred['pred_score'] = 0
pred['pred_score_2nd'] = 0
for ctg in classes:
    pred.loc[pred['y_semantic']==ctg, 'true_score'] = pred[ctg][pred['y_semantic']==ctg]
    pred.loc[pred['x_semantic']==ctg, 'pred_score'] = pred[ctg][pred['x_semantic']==ctg]
    pred.loc[pred['x_semantic_2nd']==ctg, 'pred_score_2nd'] = pred[ctg][pred['x_semantic_2nd']==ctg]

In [None]:
pred['pred_score_p'] = pred.eval('exp(pred_score)/(exp(MIP)+exp(HIP)+exp(shower)+exp(michel)+exp(diffuse))')
pred['true_score_p'] = pred.eval('exp(true_score)/(exp(MIP)+exp(HIP)+exp(shower)+exp(michel)+exp(diffuse))')
pred

In [None]:
plt.hist(pred.query('sem_label>=0')['sem_label'],weights=NormOne(pred.query('sem_label>=0')['sem_label']),bins=np.linspace(0,5,6),histtype='step',label='true')
plt.hist(pred.query('sem_label>=0')['sem_pred'] ,weights=NormOne(pred.query('sem_label>=0')['sem_pred']),bins=np.linspace(0,5,6),histtype='step',label='predicted')
plt.xlabel('category')
plt.ylabel('fraction of hits')
plt.title('abundance of categories in dataset')
plt.xticks([0.5,1.5,2.5,3.5,4.5],classes)
plt.legend()
plt.show()

## Study network confidence in prediction

Here we use as a confidence score the softmax of the original scores, so that we get an estimate for the probability of the prediction.

In [None]:
for i,ctg in enumerate(classes):
    x=pred.query('sem_label>=0 and sem_pred==%i'%i)['pred_score_p']
    plt.hist(x,weights=NormOne(x),bins=np.linspace(0,1,11),histtype='step',label=ctg)
plt.legend(loc=2)
plt.xlabel('score')
plt.ylabel('fraction of entries')
plt.title('predicted score per hit in different categories')
plt.show()

In [None]:
x=pred.query('sem_label>=0 and isgood==1')['true_score_p']
plt.hist(x,bins=np.linspace(0,1,11),histtype='step',label='correct')
x=pred.query('sem_label>=0 and isgood==0')['true_score_p']
plt.hist(x,bins=np.linspace(0,1,11),histtype='step',label='incorrect, true')
x=pred.query('sem_label>=0 and isgood==0')['pred_score_p']
plt.hist(x,bins=np.linspace(0,1,11),histtype='step',label='incorrect, pred')
plt.xlabel('score')
plt.ylabel('hit count')
plt.yscale('log')
plt.legend(loc=2)
plt.title('score per hit')
plt.show()

x=pred.query('sem_label>=0 and isgood==1')['true_score_p']
plt.hist(x,weights=NormOne(x),bins=np.linspace(0,1,11),histtype='step',label='correct')
x=pred.query('sem_label>=0 and isgood==0')['true_score_p']
plt.hist(x,weights=NormOne(x),bins=np.linspace(0,1,11),histtype='step',label='incorrect, true')
x=pred.query('sem_label>=0 and isgood==0')['pred_score_p']
plt.hist(x,weights=NormOne(x),bins=np.linspace(0,1,11),histtype='step',label='incorrect, pred')
plt.xlabel('score')
plt.ylabel('fraction of entries')
plt.legend(loc=2)
plt.title('score per hit')
plt.show()

In [None]:
x=pred.query('sem_label>=0 and isgood==1').eval('exp(pred_score_2nd-pred_score)')
plt.hist(x,weights=NormOne(x),bins=np.linspace(0,1,11),histtype='step',label='correct')
x=pred.query('sem_label>=0 and isgood==0').eval('exp(pred_score_2nd-pred_score)')
plt.hist(x,weights=NormOne(x),bins=np.linspace(0,1,11),histtype='step',label='incorrect')
plt.xlabel('exp(pred_2nd-pred)')
plt.ylabel('a.u.')
plt.legend(loc=1)
plt.title('ratio of score for top 2 categories per hit')
plt.show()

In [None]:
x=pred.query('sem_label>=0 and isgood==0').eval('exp(true_score-pred_score)')
plt.hist(x,weights=NormOne(x),bins=np.linspace(0,1,11),histtype='step')
plt.xlabel('exp(true_score-pred_score)')
plt.ylabel('a.u.')
plt.title('ratio of score of true to incorrectly predicted hit category')
plt.show()

In [None]:
fpr, tpr, _ = metrics.roc_curve(pred.query('sem_label>=0')['isgood'], pred.query('sem_label>=0')['pred_score_p'])
plt.plot(fpr,tpr,label='all')
for c in [0,1,2,3,4]:
    fpr, tpr, _ = metrics.roc_curve(pred.query('sem_label==%i'%c)['isgood'], pred.query('sem_label==%i'%c)['pred_score_p'])
    plt.plot(fpr,tpr,label=classes[c])
    
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.title('ROC curves as a function of score probability')
plt.legend(loc=7)
plt.show()

## Confusion matrix by hit

In [None]:
x = np.histogram2d(pred.query('sem_label>=0')['sem_pred'],pred.query('sem_label>=0')['sem_label'],bins=[np.linspace(0,5,6),np.linspace(0,5,6)])

eff = (x[0].transpose() / x[0].sum(axis=1)).transpose()
pur = x[0] / x[0].sum(axis=0)

fig = plt.figure(figsize=(7, 6))
plt.imshow(eff,origin='lower',cmap='copper')
for i in range(len(eff[0])):
    for j in range(len(eff[0])):
        text = plt.text(j, i, "%.2f"%eff[i, j],ha="center", va="center", color="w")
plt.colorbar()
plt.xlabel("assigned label")
ax = plt.gca()
ax.set_xticklabels(['','MIP','HIP','SHR','MCL','DFS'])
ax.set_yticklabels(['','MIP','HIP','SHR','MCL','DFS'])
plt.ylabel("true label")
plt.title('efficiency (by hit)')
plt.tight_layout()
fig.show()

fig = plt.figure(figsize=(7, 6))
plt.imshow(pur,origin='lower',cmap='copper')
for i in range(len(pur[0])):
    for j in range(len(pur[0])):
        text = plt.text(j, i, "%.2f"%pur[i, j],ha="center", va="center", color="w")
plt.colorbar()
plt.xlabel("assigned label")
ax = plt.gca()
ax.set_xticklabels(['','MIP','HIP','SHR','MCL','DFS'])
ax.set_yticklabels(['','MIP','HIP','SHR','MCL','DFS'])
plt.ylabel("true label")
plt.title('purity (by hit)')
plt.tight_layout()
fig.show()


In [None]:
print('overall accuracy=',pred.query('sem_label>=0')['isgood'].sum()/pred.query('sem_label>=0')['isgood'].count())