# Import libraries

In [1]:
import numpy as np
from numpy.random import RandomState
import numpy.ma as ma


import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
%matplotlib inline

import h5py
import ot
from numpy.random import Generator, PCG64
from sklearn import metrics
import itertools

from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm

In [2]:
sigAliasList    = ['sig_A', 'sig_h0', 'sig_hch', 'sig_LQ']
sigFilenameList = ['Ato4l_lepFilter_13TeV_filtered.h5', 'hToTauTau_13TeV_PU20_filtered.h5', 'hChToTauNu_13TeV_PU20_filtered.h5', 'leptoquark_LOWMASS_lepFilter_13TeV_filtered.h5']

In [3]:
#-- Set base directory and data directory path --#
basePath   = './'
dataPath   = 'Data/'

bkgPath    = basePath+dataPath+'background_for_training.h5'
sigPathList = []
for x in sigFilenameList:
  sigPathList.append(basePath+dataPath+x)

# Functions

In [4]:
%run centralFunctions.ipynb

# Loading Data

In [5]:
dataDict = {}
dataDict['bkg'] = h5py.File(bkgPath, 'r')

for i in range(len(sigAliasList)):
  alias   = sigAliasList[i]
  sigPath = sigPathList[i]
  dataDict[alias] = h5py.File(sigPath, 'r')

In [6]:
bkg_data = dataDict['bkg']['Particles'][:,:,0:3]
sig_data = {}

for alias in sigAliasList:
  sig_data[alias] = dataDict[alias]['Particles'][:,:,0:3]

# Low $p_T$ range

In [7]:
nEvents = 500
random_state = Generator(PCG64(123))
OTSCHEME = {}
OTSCHEME['normPT'] = True
OTSCHEME['balanced'] = True
OTSCHEME['noZeroPad'] = False
OTSCHEME['individualOT'] = False

In [8]:
total_event_pT = {}

total_event_pT['bkg'] = np.sum(bkg_data[:, :, 0], axis=1)

for alias in sigAliasList:
    total_event_pT[alias] = np.sum(sig_data[alias][:,:,0], axis=1)

pTrange = [0,50,100,150,200,500,1000]
neighbor_list = list(range(5, 400,10))
avg_aucs = []
std_aucs = []
avg_ks = []
std_ks = []

filtered_events = {}
filtered_events['bkg'] = {}
for alias in sigAliasList:
    filtered_events[alias] = {}

for i in range(0, len(pTrange)-1):
    lower_bound = pTrange[i]
    upper_bound = pTrange[i+1]
    
    mask = (total_event_pT['bkg'] >= lower_bound) & (total_event_pT['bkg'] <= upper_bound)
    
    filtered_events['bkg'][str(pTrange[i+1])] = randomDataSample(bkg_data[mask],nEvents,random_state)
    
    for alias in sigAliasList:
        mask = (total_event_pT[alias] >= lower_bound) & (total_event_pT[alias] <= upper_bound)
        filtered_events[alias][str(pTrange[i+1])] = randomDataSample(sig_data[alias][mask],nEvents,random_state)

del sig_data
del bkg_data

In [9]:
gamma_list = [0.001, 0.01, 0.1, 1, 10, 20]
C_list = [0.001,0.01,1,10,100]

for i in range(0, len(pTrange)-1):
    np.random.seed(i)
    permutation = np.random.permutation(nEvents*2)
    for alias in sigAliasList:
        event_list = np.concatenate((filtered_events['bkg'][str(pTrange[i+1])],filtered_events[alias][str(pTrange[i+1])]))
        event_labels = np.asarray([0] * nEvents + [1] * nEvents)
        event_list = event_list[permutation]
        event_labels = event_labels[permutation]
        
        distance_matrix = calcOTDistance(event_list, event_list, OTSCHEME, '2D', Matrix = True)
        
        auc_list, best_gamma_list, best_C_list = SVM_cross_validation(distance_matrix, event_labels, gamma_list, C_list)
        print(np.mean(auc_list), np.std(auc_list))
        print(np.mean(best_gamma_list), np.std(best_gamma_list))
        print(np.mean(best_C_list), np.std(best_C_list))

100%|██████████| 1000000/1000000 [01:10<00:00, 14162.20it/s]


0.6621256209724191 0.037137544535874104
14.2 7.652450587883596
6.4 4.409081537009721


