In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats

import tensorflow as tf
import tensorflow_datasets as tfds
import keras_tuner as kt
import keras.backend as K

from sklearn.metrics import accuracy_score, confusion_matrix

import seaborn as sns

import os
import pickle
import shap
from matplotlib.ticker import LogLocator, NullLocator
from matplotlib.colors import LogNorm

### Import data

In [None]:
cwd = os.getcwd() + '/'

charged_mlp_true = np.load(cwd + 'charged_mlp_true.npy')
charged_mlp_pred = np.load(cwd + 'charged_mlp_pred.npy')

neutral_mlp_true = np.load(cwd + 'neutral_mlp_true.npy')
neutral_mlp_pred = np.load(cwd + 'neutral_mlp_pred.npy')

charged_dedx = np.load(cwd + 'charged_dedx.npy')
charged_momentum = np.load(cwd + 'charged_momentum.npy')

manual_pid_cuts = np.load(cwd + 'manual_pid_cuts.npy')

predicted_manual_pid = np.load(cwd + 'predicted_manual_pid.npy')
true_manual_pid = np.load(cwd + 'true_manual_pid.npy')

with open(cwd + 'shapley_values_df.pkl', 'rb') as f:
    df_list = pickle.load(f)

### Define dedx_function

In [None]:
def dedx_function(momentum, a, b, c):
    return np.exp(a * momentum + b) + c 

# Figure 2 - dE/dx versus momentum with cuts overlaid

In [None]:
plt.figure(figsize=(12.5, 10))

bin_size = 250
seq = np.linspace(-0.5, 3)
a = 0.45

plt.hist2d(charged_momentum, charged_dedx, 
           bins=(bin_size, bin_size), range=[[0.1, 1.1], [0.05*10**-5, 2*10**-5]], cmap='viridis', norm = LogNorm())


#plt.hist2d(np.array(np.sqrt(test_charged['px']**2 + test_charged['py']**2 + test_charged['pz']**2)), np.log10(np.array(test_charged['dEdxCDC'])), 
 #          bins=(bin_size, bin_size), range=[[0.1, 1.1], [-6, -4]], cmap='viridis', norm = LogNorm())

#plt.hist2d(p_test, dedx_test, 
 #          bins=(bin_size, bin_size), range=[[0, 1.1], [0.05*10**-5, 2*10**-5]], cmap='viridis', norm = LogNorm())

# Proton
pr_dedx = dedx_function(seq, manual_pid_cuts[0], manual_pid_cuts[1], manual_pid_cuts[2])
plt.fill_between(seq, pr_dedx, pr_dedx.max(), color='red', alpha=a, label=r'$p$ | $\bar{p}$')
plt.plot(seq, dedx_function(seq, manual_pid_cuts[0], manual_pid_cuts[1], manual_pid_cuts[2]), c='k', linestyle='dotted', linewidth=5, label=r'$f_{1}(p)$')

#Kaon
ka_dedx = dedx_function(seq, manual_pid_cuts[3], manual_pid_cuts[4], manual_pid_cuts[5])
plt.fill_between(seq, pr_dedx, ka_dedx, where=(pr_dedx > ka_dedx), color='orchid', alpha=a, interpolate=True, label=r'$K^{\pm}$')
plt.plot(seq, ka_dedx, c='k', linestyle='dashed', linewidth=5, label=r'$f_{2}(p)$')

#Pion
pi_dedx = dedx_function(seq, manual_pid_cuts[6], manual_pid_cuts[7], manual_pid_cuts[8])
plt.fill_between(seq, ka_dedx, pi_dedx, where=(ka_dedx > pi_dedx), color='darkorange', alpha=a, interpolate=True, label=r'$e^{\pm}$')
plt.plot(seq, dedx_function(seq, manual_pid_cuts[6], manual_pid_cuts[7], manual_pid_cuts[8]), c='k', linestyle='solid', linewidth=5, label=r'$f_{3}(p)$')

#Electron
zero_arr = np.zeros(len(seq))
plt.fill_between(seq, pi_dedx, zero_arr, where=(pi_dedx > zero_arr), color='blue', alpha=a, interpolate=True, label=r'$\pi^{\pm}$ | $\mu^{\pm}$')
#plt.plot(seq, np.array([2 for i in range(len(seq))]) * 10**-6, c='orange', linestyle='--', linewidth=4, label=r'$e^{+}$')

