# Import libraries

In [1]:
import numpy as np
from numpy.random import RandomState
import numpy.ma as ma

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
%matplotlib inline

import h5py
import ot
from numpy.random import Generator, PCG64
from sklearn import metrics
import itertools

from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm

In [2]:
sigAliasList    = ['sig_A', 'sig_h0', 'sig_hch', 'sig_LQ']
sigFilenameList = ['Ato4l_lepFilter_13TeV_filtered.h5', 'hToTauTau_13TeV_PU20_filtered.h5', 'hChToTauNu_13TeV_PU20_filtered.h5', 'leptoquark_LOWMASS_lepFilter_13TeV_filtered.h5']

In [3]:
#-- Set base directory and data directory path --#
basePath   = './'
dataPath   = 'Data/'

bkgPath    = basePath+dataPath+'background_for_training.h5'
sigPathList = []
for x in sigFilenameList:
  sigPathList.append(basePath+dataPath+x)

# Functions

In [4]:
%run centralFunctions.ipynb

# Loading Data

In [5]:
dataDict = {}
dataDict['bkg'] = h5py.File(bkgPath, 'r')

for i in range(len(sigAliasList)):
  alias   = sigAliasList[i]
  sigPath = sigPathList[i]
  dataDict[alias] = h5py.File(sigPath, 'r')

In [6]:
bkg_data = dataDict['bkg']['Particles'][:,:,0:3]
sig_data = {}

for alias in sigAliasList:
  sig_data[alias] = dataDict[alias]['Particles'][:,:,0:3]

# Low $p_T$ range

In [7]:
nEvents = 100
random_state = Generator(PCG64(123))
OTSCHEME = {}
OTSCHEME['normPT'] = True
OTSCHEME['balanced'] = True
OTSCHEME['noZeroPad'] = False
OTSCHEME['individualOT'] = False

In [8]:
total_event_pT = {}

total_event_pT['bkg'] = np.sum(bkg_data[:, :, 0], axis=1)

for alias in sigAliasList:
    total_event_pT[alias] = np.sum(sig_data[alias][:,:,0], axis=1)

pTrange = [0,50,100,150,200,500,1000]
neighbor_list = list(range(5, 400,10))
avg_aucs = []
std_aucs = []
avg_ks = []
std_ks = []
for i in range(0, len(pTrange)-1):
    lower_bound = pTrange[i]
    upper_bound = pTrange[i+1]
    
    filtered_events = {}
    mask = (total_event_pT['bkg'] >= lower_bound) & (total_event_pT['bkg'] <= upper_bound)
    
    filtered_events['bkg'] = randomDataSample(bkg_data[mask],nEvents,random_state)
    np.random.seed(i)
    permutation = np.random.permutation(nEvents*2)
    
    for alias in sigAliasList:
        mask = (total_event_pT[alias] >= lower_bound) & (total_event_pT[alias] <= upper_bound)
        filtered_events[alias] = randomDataSample(sig_data[alias][mask],nEvents,random_state)
        
        event_list = np.concatenate((filtered_events['bkg'],filtered_events[alias]))
        event_labels = np.asarray([0] * nEvents + [1] * nEvents)
        event_list = event_list[permutation]
        event_labels = event_labels[permutation]
        
        distance_matrix = calcOTDistance(event_list, event_list, OTSCHEME, '2D', Matrix = True)
        
        avg_auc, std_auc, avg_k, std_k = kNN_cross_validation(distance_matrix, event_labels, neighbor_list, k_fold=5)
        print(avg_auc, std_auc, avg_k, std_k)
        avg_aucs.append(avg_auc)
        std_aucs.append(std_auc)
        avg_ks.append(avg_k)
        std_ks.append(std_k)

100%|██████████| 4000000/4000000 [04:46<00:00, 13971.56it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.27it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.20it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.05it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.93it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 69.70it/s]


0.6193241561554373 0.013909554559565948 25.0 12.649110640673518


