In [None]:
import h5py
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.metrics import roc_curve, auc
import qkeras
from qkeras import *
#import tensorflow_probability as tfp
#import keras_tuner
#from keras_tuner import Hyperband
import joblib

# data files

All input files have data already sorted in Calo regions (i, j) ~ (18, 14)<br>
i = 0 -> 17 corresponds to GCT_Phi = 0 -> 17<br>
j = 0 -> 13 corresponds to RCT_Eta = 4 -> 17

In [None]:
ZB18A0 = h5py.File('bkg/ZeroBias2018RunA_0.h5', 'r')['CaloRegions'][:].astype('float32')
ZB18A1 = h5py.File('bkg/ZeroBias2018RunA_1.h5', 'r')['CaloRegions'][:].astype('float32')
ZB18A2 = h5py.File('bkg/ZeroBias2018RunA_2.h5', 'r')['CaloRegions'][:].astype('float32')
ZB18A = np.concatenate((ZB18A0, ZB18A1, ZB18A2)); del ZB18A0, ZB18A1, ZB18A2
ZB18A = np.reshape(ZB18A, (-1,18,14,1))

ZB18D = h5py.File('bkg/EphemeralZeroBias2018RunD_2.h5', 'r')['CaloRegions'][:300000].astype('float32')
ZB18D = np.reshape(ZB18D, (-1,18,14,1))

ZB23 = h5py.File('bkg/EZB2023_0.h5', 'r')['CaloRegions'][:300000].astype('float32')
ZB23 = np.reshape(ZB23, (-1,18,14,1))

print('ZeroBias2018A shape: ' + str(ZB18A.shape))
print('ZeroBias2018D shape: ' + str(ZB18D.shape))
print('ZeroBias2023  shape: ' + str(ZB23.shape))

MC_files = []
MC_files.append('bkg/110X/QCD_Pt-15to7000_TuneCP5_Flat_14TeV_0.h5')#i=0
MC_files.append('bkg/120X/SingleNeutrino_E-10-gun_0.h5')#i=1
MC_files.append('bkg/120X/SingleNeutrino_Pt-2To20-gun_0.h5')#i=2
'''
MC_files.append('sig/110X/GluGluToHHTo4B_node_SM_TuneCP5_14TeV.h5')#i=3
MC_files.append('sig/110X/HTo2LongLivedTo4mu_MH-1000_MFF-450_CTau-10000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/110X/HTo2LongLivedTo4mu_MH-125_MFF-12_CTau-900mm_TuneCP5_14TeV.h5')
MC_files.append('sig/110X/HTo2LongLivedTo4mu_MH-125_MFF-25_CTau-1500mm_TuneCP5_14TeV.h5')
MC_files.append('sig/110X/HTo2LongLivedTo4mu_MH-125_MFF-50_CTau-3000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/110X/VBFHToTauTau_M125_TuneCUETP8M1_14TeV.h5')
MC_files.append('sig/110X/VBF_HH_CV_1_C2V_1_C3_1_TuneCP5_PSweights_14TeV.h5')
MC_files.append('sig/110X/VBF_HToInvisible_M125_TuneCUETP8M1_14TeV.h5')
MC_files.append('sig/110X/VectorZPrimeToQQ_M100_pT300_TuneCP5_14TeV.h5')
MC_files.append('sig/110X/VectorZPrimeToQQ_M200_pT300_TuneCP5_14TeV.h5')
MC_files.append('sig/110X/VectorZPrimeToQQ_M50_pT300_TuneCP5_14TeV.h5')#i=13
MC_files.append('sig/110X/ZprimeToZH_MZprime1000_MZ50_MH80_ZTouds_HTouds_narrow_TuneCP5_14TeV.h5')
MC_files.append('sig/110X/ZprimeToZH_MZprime600_MZ50_MH80_ZTouds_HTouds_narrow_TuneCP5_14TeV.h5')
MC_files.append('sig/110X/ZprimeToZH_MZprime800_MZ50_MH80_ZTouds_HTouds_narrow_TuneCP5_14TeV.h5')
'''
MC_files.append('sig/120X/GluGluHToTauTau_M-125_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/GluGluToHHTo4B_node_cHHH1_TuneCP5_14TeV.h5')
'''
MC_files.append('sig/120X/GluGluToHHTo4B_node_cHHH5_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-1000_MFF-450_CTau-100000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-1000_MFF-450_CTau-10000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-125_MFF-12_CTau-9000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-125_MFF-12_CTau-900mm_TuneCP5_14TeV.h5')#i=23
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-125_MFF-25_CTau-15000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-125_MFF-25_CTau-1500mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-125_MFF-50_CTau-30000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-125_MFF-50_CTau-3000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-250_MFF-120_CTau-10000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-250_MFF-120_CTau-1000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-250_MFF-60_CTau-1000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-350_MFF-160_CTau-10000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-350_MFF-160_CTau-1000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-350_MFF-160_CTau-500mm_TuneCP5_14TeV.h5')#i=33
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-350_MFF-80_CTau-10000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-350_MFF-80_CTau-1000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4b_MH-350_MFF-80_CTau-500mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4mu_MH-1000_MFF-450_CTau-10000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4mu_MH-125_MFF-12_CTau-900mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4mu_MH-125_MFF-25_CTau-1500mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/HTo2LongLivedTo4mu_MH-125_MFF-50_CTau-3000mm_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/SUSYGluGluToBBHToBB_NarrowWidth_M-120_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/SUSYGluGluToBBHToBB_NarrowWidth_M-350_TuneCP5_14TeV.h5')#i=43
MC_files.append('sig/120X/SUSYGluGluToBBHToBB_NarrowWidth_M-600_TuneCP5_14TeV.h5')
'''
MC_files.append('sig/120X/TT_TuneCP5_14TeV.h5')
'''
MC_files.append('sig/120X/TprimeBToTH_M-650_LH_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/VBFHHTo4B_CV_1_C2V_2_C3_1_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/VBFHToInvisible_M125_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/VBFHToTauTau_M125_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/VectorZPrimeGammaToQQGamma_M-10_GPt-75_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/VectorZPrimeToQQ_M-100_Pt-300_TuneCP5_14TeV.h5')
MC_files.append('sig/120X/VectorZPrimeToQQ_M-200_Pt-300_TuneCP5_14TeV.h5')#i=52
MC_files.append('sig/120X/emj-mMed-800-mDark-10-ctau-0p1.h5')
MC_files.append('sig/120X/emj-mMed-800-mDark-10-ctau-1.h5')
MC_files.append('sig/120X/emj-mMed-800-mDark-10-ctau-1000.h5')
MC_files.append('sig/120X/emj-mMed-800-mDark-10-ctau-150.h5')
'''
MC_files.append('sig/120X/haa4b_ma15_powheg.h5')
'''
MC_files.append('sig/120X/haa4b_ma15_powheg_FlatPU0To80.h5')
MC_files.append('sig/120X/haa4b_ma50_powheg.h5')
MC_files.append('sig/120X/haa4taus_ma15_powheg.h5')
'''
MC_files.append('sig/120X/SUEP_L1_NOPU.h5')