plt.xlabel('Momentum [GeV/c]', fontsize=12)
plt.ylabel(f'CDC dE/dx [keV/cm]', fontsize=12)

plt.legend(loc='upper right', fontsize=12)
plt.show()

# Figure 3 - Manual PID confusion matrix

In [None]:
acc = accuracy_score(true_manual_pid, predicted_manual_pid)

cm = confusion_matrix(true_manual_pid, predicted_manual_pid, normalize='true')
#cm = np.array([cm[i]/len(testy[testy == i]) for i in range(len(cm))])
cm = np.round(cm, decimals=3)

cm = cm[:8]

fig = plt.figure()
fig.set_size_inches(10, 10)
ax = fig.add_subplot(111)


heatmap = sns.heatmap(cm, annot=True, fmt='g', ax=ax, vmin=0, vmax=1);

cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=13)

ax.set_xlabel('Predicted Particle', fontsize=15);ax.set_ylabel('Generated Particle', fontsize=15); 
#ax.set_title(f'Charged Particles PID with Total Accuracy of {round(acc, 4)}', fontsize=17); 
ax.xaxis.set_ticklabels([r'$p$', r'$\bar{p}$', r'$K^{+}$', r'$K^{-}$', r'$e^{-}$', r'$e^{+}$', r'$\pi^{+}$ | $\mu^{+}$', r'$\pi^{-}$ | $\mu^{-}$', 'no ID'], fontsize=15, rotation='vertical');
ax.yaxis.set_ticklabels([r'$p$', r'$\bar{p}$', r'$K^{+}$', r'$K^{-}$', r'$e^{-}$', r'$e^{+}$', r'$\pi^{+}$ | $\mu^{+}$', r'$\pi^{-}$ | $\mu^{-}$'], fontsize=15, rotation='horizontal');

# Figure 4 - MLP PID on charged particles

In [None]:
acc = accuracy_score(charged_mlp_true, charged_mlp_pred)

cm = confusion_matrix(charged_mlp_true, charged_mlp_pred, normalize='true')
#cm = np.array([cm[i]/len(testy[testy == i]) for i in range(len(cm))])
cm = np.round(cm, decimals=3)

cm = cm[:8]

fig = plt.figure()
fig.set_size_inches(10, 10)
ax = fig.add_subplot(111)


heatmap = sns.heatmap(cm, annot=True, fmt='g', ax=ax, vmin=0, vmax=1);

cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=13)

ax.set_xlabel('Predicted Particle', fontsize=15);ax.set_ylabel('Generated Particle', fontsize=15); 
#ax.set_title(f'Charged Particles PID with Total Accuracy of {round(acc, 4)}', fontsize=17); 
ax.xaxis.set_ticklabels([r'$p$', r'$\bar{p}$', r'$K^{+}$', r'$K^{-}$', r'$e^{-}$', r'$e^{+}$', r'$\pi^{+}$ | $\mu^{+}$', r'$\pi^{-}$ | $\mu^{-}$', 'no ID'], fontsize=15, rotation='vertical');
ax.yaxis.set_ticklabels([r'$p$', r'$\bar{p}$', r'$K^{+}$', r'$K^{-}$', r'$e^{-}$', r'$e^{+}$', r'$\pi^{+}$ | $\mu^{+}$', r'$\pi^{-}$ | $\mu^{-}$'], fontsize=15, rotation='horizontal');

# Figure 5 - MLP PID for neutral particles

In [None]:
acc_new = accuracy_score(neutral_mlp_true, neutral_mlp_pred)

cm = confusion_matrix(neutral_mlp_true, neutral_mlp_pred, normalize='true')
cm = np.round(cm, decimals=3)
cm = cm[:-1, :]

fig = plt.figure()
fig.set_size_inches(12, 12)
ax = fig.add_subplot(111)


heatmap = sns.heatmap(cm, annot=True, fmt='g', ax=ax, vmin=0, vmax=1, annot_kws={"fontsize": 15});

cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=13)