100%|██████████| 1000000/1000000 [01:10<00:00, 14155.73it/s]


0.5007336715370221 0.02489754407738444
2.8 3.6000000000000005
22.6 38.85665966086123


100%|██████████| 1000000/1000000 [01:10<00:00, 14131.04it/s]


0.5976552858748009 0.019596376159695575
8.2 3.6
8.2 3.6


100%|██████████| 1000000/1000000 [01:10<00:00, 14192.30it/s]


0.7747856747897253 0.023324677531647332
18.0 4.0
2.8 3.6


100%|██████████| 1000000/1000000 [01:10<00:00, 14088.93it/s]


0.5694413150003783 0.06046903975398483
2.8 3.6
2.8 3.6


100%|██████████| 1000000/1000000 [01:12<00:00, 13878.17it/s]


0.5362113098080017 0.023034251189970302
0.8001999999999999 0.3996
20.8 39.6


100%|██████████| 1000000/1000000 [01:10<00:00, 14131.05it/s]


0.5123997826402175 0.0309946558054473
1.0 0.0
2.8 3.6000000000000005


100%|██████████| 1000000/1000000 [01:10<00:00, 14105.42it/s]


0.5395825138119626 0.03411848684580315
1.0 0.0
1.0 0.0


100%|██████████| 1000000/1000000 [01:10<00:00, 14201.19it/s]


0.5252004975951261 0.005512542922361326
0.82 0.36
2.8 3.6000000000000005


100%|██████████| 1000000/1000000 [01:10<00:00, 14099.92it/s]


0.5076017426530938 0.003802591281631008
0.42219999999999996 0.47304054794488815
24.4 38.01368174749717


100%|██████████| 1000000/1000000 [01:11<00:00, 14044.41it/s]


0.5026316798814857 0.007181715052068852
0.4204 0.474620100712138
42.4 47.144883073351664


100%|██████████| 1000000/1000000 [01:10<00:00, 14098.22it/s]


0.5047179702762185 0.005216844368939403
1.0 0.0
1.0 0.0


100%|██████████| 1000000/1000000 [01:09<00:00, 14294.48it/s]


0.5055165746396545 0.007924297651358072
1.0 0.0
22.6 38.856659660861226


100%|██████████| 1000000/1000000 [01:10<00:00, 14140.98it/s]


0.49992967484023315 0.0032433056773941865
0.64 0.440908153700972
4.6 4.409081537009721


100%|██████████| 1000000/1000000 [01:10<00:00, 14177.18it/s]


0.5019704911667637 0.00241339775315438
0.45999999999999996 0.440908153700972
1.0 0.0


100%|██████████| 1000000/1000000 [01:10<00:00, 14147.41it/s]


0.49971107984621677 0.005656061701097151
0.82 0.36
22.6 38.85665966086123


100%|██████████| 1000000/1000000 [01:11<00:00, 13890.02it/s]


0.5118438968397124 0.005940791105905887
1.0 0.0
1.0 0.0


100%|██████████| 1000000/1000000 [01:40<00:00, 9914.02it/s]


0.5069122387395792 0.00811049924128961
0.6220000000000001 0.4638275541621045
2.8 3.6


100%|██████████| 1000000/1000000 [21:15<00:00, 783.90it/s] 


0.501900718727663 0.007459004728839429
0.44020000000000004 0.4585020828742221
20.8 39.6


100%|██████████| 1000000/1000000 [18:36<00:00, 895.95it/s] 


0.5069221397296783 0.003967586628124781
0.4600000000000001 0.440908153700972
1.0 0.0


100%|██████████| 1000000/1000000 [01:10<00:00, 14220.01it/s]


0.5593364529387064 0.020524249597955832
1.0 0.0
1.0 0.0


100%|██████████| 1000000/1000000 [04:06<00:00, 4058.05it/s]


0.5007211080732816 0.007862718071649155
0.6202 0.4662104245938737
20.8 39.6


100%|██████████| 1000000/1000000 [17:43<00:00, 940.00it/s] 


0.5018637309292648 0.003901981029775161
0.82 0.36
2.8 3.6000000000000005


100%|██████████| 1000000/1000000 [16:10<00:00, 1030.39it/s]


0.5085868982936543 0.0034610207976638716
0.6202 0.4662104245938737
24.4 38.01368174749717