MC = []
for i in range(len(MC_files)):
    MC.append(h5py.File(MC_files[i], 'r')['CaloRegions'][:100000].astype('float32'))
    MC[i] = np.reshape(MC[i], (-1,18,14,1))
    print('i = ' + str(i) + ': ' + str(MC[i].shape))

In [None]:
#SUEP_NoPU

In [None]:
np.random.seed(420)
suep_zb = np.empty((MC[7].shape[0],18,14,1))
ZB_random_event = np.random.randint(low=0, high=ZB18D.shape[0], size=MC[7].shape[0])
for j in range(MC[7].shape[0]):
    suep_zb[j,:,:,0] = ZB18D[ZB_random_event[j],:,:,0] + MC[7][j,:,:,0]

In [None]:
for i in range(1260,1270):
    fig, ax = plt.subplots(figsize = (10,10))
    print(str(MC_files[7]))
    ax = plt.subplot(2, 2, 1)
    ax = sns.heatmap(MC[7][i,:,:,0].reshape(18, 14), vmin = 0, vmax = suep_zb[i,:,:,0].max(), cmap = "Purples", cbar_kws = {'label': 'Et (GeV)'})
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('SUEP_NoPU')
    
    ax = plt.subplot(2, 2, 2)
    ax = sns.heatmap(suep_zb[i,:,:,0].reshape(18, 14), vmin = 0, vmax = suep_zb[i,:,:,0].max(), cmap = "Purples", cbar_kws = {'label': 'Et (GeV)'})
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('SUEP_NoPU overlayed with ZB')
    
    plt.show()

In [None]:
MC[7][:,:,:,0] = suep_zb[:,:,:,0]

Take a look at some ZB statistics.

In [None]:
ZB18A_mean = np.mean(ZB18A, axis = 0)
ZB18D_mean = np.mean(ZB18D, axis = 0)
ZB23_mean = np.mean(ZB23, axis = 0)