ax.set_xlabel('Predicted Particle', fontsize=15);ax.set_ylabel('Generated Particle', fontsize=15); 
#ax.set_title(f'Total Accuracy of {round(acc_new, 4)}', fontsize=17); 
ax.xaxis.set_ticklabels([r'$\gamma$', r'$K_{L}^{0}$', r'$n$', 'no ID'], fontsize=17, rotation='vertical');
ax.yaxis.set_ticklabels([r'$\gamma$', r'$K_{L}^{0}$', r'$n$'], fontsize=17, rotation='horizontal');

# Figure 6 - Shapley Vaues for positively charged particles

In [None]:
fig, ax = plt.subplots(figsize=(25, 8))

ptypedic_new = {0:r'$\gamma$', 1:r'$K_{L}^{0}$', 2:r'$n$', 3:r'$p$', 4:r'$\bar{p}$', 
            5:r'$K^{+}$', 6:r'$K^{-}$', 7:r'$e^{-}$', 8:r'$e^{+}$', 
            9:r'$\pi^{+}$ | $\mu^{+}$', 10:r'$\pi^{-}$ | $\mu^{-}$'}

box_width = 0.2
#box_pos = 0.2
cap_width = 2
whisker_width = 2
color_dict = {0:'violet', 1:'red', 2:'skyblue', 3:'green'}
pos_dict = {0:-0.25, 1:-0.25 + (1*0.1666), 2:-0.25 + (2*0.1666), 3:-0.25 + (3*0.1666)}
count = 0

pos_char = np.array([3, 5, 8, 9])

for i in pos_char:
    box = df_list[i].boxplot(boxprops=dict(edgecolor='k', alpha=1, facecolor=color_dict[count]),
                          positions=np.arange(len(np.array(df_list[i].columns))) + pos_dict[count],
                          capprops=dict(color=color_dict[count], linewidth=cap_width),
                          flierprops=dict(marker='.', markerfacecolor=color_dict[count], markersize=8, markeredgecolor=color_dict[count]),
                          whiskerprops=dict(color=color_dict[count], linewidth=whisker_width),
                          medianprops=dict(color='k', linewidth=2.5),
                          showfliers=False,
                          widths=box_width,
                          patch_artist=True)
    plt.scatter(0, 0, c=color_dict[count], label=ptypedic_new[i], marker='s')
    count += 1

char_labs = np.array(df_list[i].columns)

ax.set_xticks(np.arange(len(char_labs)))
ax.set_xticklabels(char_labs)
plt.yscale('log')

ax.tick_params(axis='y', labelsize=15)
ax.set_ylabel('')
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.set_ylabel('SHAP Value Magnitude', fontsize=20, labelpad=20, rotation=90)
plt.setp(box.xaxis.get_majorticklabels(), fontsize=15)
plt.xticks(rotation=90, fontsize=20)


y_locator = LogLocator(base=10.0, numticks=15)
y_minor_locator = NullLocator()
plt.gca().yaxis.set_major_locator(y_locator)
plt.gca().yaxis.set_minor_locator(y_locator)
plt.grid(True, which="both", axis="y")
plt.yticks(rotation=90, fontsize=15)

legend = ax.legend(fontsize = 15, loc = 'lower right')

# Rotate the legend text
for text in legend.get_texts():
    text.set_rotation(90)

plt.show()

# Figure 7 - Shapley values for negatively charged particles

In [None]:
fig, ax = plt.subplots(figsize=(25, 8))

ptypedic_new = {0:r'$\gamma$', 1:r'$K_{L}^{0}$', 2:r'$n$', 3:r'$p$', 4:r'$\bar{p}$', 
            5:r'$K^{+}$', 6:r'$K^{-}$', 7:r'$e^{-}$', 8:r'$e^{+}$', 
            9:r'$\pi^{+}$ | $\mu^{+}$', 10:r'$\pi^{-}$ | $\mu^{-}$'}


box_width = 0.2
#box_pos = 0.2
cap_width = 2
whisker_width = 2
color_dict = {0:'violet', 1:'red', 2:'skyblue', 3:'green'}
pos_dict = {0:-0.25, 1:-0.25 + (1*0.1666), 2:-0.25 + (2*0.1666), 3:-0.25 + (3*0.1666)}
count = 0

pos_char = np.array([4, 6, 7, 10])