100%|██████████| 4000000/4000000 [04:46<00:00, 13944.60it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.13it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.73it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.37it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.88it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 75.99it/s]


0.5560362022323223 0.026013491872497616 85.0 87.40709353364863


100%|██████████| 4000000/4000000 [04:47<00:00, 13908.51it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.43it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.06it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.78it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.17it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 73.63it/s] 


0.5752272736176252 0.037555171735308014 115.0 127.59310326189265


100%|██████████| 4000000/4000000 [04:47<00:00, 13928.89it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.09it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.54it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.89it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.17it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 75.09it/s]


0.6840947062389287 0.03150385066000139 25.0 40.0


100%|██████████| 4000000/4000000 [04:45<00:00, 13990.30it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.64it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.99it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.11it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.44it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.89it/s]


0.797380753997553 0.008993921054521814 43.0 36.55133376499413


100%|██████████| 4000000/4000000 [04:46<00:00, 13967.24it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 70.35it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 73.71it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.99it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.57it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.61it/s]


0.6761395554738507 0.035688653156475125 57.0 49.95998398718719


100%|██████████| 4000000/4000000 [04:47<00:00, 13914.39it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 75.91it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.88it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.70it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.59it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.86it/s] 


0.607118567366921 0.026939186162531433 75.0 45.16635916254486


100%|██████████| 4000000/4000000 [04:47<00:00, 13931.42it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 69.75it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.37it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.41it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.93it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.44it/s] 


0.7468665866068318 0.021490142570251385 65.0 30.331501776206203


100%|██████████| 4000000/4000000 [04:44<00:00, 14055.29it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.04it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.45it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.01it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 65.33it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.93it/s]


0.7011452113981134 0.005785647854299646 15.0 8.94427190999916


100%|██████████| 4000000/4000000 [04:46<00:00, 13951.53it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.15it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.82it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 68.22it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 75.86it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.19it/s] 


0.5667354772174112 0.04953517731908581 89.0 98.91410415102591


100%|██████████| 4000000/4000000 [04:46<00:00, 13937.83it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 75.54it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.73it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.73it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.96it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.79it/s]


0.6015605971417531 0.03723336061553597 31.0 18.547236990991408


100%|██████████| 4000000/4000000 [04:48<00:00, 13856.90it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 72.65it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.02it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.92it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.93it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.00it/s]


0.6883180348051865 0.015572400289487753 63.0 47.07440918375928


100%|██████████| 4000000/4000000 [04:44<00:00, 14038.20it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.82it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.04it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.70it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.44it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.89it/s] 


0.6251299810860171 0.020683095471606176 59.0 24.979991993593593


100%|██████████| 4000000/4000000 [04:45<00:00, 13992.27it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.32it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.65it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.67it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.09it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.09it/s]


0.5761864516416402 0.02572489625664325 51.0 29.393876913398138


100%|██████████| 4000000/4000000 [04:46<00:00, 13985.04it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.91it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.47it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.96it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.86it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 81.40it/s] 


0.6665967908890511 0.021584085318690237 51.0 28.705400188814647


100%|██████████| 4000000/4000000 [04:45<00:00, 14010.59it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.40it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.88it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.90it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.25it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 82.46it/s] 


0.6981814488486944 0.019362046157130636 65.0 15.491933384829668


100%|██████████| 4000000/4000000 [04:44<00:00, 14052.55it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.66it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.68it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.28it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.42it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.10it/s] 


0.6258593862764383 0.024167621160884223 69.0 67.70524351924303


100%|██████████| 4000000/4000000 [04:44<00:00, 14053.22it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.94it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.34it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.60it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.76it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.18it/s] 


0.5841808994695518 0.006719364053971755 59.0 38.262252939417984


100%|██████████| 4000000/4000000 [04:43<00:00, 14111.99it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 81.35it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.52it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 81.59it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.52it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.63it/s] 


0.6885509588655917 0.020427488648409705 63.0 42.14261501141095