fig, ax = plt.subplots(figsize = (10,10))
ax = plt.subplot(2, 2, 2)
ax = sns.heatmap(ZB18A_mean.reshape(18, 14), vmin = 0, vmax = ZB18A_mean.max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title('Mean Et (ZB2018RunA)')
plt.show()

fig, ax = plt.subplots(figsize = (10,10))
ax = plt.subplot(2, 2, 2)
ax = sns.heatmap(ZB18D_mean.reshape(18, 14), vmin = 0, vmax = ZB18D_mean.max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title('Mean Et (ZB2018RunD)')
plt.show()

fig, ax = plt.subplots(figsize = (10,10))
ax = plt.subplot(2, 2, 2)
ax = sns.heatmap(ZB23_mean.reshape(18, 14), vmin = 0, vmax = ZB23_mean.max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title('Mean Et (ZB2023)')
plt.show()

In [None]:
plt.hist(ZB18A.reshape((-1)), bins = 20, range=(0,1024), density=1, label='2018A', log = True, histtype='step')
plt.hist(ZB18D.reshape((-1)), bins = 20, range=(0,1024), density=1, label='2018D', log = True, histtype='step')
plt.hist(ZB23.reshape((-1)), bins = 20, range=(0,1024), density=1, label='2023', log = True, histtype='step')
plt.xlabel("ZB Et")
plt.legend(loc='best')
plt.show()

print('Mean ZB2018A pT = ' + str(np.mean(ZB18A.reshape(-1))))
print('Mean ZB2018D pT = ' + str(np.mean(ZB18D.reshape(-1))))
print('Mean ZB2023  pT = ' + str(np.mean(ZB23.reshape(-1))))

In [None]:
ZB18A_noBoundary = h5py.File('bkg/ZeroBias2018RunA_0.h5', 'r')
ZB18A_noBoundary = np.stack((ZB18A_noBoundary['CaloRegions'][:500000].astype('float32')))
ZB18A_noBoundary = np.reshape(ZB18A_noBoundary, (-1,18,14,1))
ZB18A_noBoundary[:,:,0,0] = 0
ZB18A_noBoundary[:,:,13,0] = 0

ZB18D_noBoundary = h5py.File('bkg/EphemeralZeroBias2018RunD_2.h5', 'r')
ZB18D_noBoundary = np.stack((ZB18D_noBoundary['CaloRegions'][:500000].astype('float32')))
ZB18D_noBoundary = np.reshape(ZB18D_noBoundary, (-1,18,14,1))
ZB18D_noBoundary[:,:,0,0] = 0
ZB18D_noBoundary[:,:,13,0] = 0

ZB23_noBoundary = h5py.File('bkg/EZB2023_0.h5', 'r')
ZB23_noBoundary = np.stack((ZB23_noBoundary['CaloRegions'][:500000].astype('float32')))
ZB23_noBoundary = np.reshape(ZB23_noBoundary, (-1,18,14,1))
ZB23_noBoundary[:,:,0,0] = 0
ZB23_noBoundary[:,:,13,0] = 0

plt.hist(ZB18A_noBoundary.reshape((-1)), bins = 20, range=(0,1024), density=1, label='2018A', log = True, histtype='step')
plt.hist(ZB18D_noBoundary.reshape((-1)), bins = 20, range=(0,1024), density=1, label='2018D', log = True, histtype='step')
plt.hist(ZB23_noBoundary.reshape((-1)), bins = 20, range=(0,1024), density=1, label='2023', log = True, histtype='step')
plt.xlabel("ZB Et (excluding RCT_Eta=4,17)")
plt.legend(loc='best')
plt.show()

print('Mean ZB2018A pT (excluding RCT_Eta=4,17)= ' + str(np.mean(ZB18A_noBoundary.reshape(-1))))
print('Mean ZB2018D pT (excluding RCT_Eta=4,17)= ' + str(np.mean(ZB18D_noBoundary.reshape(-1))))
print('Mean ZB2023  pT (excluding RCT_Eta=4,17)= ' + str(np.mean(ZB23_noBoundary.reshape(-1))))

ZB18A_noBoundary_mean = np.mean(ZB18A_noBoundary, axis = 0)
ZB18D_noBoundary_mean = np.mean(ZB18D_noBoundary, axis = 0)
ZB23_noBoundary_mean = np.mean(ZB23_noBoundary, axis = 0)

fig, ax = plt.subplots(figsize = (10,10))
ax = plt.subplot(2, 2, 2)
ax = sns.heatmap(ZB18A_noBoundary_mean.reshape(18, 14), vmin = 0, vmax = ZB18A_noBoundary_mean.max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title('Mean Et (ZB2018RunA), excluding RCT_Eta=4,17')
plt.show()

fig, ax = plt.subplots(figsize = (10,10))
ax = plt.subplot(2, 2, 2)
ax = sns.heatmap(ZB18D_noBoundary_mean.reshape(18, 14), vmin = 0, vmax = ZB18D_noBoundary_mean.max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title('Mean Et (ZB2018RunD), excluding RCT_Eta=4,17')
plt.show()

fig, ax = plt.subplots(figsize = (10,10))
ax = plt.subplot(2, 2, 2)
ax = sns.heatmap(ZB23_noBoundary_mean.reshape(18, 14), vmin = 0, vmax = ZB23_noBoundary_mean.max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title('Mean Et (ZB2023), excluding RCT_Eta=4,17')
plt.show()

del ZB18A_noBoundary, ZB18D_noBoundary, ZB23_noBoundary, ZB18A_noBoundary_mean, ZB18D_noBoundary_mean, ZB23_noBoundary_mean

# CNN AE (teacher model)

In [None]:
encoder_input = tf.keras.Input(shape=(18,14,1), name='input')

encoder = layers.Conv2D(20, (3,3), strides=1, padding='same', name='conv2d_1')(encoder_input)
encoder = layers.Activation('relu', name='relu_1')(encoder)
encoder = layers.AveragePooling2D((2,2), name='pool_1')(encoder)
encoder = layers.Conv2D(30, (3,3), strides=1, padding='same', name='conv2d_2')(encoder)
encoder = layers.Activation('relu', name='relu_2')(encoder)
encoder = layers.Flatten(name='flatten')(encoder)

encoder_output = layers.Dense(100, activation='relu', name='latent')(encoder)

encoder = tf.keras.models.Model(encoder_input, encoder_output)
encoder.summary()

In [None]:
decoder = layers.Dense(9*7*30, name='dense')(encoder_output)
decoder = layers.Reshape((9,7,30), name='reshape2')(decoder)
decoder = layers.Activation('relu', name='relu_3')(decoder)
decoder = layers.Conv2D(30, (3,3), strides=1, padding='same', name='conv2d_3')(decoder)
decoder = layers.Activation('relu', name='relu_4')(decoder)
decoder = layers.UpSampling2D((2,2), name='upsampling')(decoder)
decoder = layers.Conv2D(20, (3,3), strides=1, padding='same', name='conv2d_4')(decoder)
decoder = layers.Activation('relu', name='relu_5')(decoder)

decoder_output = layers.Conv2D(1, (3,3), activation='relu', strides=1, padding='same', name='output')(decoder)

In [None]:
teacher = tf.keras.Model(encoder_input, decoder_output)
teacher.summary()

In [None]:
teacher.compile(optimizer = keras.optimizers.Adam(learning_rate=0.002), loss = 'mse')

# Training

In [None]:
X = ZB23

train_ratio = 0.01
val_ratio = 0.01
test_ratio = 1 - train_ratio - val_ratio
X_train_val, X_test = train_test_split(X, test_size = test_ratio, random_state = 123)
X_train, X_val = train_test_split(X_train_val, test_size = val_ratio/(val_ratio + train_ratio), random_state = 123)

In [None]:
history = teacher.fit(X_train, X_train,
                      epochs = 80,
                      validation_data = (X_val, X_val),
                      batch_size = 128)

In [None]:
plt.figure(figsize = (15,10))
axes = plt.subplot(2, 2, 1)
axes.plot(history.history['loss'], label = 'train loss')
axes.plot(history.history['val_loss'], label = 'val loss')
axes.legend(loc = "upper right")
axes.set_xlabel('Epoch')
axes.set_ylabel('Loss')

# Save/load trained models

In [None]:
teacher.save('saved_models/teacher_ZB2023')

In [None]:
teacher = tf.keras.models.load_model('saved_models/teacher_ZB2023')
teacher.summary()

In [None]:
#student.save('saved_models/student_ZB2023_v1/')
student.save('saved_models/student_ZB2023_v2/')

In [None]:
tf.get_logger().setLevel('ERROR')
#student = qkeras.utils.load_qmodel('saved_models/model_sA') # 2018 v1
#student = qkeras.utils.load_qmodel('saved_models/qmodel_oct24') # 2018 v2
#student = qkeras.utils.load_qmodel('saved_models/student_ZB2023_v1') # 2023 v1
student = qkeras.utils.load_qmodel('saved_models/student_ZB2023_v2') # 2023 v2
student.summary()
student.get_config()

# Loss distribution

In [None]:
X_train_predict_teacher = teacher.predict(X_train)
X_val_predict_teacher = teacher.predict(X_val)
X_test_predict_teacher = teacher.predict(X_test)
MC_predict_teacher = []
for i in range(len(MC)):
    MC_predict_teacher.append(teacher.predict(MC[i]))

In [None]:
def loss(y_true, y_pred, choice):
    if choice == 'mse':
        loss = np.mean((y_true - y_pred)**2, axis = (1,2,3))
        return loss

In [None]:
X_train_loss_teacher = loss(X_train, X_train_predict_teacher, 'mse')
X_val_loss_teacher = loss(X_val, X_val_predict_teacher, 'mse')
X_test_loss_teacher = loss(X_test, X_test_predict_teacher, 'mse')

MC_loss_teacher = []
for i in range(len(MC)):
    MC_loss_teacher.append(loss(MC[i], MC_predict_teacher[i], 'mse'))

In [None]:
nbins = 80
rmin = 0
rmax = 100
plt.hist(X_test_loss_teacher, density = 1, bins = nbins, alpha = 0.3, label = 'test (ZeroBias)', range = (rmin, rmax), log = True)
plt.hist(MC_loss_teacher[3], density = 1, bins = nbins, label = 'H->tautau', color='green', histtype = 'step', range = (rmin, rmax))
plt.hist(MC_loss_teacher[4], density = 1, bins = nbins, label = 'SM HH->4b', color='red', histtype = 'step', range = (rmin, rmax))
plt.hist(MC_loss_teacher[5], density = 1, bins = nbins, label = 'TTbar', color='blue', histtype = 'step', range = (rmin, rmax))
plt.hist(MC_loss_teacher[6], density = 1, bins = nbins, label = 'H->aa->4b', color='orange', histtype = 'step', range = (rmin, rmax))
plt.hist(MC_loss_teacher[7], density = 1, bins = nbins, label = 'SUEP', color='purple', histtype = 'step', range = (rmin, rmax))
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.xlabel("Anomaly score (teacher)")
plt.show()

# Comparison between original and reconstructed inputs

In [None]:
show_ZB = True
#show_ZB = False
n = 5
for i in range(20,30):
    fig, ax = plt.subplots(figsize = (17,17))
    if show_ZB == True:
        print('ZB test\nloss = ' + str(X_test_loss_teacher[i]))
    else:
        print(str(MC_files[n]) + '\nloss = ' + str(MC_loss_teacher[n][i]))
    ax = plt.subplot(3, 3, 1)
    if show_ZB == True:
        ax = sns.heatmap(X_test[i].reshape(18, 14), vmin = 0, vmax = X_test[i].max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
    else:
        ax = sns.heatmap(MC[n][i].reshape(18, 14), vmin = 0, vmax = MC[n][i].max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('Original')
    
    ax = plt.subplot(3, 3, 2)
    if show_ZB == True:
        ax = sns.heatmap(X_test_predict_teacher[i].reshape(18, 14), vmin = 0, vmax = X_test[i].max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
    else:
        ax = sns.heatmap(MC_predict_teacher[n][i].reshape(18, 14), vmin = 0, vmax = MC[n][i].max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('Reconstructed')
    
    ax = plt.subplot(3, 3, 3)
    if show_ZB == True:
        ax = sns.heatmap(np.absolute(X_test_predict_teacher[i].reshape(18, 14) - X_test[i].reshape(18, 14)), vmin = 0, vmax = X_test[i].max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
    else:
        ax = sns.heatmap(np.absolute(MC_predict_teacher[n][i].reshape(18, 14) - MC[n][i].reshape(18, 14)), vmin = 0, vmax = MC[n][i].max(), cmap = "Purples", cbar_kws = {'label': 'ET (GeV)'})
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('abs(original-reconstructed)')

    plt.show()

# Knowledge Distillation (+ quantizing with QKeras)

In [None]:
# v1
x_in = layers.Input(shape=(252,), name="In")

x = QDense(15, kernel_quantizer=quantized_bits(2, 0, 1, alpha=1.0),
           use_bias=False, name='dense1')(x_in)
x = QBatchNormalization(beta_quantizer=quantized_bits(10, 2, 1, alpha='auto'),
                        gamma_quantizer=quantized_bits(10, 2, 1, alpha='auto'),
                        mean_quantizer=quantized_bits(10, 2, 1, alpha='auto'),
                        variance_quantizer=quantized_bits(10, 2, 1, alpha='auto'),
                        name = 'QBN1')(x)
x = QActivation('quantized_relu(5, 2)', name='relu1')(x)
x = QDense(1, kernel_quantizer=quantized_bits(4, 0, 1, alpha=1.0),
           use_bias=False, name='output')(x)

student = tf.keras.models.Model(x_in, x)
student.summary()
student.compile(optimizer = 'adam', loss = 'mse')

In [None]:
# v2
x_in = layers.Input(shape=(252,), name="In")
x = layers.Reshape((18,14,1), name='reshape')(x_in)

x = QConv2D(3,(3,3), strides=2, padding="valid", use_bias=False,
            kernel_quantizer=quantized_bits(16,4,1,alpha='auto'), name='conv')(x)
x = QActivation('quantized_relu(16,4)', name='relu1')(x)
x = layers.Flatten(name='flatten')(x)
x = QDense(20, kernel_quantizer=quantized_bits(16,4,1,alpha='auto'),
           use_bias=False, name='dense1')(x)
x = QActivation('quantized_relu(16,4)', name='relu2')(x)
x = QDense(1, kernel_quantizer=quantized_bits(16,2,1,alpha='auto'),
           use_bias=False, name='output')(x)

student = tf.keras.models.Model(x_in, x)
student.summary()
student.compile(optimizer = 'adam', loss = 'mse')

In [None]:
history = student.fit(X_train.reshape((-1,252,1)), X_train_loss_teacher,
                      epochs = 100,
                      validation_data = (X_val.reshape((-1,252,1)), X_val_loss_teacher),
                      batch_size = 512)

In [None]:
plt.figure(figsize = (15,10))
axes = plt.subplot(2, 2, 1)
axes.plot(history.history['loss'], label = 'train loss')
#axes.set_yscale(value = "log")
axes.plot(history.history['val_loss'], label = 'val loss')
axes.legend(loc = "upper right")
axes.set_xlabel('Epoch')
axes.set_ylabel('Loss')

In [None]:
X_train_loss_student = student.predict(X_train.reshape((-1,252,1)))
X_val_loss_student = student.predict(X_val.reshape((-1,252,1)))
X_test_loss_student = student.predict(X_test.reshape((-1,252,1)))
MC_loss_student = []
for i in range(len(MC)):
    MC_loss_student.append(student.predict(MC[i].reshape((-1,252,1))))

In [None]:
nbins = 60
rmin = 0
rmax = 20
plt.hist(X_test_loss_student, density = 1, bins = nbins, alpha = 0.3, label = 'ZB23', range = (rmin, rmax), log = True)
plt.hist(MC_loss_student[3], density = 1, bins = nbins, label = 'H->tautau', color='green', histtype = 'step', range = (rmin, rmax))
plt.hist(MC_loss_student[4], density = 1, bins = nbins, label = 'SM HH->4b', color='red', histtype = 'step', range = (rmin, rmax))
plt.hist(MC_loss_student[5], density = 1, bins = nbins, label = 'TTbar', color='blue', histtype = 'step', range = (rmin, rmax))
plt.hist(MC_loss_student[6], density = 1, bins = nbins, label = 'H->aa->4b', color='orange', histtype = 'step', range = (rmin, rmax))
plt.hist(MC_loss_student[7], density = 1, bins = nbins, label = 'SUEP', color='purple', histtype = 'step', range = (rmin, rmax))
#plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
#plt.legend(loc='center left', bbox_to_anchor=(0.57, 0.5))
plt.title('CICADA_v1, ZB scores')
plt.legend(loc='best')
plt.xlabel("Anomaly score")
#plt.xticks(np.arange(rmin, rmax, step = 0.0002))
plt.show()

# ROC plotting

### Assigning labels and arranging for ROC plotting

In [None]:
#Baseline
#Assuming only the mean ZB is learned
#Take mean ZB as outputs no matter what inputs are
#Classifier of baseline = MSE(inputs, ZeroBias_mean)
ZB_mean = np.mean(X_test.reshape((-1,252,1)), axis = 0)

baseline_zb = np.mean((X_test.reshape(-1,252,1) - ZB_mean)**2, axis = (1))
baseline_mc = []
for i in range(len(MC)):
    baseline_mc.append(np.mean((MC[i].reshape(-1,252,1) - ZB_mean)**2, axis = (1)))

In [None]:
#Assign labels for various signals (y = 1) and backgrounds (y = 0)
Y_zb = np.zeros((X_test.shape[0], 1))
Y_mc = []
for i in range(len(MC)):
    Y_mc.append(np.ones((MC[i].shape[0], 1)))

#Concatenate datasets to make ROC curves, i.e. QCD/SingleNu/signals vs ZB

#True labels
Y_true = []
#Baseline scores
Y_baseline = []
#Model scores
Y_teacher = []
Y_student = []

for i in range(len(MC)):
    Y_true.append(np.concatenate((Y_mc[i], Y_zb)))
    Y_baseline.append(np.concatenate((baseline_mc[i], baseline_zb)))
    #Y_teacher.append(np.concatenate((MC_loss_teacher[i], X_test_loss_teacher)))
    Y_student.append(np.concatenate((MC_loss_student[i], X_test_loss_student)))

### Baseline ROC

In [None]:
plt.figure(figsize = (13, 13))
axes = plt.subplot(2, 2, 1)
fpr_baseline = []
tpr_baseline = []
thresholds_baseline = []
roc_auc_baseline = []
for i in range(len(MC)):
    fpr_baseline.append(np.empty((Y_true[i].shape[0],1)))
    tpr_baseline.append(np.empty((Y_true[i].shape[0],1)))
    thresholds_baseline.append(np.empty((Y_true[i].shape[0],1)))
    roc_auc_baseline.append(np.empty((Y_true[i].shape[0],1)))
    fpr_baseline[i], tpr_baseline[i], thresholds_baseline[i] = roc_curve(Y_true[i], Y_baseline[i])
    roc_auc_baseline[i] = auc(fpr_baseline[i], tpr_baseline[i])
    fpr_baseline[i] *= 28.61
    if i == 3:
        axes.plot(fpr_baseline[3], tpr_baseline[3], linestyle = '-', lw = 1.5, color = 'green', label = 'H->2Tau (AUC = %.5f)' % (roc_auc_baseline[3]))
    if i == 4:
        axes.plot(fpr_baseline[4], tpr_baseline[4], linestyle = '-', lw = 1.5, color = 'red', label = 'SM HH->4b (AUC = %.5f)' % (roc_auc_baseline[4]))
    if i == 5:
        axes.plot(fpr_baseline[5], tpr_baseline[5], linestyle = '-', lw = 1.5, color = 'blue', label = 'TTbar (AUC = %.5f)' % (roc_auc_baseline[5]))
    if i == 6:
        axes.plot(fpr_baseline[6], tpr_baseline[6], linestyle = '-', lw = 1.5, color = 'orange', label = 'H->aa->4b (AUC = %.5f)' % (roc_auc_baseline[6]))
    if i == 7:
        axes.plot(fpr_baseline[7], tpr_baseline[7], linestyle = '-', lw = 1.5, color = 'purple', label = 'SUEP (AUC = %.5f)' % (roc_auc_baseline[7]))
axes.plot([0.001, 0.001], [0, 1], linestyle = '--', lw = 1, color = 'black', label = 'Trigger rate = 1 kHz')
axes.set_xlim([0.0001, 28.61])
axes.set_ylim([0.000001, 1.0])
axes.set_xscale(value = "log")
axes.set_yscale(value = "log")
axes.set_xlabel('Trigger Rate (MHz)',size=15)
axes.set_ylabel('Signal Efficiency',size=15)
axes.set_title('Baseline (Energy Cut-Based)',size=15)
axes.legend(loc='center left', bbox_to_anchor=(1, 0.3),fontsize=12)
plt.show()

### Teacher model ROC

In [None]:
plt.figure(figsize = (13, 13))
axes = plt.subplot(2, 2, 1)
fpr_teacher = []
tpr_teacher = []
thresholds_teacher = []
roc_auc_teacher = []
for i in range(len(MC)):
    fpr_teacher.append(np.empty((Y_true[i].shape[0],1)))
    tpr_teacher.append(np.empty((Y_true[i].shape[0],1)))
    thresholds_teacher.append(np.empty((Y_true[i].shape[0],1)))
    roc_auc_teacher.append(np.empty((Y_true[i].shape[0],1)))
    fpr_teacher[i], tpr_teacher[i], thresholds_teacher[i] = roc_curve(Y_true[i], Y_teacher[i])
    roc_auc_teacher[i] = auc(fpr_teacher[i], tpr_teacher[i])
    fpr_teacher[i] *= 28.61
    if i == 3:
        axes.plot(fpr_teacher[3], tpr_teacher[3], linestyle = '-', lw = 1.5, color = 'green', label = 'H->tautau (AUC = %.5f)' % (roc_auc_teacher[3]))
    if i == 4:
        axes.plot(fpr_teacher[4], tpr_teacher[4], linestyle = '-', lw = 1.5, color = 'red', label = 'SM HH->4b (AUC = %.5f)' % (roc_auc_teacher[4]))
    if i == 5:
        axes.plot(fpr_teacher[5], tpr_teacher[5], linestyle = '-', lw = 1.5, color = 'blue', label = 'TTbar (AUC = %.5f)' % (roc_auc_teacher[5]))
    if i == 6:
        axes.plot(fpr_teacher[6], tpr_teacher[6], linestyle = '-', lw = 1.5, color = 'orange', label = 'H->aa->4b (AUC = %.5f)' % (roc_auc_teacher[6]))
    if i == 7:
        axes.plot(fpr_teacher[7], tpr_teacher[7], linestyle = '-', lw = 1.5, color = 'purple', label = 'SUEP (AUC = %.5f)' % (roc_auc_teacher[7]))
axes.plot([0.001, 0.001], [0, 1], linestyle = '--', lw = 1, color = 'black', label = 'Trigger rate = 1 kHz')
#axes.set_xlim([0.00002861, 28.61])
axes.set_xlim([0.0001, 28.61])
axes.set_ylim([0.000001, 1.0])
axes.set_xscale(value = "log")
axes.set_yscale(value = "log")
axes.set_xlabel('Trigger Rate (MHz)',size=15)
axes.set_ylabel('Signal Efficiency',size=15)
axes.set_title('Teacher Network',size=15)
#axes.legend(loc='center left', bbox_to_anchor = (0.3, 0.3),fontsize=12)
axes.legend(loc='center left', bbox_to_anchor=(1, 0.3),fontsize=12)
plt.show()

### Student model ROC

In [None]:
plt.figure(figsize = (13, 13))
axes = plt.subplot(2, 2, 1)
fpr_student = []
tpr_student = []
thresholds_student = []
roc_auc_student = []

for i in range(len(MC)):
    fpr_student.append(np.empty((Y_true[i].shape[0],1)))
    tpr_student.append(np.empty((Y_true[i].shape[0],1)))
    thresholds_student.append(np.empty((Y_true[i].shape[0],1)))
    roc_auc_student.append(np.empty((Y_true[i].shape[0],1)))
    fpr_student[i], tpr_student[i], thresholds_student[i] = roc_curve(Y_true[i], Y_student[i])
    roc_auc_student[i] = auc(fpr_student[i], tpr_student[i])
    fpr_student[i] *= 28.61
    if i == 3:
        axes.plot(fpr_student[3], tpr_student[3], linestyle = '-', lw = 1.5, color = 'green', label = 'H->tautau (AUC = %.5f)' % (roc_auc_student[3]))
    if i == 4:
        axes.plot(fpr_student[4], tpr_student[4], linestyle = '-', lw = 1.5, color = 'red', label = 'SM HH->4b (AUC = %.5f)' % (roc_auc_student[4]))
    if i == 5:
        axes.plot(fpr_student[5], tpr_student[5], linestyle = '-', lw = 1.5, color = 'blue', label = 'TTbar (AUC = %.5f)' % (roc_auc_student[5]))
    if i == 6:
        axes.plot(fpr_student[6], tpr_student[6], linestyle = '-', lw = 1.5, color = 'orange', label = 'H->aa->4b (AUC = %.5f)' % (roc_auc_student[6]))
    if i == 7:
        axes.plot(fpr_student[7], tpr_student[7], linestyle = '-', lw = 1.5, color = 'purple', label = 'SUEP (AUC = %.5f)' % (roc_auc_student[7]))
axes.plot([0.003, 0.003], [0, 1], linestyle = '--', lw = 1, color = 'black', label = 'Trigger rate = 3 kHz')
#axes.set_xlim([0.00002861, 28.61])
axes.set_xlim([0.0001, 28.61])
axes.set_ylim([0.000001, 1])
axes.set_xscale(value = "log")
axes.set_yscale(value = "log")
axes.set_xlabel('Trigger Rate (MHz)',size=15)
axes.set_ylabel('Signal Efficiency',size=15)
axes.set_title('Student Network',size=15)
#axes.legend(loc='center left', bbox_to_anchor = (0.3, 0.3),fontsize=12)
axes.legend(loc='center left', bbox_to_anchor=(1, 0.3),fontsize=12)
plt.show()

# cross-validation

In [None]:
Y_zb = np.zeros((X_test.shape[0], 1))
Y_mc = []
for i in range(len(MC)):
    Y_mc.append(np.ones((MC[i].shape[0], 1)))

def kfold(y, k):
    N=y.shape[0]
    n=np.floor(N/k).astype(int)
    y_kf=[]
    for i in range(k):
        y_kf.append(y[i*n : (i+1)*n])
    return y_kf

kf=10
X_test_loss_model = X_test_loss_student
Y_model = Y_student
MC_loss_model = MC_loss_student

X_test_loss_model_kf=kfold(X_test_loss_model,kf)
Y_zb_kf=kfold(Y_zb,kf)

plt.figure(figsize = (16, 16))
axes = plt.subplot(2, 2, 1)

fpr = []
tpr = []
thresholds = []
roc_auc = []

for j in range(len(MC)):
    fpr.append(np.empty((Y_true[j].shape[0],1)))
    tpr.append(np.empty((Y_true[j].shape[0],1)))
    thresholds.append(np.empty((Y_true[j].shape[0],1)))
    roc_auc.append(np.empty((Y_true[j].shape[0],1)))
    fpr[j], tpr[j], thresholds[j] = roc_curve(Y_true[j], Y_model[j])
    roc_auc[j] = auc(fpr[j], tpr[j])
    fpr[j] *= 28.61
    
    MC_loss_model_kf=kfold(MC_loss_model[j],kf)
    Y_mc_kf=kfold(Y_mc[j],kf)

    Y_true_kf=[]
    Y_model_kf=[]
    for i in range(kf):
        Y_true_kf.append(np.concatenate((Y_mc_kf[i],Y_zb_kf[i])))
        Y_model_kf.append((np.concatenate((MC_loss_model_kf[i],X_test_loss_model_kf[i]))))

    fpr_mean=np.linspace(0,1,10000000)
    tpr_kf=[]
    fpr_kf=[]
    thresholds_kf=[]
    roc_auc_kf=[]
    for i in range(kf):
        tpr_kf.append(np.empty((Y_true_kf[i].shape[0],1)))
        fpr_kf.append(np.empty((Y_true_kf[i].shape[0],1)))
        thresholds_kf.append(np.empty((Y_true_kf[i].shape[0],1)))
        roc_auc_kf.append(np.empty((Y_true_kf[i].shape[0],1)))
        fpr_kf[i], tpr_kf[i], thresholds_kf[i] = roc_curve(Y_true_kf[i], Y_model_kf[i])
        roc_auc_kf[i] = auc(fpr_kf[i], tpr_kf[i])

    tpr_total=[]
    for i in range(kf):
        interp_tpr=np.interp(fpr_mean, fpr_kf[i], tpr_kf[i])
        interp_tpr[0]=0.0
        tpr_total.append(interp_tpr)
    tpr_mean=np.mean(tpr_total, axis=0)
    tpr_mean[-1]=1.0
    roc_auc_mean=auc(fpr_mean,tpr_mean)
    roc_auc_std=np.std(roc_auc_kf)
    
    tpr_std=np.std(tpr_total, axis=0)
    tpr_up=np.minimum(tpr_mean+tpr_std,1)
    tpr_down=np.maximum(tpr_mean-tpr_std,0)

    fpr_mean *= 28.61
    
    if j == 3:
        axes.plot(fpr_mean, tpr_mean, linestyle = '-', lw = 1.5, color = 'green', label = 'H->tautau (AUC = %.3f $\pm$ %0.4f)' % (roc_auc_mean, roc_auc_std))
        axes.fill_between(fpr_mean, tpr_down, tpr_up, color='green', alpha=0.1)
        #axes.plot(fpr[3], tpr[3], linestyle = '--', lw = 1.5, color = 'green', label = 'H->tautau (AUC = %.3f)' % (roc_auc[3]))
    if j == 4:
        axes.plot(fpr_mean, tpr_mean, linestyle = '-', lw = 1.5, color = 'red', label = 'SM HH->4b (AUC = %.3f $\pm$ %0.4f)' % (roc_auc_mean, roc_auc_std))
        axes.fill_between(fpr_mean, tpr_down, tpr_up, color='red', alpha=0.1)
        #axes.plot(fpr[4], tpr[4], linestyle = '--', lw = 1.5, color = 'red', label = 'SM HH->4b (AUC = %.3f)' % (roc_auc[4]))
    if j == 5:
        axes.plot(fpr_mean, tpr_mean, linestyle = '-', lw = 1.5, color = 'blue', label = 'TTbar (AUC = %.3f $\pm$ %0.4f)' % (roc_auc_mean, roc_auc_std))
        axes.fill_between(fpr_mean, tpr_down, tpr_up, color='blue', alpha=0.1)
        #axes.plot(fpr[5], tpr[5], linestyle = '--', lw = 1.5, color = 'blue', label = 'TTbar (AUC = %.3f)' % (roc_auc[5]))
    if j == 6:
        axes.plot(fpr_mean, tpr_mean, linestyle = '-', lw = 1.5, color = 'orange', label = 'H->aa->4b (AUC = %.3f $\pm$ %0.4f)' % (roc_auc_mean, roc_auc_std))
        axes.fill_between(fpr_mean, tpr_down, tpr_up, color='orange', alpha=0.1)
        #axes.plot(fpr[6], tpr[6], linestyle = '--', lw = 1.5, color = 'orange', label = 'H->aa->4b (AUC = %.3f)' % (roc_auc[6]))
    if j == 7:
        axes.plot(fpr_mean, tpr_mean, linestyle = '-', lw = 1.5, color = 'purple', label = 'SUEP (AUC = %.3f $\pm$ %0.4f)' % (roc_auc_mean, roc_auc_std))
        axes.fill_between(fpr_mean, tpr_down, tpr_up, color='purple', alpha=0.1)
        #axes.plot(fpr[7], tpr[7], linestyle = '--', lw = 1.5, color = 'purple', label = 'SUEP (AUC = %.3f)' % (roc_auc[7]))

axes.plot([0.003, 0.003], [0, 1], linestyle = '--', lw = 1, color = 'black', label = 'Trigger rate = 3 kHz')
axes.set_xlim([0.0001, 28.61])
axes.set_ylim([0.000001, 1])
axes.set_xscale(value = "log")
axes.set_yscale(value = "log")
axes.set_xlabel('Trigger Rate (MHz)',size=17)
axes.set_ylabel('Signal Efficiency',size=17)
axes.set_title('CICADA_v3_v1, signal(Run2) vs ZB(2023)',size=17)
axes.legend(loc='center left', bbox_to_anchor = (0.3, 0.3),fontsize=11)
plt.show()

# sensitivity table

In [None]:
thr=0.002
table_tpr_baseline = []
table_tpr_teacher = []
table_tpr_student = []
table_tpr_change = []
for i in range(len(fpr_baseline)):
    for j in range(len(fpr_baseline[i])):
        if fpr_baseline[i][j] > thr:
            table_tpr_baseline.append(tpr_baseline[i][j] * 100)
            break
    #for j in range(len(fpr_model[i])):
        #if fpr_model[i][j] > thr:
            #table_tpr_teacher.append(tpr_teacher[i][j] * 100)
            #break
    for j in range(len(fpr_student[i])):
        if fpr_student[i][j] > thr:
            table_tpr_student.append(tpr_student[i][j] * 100)
            break

#for i in range(len(MC)):
    #table_tpr_change.append(100 * (table_tpr_model[i] - table_tpr_baseline[i])/table_tpr_baseline[i])

MC_names = ['H->tautau','SM HH->4b','TTbar','H->aa->4b','SUEP']
table_tpr = pd.DataFrame({'Baseline': table_tpr_baseline[3:],
                          #'Teacher': table_tpr_teacher[3:],
                          'Student': table_tpr_student[3:]},
                         index = MC_names)
#table_tpr = table_tpr.sort_values(by = 'delta', ascending = False)

pd.set_option('display.max_colwidth', None)
table_tpr