for i in pos_char:
    box = df_list[i].boxplot(boxprops=dict(edgecolor='k', alpha=1, facecolor=color_dict[count]),
                          positions=np.arange(len(np.array(df_list[i].columns))) + pos_dict[count],
                          capprops=dict(color=color_dict[count], linewidth=cap_width),
                          flierprops=dict(marker='.', markerfacecolor=color_dict[count], markersize=8, markeredgecolor=color_dict[count]),
                          whiskerprops=dict(color=color_dict[count], linewidth=whisker_width),
                          medianprops=dict(color='k', linewidth=2.5),
                          showfliers=False,
                          widths=box_width,
                          patch_artist=True)
    plt.scatter(0, 0, c=color_dict[count], label=ptypedic_new[i], marker='s')
    count += 1

char_labs = np.array(df_list[i].columns)

ax.set_xticks(np.arange(len(char_labs)))
ax.set_xticklabels(char_labs)
plt.yscale('log')

ax.tick_params(axis='y', labelsize=15)
ax.set_ylabel('')
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.set_ylabel('SHAP Value Magnitude', fontsize=20, labelpad=20, rotation=90)
plt.setp(box.xaxis.get_majorticklabels(), fontsize=15)
plt.xticks(rotation=90, fontsize=20)


y_locator = LogLocator(base=10.0, numticks=15)
y_minor_locator = NullLocator()
plt.gca().yaxis.set_major_locator(y_locator)
plt.gca().yaxis.set_minor_locator(y_locator)
plt.grid(True, which="both", axis="y")
plt.yticks(rotation=90, fontsize=15)

legend = ax.legend(fontsize = 15, loc = 'lower right')

# Rotate the legend text
for text in legend.get_texts():
    text.set_rotation(90)

plt.show()

# Figure 8 - Shapley Values for neutral particles

In [None]:
fig, ax = plt.subplots(figsize=(12.5, 8))

ptypedic_new = {0:r'$\gamma$', 1:r'$K_{L}^{0}$', 2:r'$n$', 3:r'$p$', 4:r'$\bar{p}$', 
            5:r'$K^{+}$', 6:r'$K^{-}$', 7:r'$e^{-}$', 8:r'$e^{+}$', 
            9:r'$\pi^{+}$ | $\mu^{+}$', 10:r'$\pi^{-}$ | $\mu^{-}$'}

box_width = 0.2
#box_pos = 0.2
cap_width = 2
whisker_width = 2
color_dict = {0:'violet', 1:'red', 2:'skyblue'}
pos_dict = {0:-0.25, 1:0, 2:0.25}
count = 0

pos_char = np.array([0, 1, 2])

for i in pos_char:
    box = df_list[i].boxplot(boxprops=dict(edgecolor='k', alpha=1, facecolor=color_dict[count]),
                          positions=np.arange(len(np.array(df_list[i].columns))) + pos_dict[count],
                          capprops=dict(color=color_dict[count], linewidth=cap_width),
                          flierprops=dict(marker='.', markerfacecolor=color_dict[count], markersize=8, markeredgecolor=color_dict[count]),
                          whiskerprops=dict(color=color_dict[count], linewidth=whisker_width),
                          medianprops=dict(color='k', linewidth=2.5),
                          showfliers=False,
                          widths=box_width,
                          patch_artist=True)
    plt.scatter(0, 0, c=color_dict[count], label=ptypedic_new[i], marker='s')
    count += 1

char_labs = np.array(df_list[i].columns)

ax.set_xticks(np.arange(len(char_labs)))
ax.set_xticklabels(char_labs)
plt.yscale('log')

ax.tick_params(axis='y', labelsize=15)
ax.set_ylabel('')
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.set_ylabel('SHAP Value Magnitude', fontsize=20, labelpad=20, rotation=90)
plt.setp(box.xaxis.get_majorticklabels(), fontsize=15)
plt.xticks(rotation=90, fontsize=20)


y_locator = LogLocator(base=10.0, numticks=15)
y_minor_locator = NullLocator()
plt.gca().yaxis.set_major_locator(y_locator)
plt.gca().yaxis.set_minor_locator(y_locator)
plt.grid(True, which="both", axis="y")
plt.yticks(rotation=90, fontsize=15)

legend = ax.legend(fontsize = 15, loc = 'lower right')

# Rotate the legend text
for text in legend.get_texts():
    text.set_rotation(90)

plt.show()