100%|██████████| 4000000/4000000 [04:44<00:00, 14072.39it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.53it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 81.35it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.59it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 80.45it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.86it/s] 


0.6166094819906541 0.020159541926277634 101.0 98.10198774744578


100%|██████████| 4000000/4000000 [04:47<00:00, 13908.07it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.13it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.30it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.18it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.28it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.50it/s] 


0.798913933102407 0.02230500501928155 29.0 10.198039027185569


100%|██████████| 4000000/4000000 [04:46<00:00, 13969.70it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 75.15it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.91it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.70it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.50it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.47it/s] 


0.5829378691399614 0.021238964288079334 83.0 63.686733312362634


100%|██████████| 4000000/4000000 [04:45<00:00, 14023.70it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.57it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.22it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 79.06it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 74.84it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.12it/s] 


0.6202817253317441 0.0341775808408001 81.0 74.45804187594514


100%|██████████| 4000000/4000000 [04:46<00:00, 13962.60it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.04it/s]
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.88it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 78.73it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 77.39it/s] 
Fitting Models: 100%|██████████| 40/40 [00:00<00:00, 76.53it/s] 


0.5284963612145445 0.015364811569643429 57.0 80.34923770640266


In [9]:
print(avg_aucs)
print(std_aucs)

[0.6193241561554373, 0.5560362022323223, 0.5752272736176252, 0.6840947062389287, 0.797380753997553, 0.6761395554738507, 0.607118567366921, 0.7468665866068318, 0.7011452113981134, 0.5667354772174112, 0.6015605971417531, 0.6883180348051865, 0.6251299810860171, 0.5761864516416402, 0.6665967908890511, 0.6981814488486944, 0.6258593862764383, 0.5841808994695518, 0.6885509588655917, 0.6166094819906541, 0.798913933102407, 0.5829378691399614, 0.6202817253317441, 0.5284963612145445]
[0.013909554559565948, 0.026013491872497616, 0.037555171735308014, 0.03150385066000139, 0.008993921054521814, 0.035688653156475125, 0.026939186162531433, 0.021490142570251385, 0.005785647854299646, 0.04953517731908581, 0.03723336061553597, 0.015572400289487753, 0.020683095471606176, 0.02572489625664325, 0.021584085318690237, 0.019362046157130636, 0.024167621160884223, 0.006719364053971755, 0.020427488648409705, 0.020159541926277634, 0.02230500501928155, 0.021238964288079334, 0.0341775808408001, 0.015364811569643429]


In [10]:
grouped_data_1 = [avg_aucs[i:i+4] for i in range(0, len(avg_aucs), 4)]
grouped_data_2 = [std_aucs[i:i+4] for i in range(0, len(std_aucs), 4)]
print(grouped_data_1)
print(grouped_data_2)

[[0.6193241561554373, 0.5560362022323223, 0.5752272736176252, 0.6840947062389287], [0.797380753997553, 0.6761395554738507, 0.607118567366921, 0.7468665866068318], [0.7011452113981134, 0.5667354772174112, 0.6015605971417531, 0.6883180348051865], [0.6251299810860171, 0.5761864516416402, 0.6665967908890511, 0.6981814488486944], [0.6258593862764383, 0.5841808994695518, 0.6885509588655917, 0.6166094819906541], [0.798913933102407, 0.5829378691399614, 0.6202817253317441, 0.5284963612145445]]
[[0.013909554559565948, 0.026013491872497616, 0.037555171735308014, 0.03150385066000139], [0.008993921054521814, 0.035688653156475125, 0.026939186162531433, 0.021490142570251385], [0.005785647854299646, 0.04953517731908581, 0.03723336061553597, 0.015572400289487753], [0.020683095471606176, 0.02572489625664325, 0.021584085318690237, 0.019362046157130636], [0.024167621160884223, 0.006719364053971755, 0.020427488648409705, 0.020159541926277634], [0.02230500501928155, 0.021238964288079334, 0.0341775808408001,