In [26]:
import os
import numpy as np
import pandas as pd
from scipy.stats import wasserstein_distance
from itertools import combinations
from datetime import datetime
from packages.constants import VAR_INFOS_DTYPES
from packages.utils import get_electron_label, get_jet_label, get_logger

## Parameters

In [16]:
logger = get_logger('wass_distance', file=False)
basepath = os.path.join('..', '..')
datapath = os.path.join(basepath, 'data')
var_info_path = os.path.join(datapath, 'var_infos.csv')
collision_path = os.path.join(datapath, 
    'ided_data17_13TeV.AllPeriods.sgn.probes_lhvloose_EGAM1.bkg.vprobes_vlhvloose_EGAM7.GRL_v97.25bins.parquet')
#   'ided_data17_13TeV.AllPeriods.sgn.probes_lhvloose_EGAM1.bkg.vprobes_vlhvloose_EGAM7.GRL_v97.25bins.parquet_et4_eta4.parquet')
boosted_path = os.path.join(datapath, 
    'ided_mc16_13TeV.302236_309995_341330.sgn.boosted_probes.WZ_llqq_plus_radion_ZZ_llqq_plus_ggH3000.merge.25bins.v2.parquet')

## Loading data

In [12]:
basepath = os.path.join('..', '..')
datapath = os.path.join(basepath, 'x')
var_infos = pd.read_csv(var_info_path, index_col=0, dtype=VAR_INFOS_DTYPES)
is_ss = var_infos['type'] == 'shower_shape'
shower_shapes_names =  var_infos.loc[is_ss, 'name'].to_list()
shower_shapes_cols = var_infos.loc[is_ss & (~var_infos['l2calo'].isnull()), 'l2calo'].to_list()
shower_shapes_cols += var_infos.loc[is_ss & (var_infos['l2calo'].isnull()), 'offline'].to_list()
shower_shapes = dict(zip(shower_shapes_names, shower_shapes_cols))
print(f'Selected shower shapes {", " .join(shower_shapes_cols)}')
var_infos

Selected shower shapes trig_L2_cl_reta, trig_L2_cl_eratio, trig_L2_cl_f1, trig_L2_cl_f3, trig_L2_cl_wstot, trig_L2_cl_weta2, el_rhad, el_rhad1, el_rphi


Unnamed: 0,name,label,type,lower_lim,upper_lim,l2calo,offline,TaP,description
0,et,$E_T$,var,0.0,inf,trig_L2_cl_et,el_et,,transverse particle energy on the calorimeter
1,eta,$\eta$,var,-2.5,2.5,trig_L2_cl_eta,el_eta,,pseudorapidity
2,reta,$R_{\eta}$,shower_shape,0.0,1.0,trig_L2_cl_reta,el_reta,,Ratio of the energy in 3x7 cells over the ener...
3,eratio,$E_{ratio}$,shower_shape,0.0,1.0,trig_L2_cl_eratio,el_eratio,,Ratio of the energy difference between the max...
4,f1,$f_1$,shower_shape,0.0,1.0,trig_L2_cl_f1,el_f1,,Ratio of the energy in the first layer to the ...
5,ehad1,$E_{had1}$,unidentified,-inf,inf,trig_L2_cl_ehad1,el_ehad1,,unidentified
6,f3,$f_3$,shower_shape,0.0,1.0,trig_L2_cl_f3,el_f3,,Ratio of the energy in the third layer to the ...
7,wstot,$\omega_{stot}$,shower_shape,0.0,inf,trig_L2_cl_wstot,el_wstot,,Shower width er > 150 GeV only on EM1
8,weta2,$\omega_{\eta 2}$,shower_shape,0.0,inf,trig_L2_cl_weta2,el_weta2,,Lateral shower width on EM2
9,e2tsts1,e2tsts1,unidentified,-inf,inf,trig_L2_cl_e2tsts1,el_e2tsts1,,unidentified


In [13]:
boosted_data = pd.read_parquet(boosted_path, columns=shower_shapes_cols)
boosted_data.tail()

Unnamed: 0,trig_L2_cl_reta,trig_L2_cl_eratio,trig_L2_cl_f1,trig_L2_cl_f3,trig_L2_cl_wstot,trig_L2_cl_weta2,el_rhad,el_rhad1,el_rphi
78909,0.943132,98.999992,0.135097,0.003572,-9999.0,0.012791,0.000759,0.001316,0.948161
78910,0.935138,0.987012,0.090264,0.006514,1.377259,0.011028,0.000289,0.000644,0.950956
78911,0.947895,98.999992,0.151062,0.00345,-9999.0,0.011621,-0.002085,-0.000811,0.950298
78912,0.941024,0.98141,0.096922,0.00691,1.735974,0.011279,0.001961,0.001739,0.957002
78913,0.941024,0.98141,0.096922,0.00691,1.735974,0.011279,0.000985,0.00055,0.94525


In [5]:
start_time = datetime.now()
print(f'Start: {start_time}')
add_cols = ['target', 'el_lhmedium', 'el_lhvloose']
collision_data = pd.read_parquet(collision_path, columns=shower_shapes_cols + add_cols)
read_time = datetime.now()
print(f'Time to read x {read_time-start_time}')
print(collision_data.shape)
jet_label = get_jet_label(collision_data, 'el_lhvloose')
el_label = get_electron_label(collision_data, 'el_lhmedium')
print(f'There are {jet_label.sum()} jets and {el_label.sum()} electrons')
label_time = datetime.now()
print(f'Time to labeling {label_time-read_time}')
print(f'Do electrons and jet have intersections? {(jet_label & el_label).any()}')
collision_data.drop(add_cols, axis=1, inplace=True)
el_data = collision_data.loc[el_label]
jet_data = collision_data.loc[jet_label]
del collision_data
drop_time = datetime.now()
print(f'Time to drop {drop_time-label_time}')
el_data.head()

Start: 2022-10-25 22:47:08.812884
Time to read x 0:00:14.341038
(43311283, 12)
There are 10906928 jets and 28955057 electrons
Time to labeling 0:00:00.230347
Do electrons and jet have intersections? False
Time to drop 0:00:01.565391


Unnamed: 0,trig_L2_cl_reta,trig_L2_cl_eratio,trig_L2_cl_f1,trig_L2_cl_f3,trig_L2_cl_wstot,trig_L2_cl_weta2,el_rhad,el_rhad1,el_rphi
0,0.972962,0.948686,0.388993,0.008554,1.520756,0.009378,-0.014604,-0.010184,0.954679
1,0.953338,0.948622,0.321716,0.003106,1.440343,0.009217,0.001482,0.001126,0.709704
2,1.013074,0.883187,0.487872,0.002307,1.51531,0.008863,-0.016014,-0.007992,0.825018
3,1.028176,0.926765,0.426812,-0.000522,2.569318,0.010547,0.008715,0.008617,1.005868
4,1.061034,0.974475,0.404273,-0.000863,2.09733,0.009515,0.034444,0.035376,0.959343


In [6]:
data = {
    'boosted': boosted_data,
    'el': el_data,
    'jet': jet_data
}

## Computing distances

In [37]:
ss_filters = {
    'f3': lambda x: x,
    'weta2': lambda x: x[x <= 99],
    'reta': lambda x: x,
    'wstot': lambda x: x[x != -9999],
    'eratio': lambda x: x[x < 98],
    'f1': lambda x: x,
    'rphi': lambda x: x[x.between(-0.5, 1.5, inclusive='both')],
    'rhad': lambda x: x,
    'rhad1': lambda x:x
}

In [38]:
data_combinations = combinations(data.keys(), 2)
combinations_str = [f'{left}_{right}' for left, right in combinations(data.keys(), 2)]
wass_distances = pd.DataFrame(index=shower_shapes, columns=combinations_str)
for ss_name, ss_col in shower_shapes.items():
    if ss_name != 'weta2':
        continue
    for left, right in combinations(data.keys(), 2):
        logger.info(f'{ss_name}: computing wasserstein_distance({left}, {right})')
        filter = ss_filters[ss_name]
        left_data = ss_filters[ss_name](data[left][ss_col])
        right_data = ss_filters[ss_name](data[right][ss_col])
        wass_distances.loc[ss_name, f'{left}_{right}'] = wasserstein_distance(left_data, right_data)
wass_distances.to_csv(os.path.join(datapath, 'wass_distances_v2.csv'))
wass_distances

2022-10-25 23:21:04,626 - weta2: computing wasserstein_distance(boosted, el)
2022-10-25 23:21:15,729 - weta2: computing wasserstein_distance(boosted, jet)
2022-10-25 23:21:19,534 - weta2: computing wasserstein_distance(el, jet)


Unnamed: 0,boosted_el,boosted_jet,el_jet
reta,0.012518,0.116747,0.106054
eratio,0.037337,0.540616,0.505021
f1,0.123779,0.015785,0.110229
f3,0.002957,0.013735,0.016679
wstot,0.208389,2.728402,2.596661
weta2,0.252399,0.005956,0.25604
rhad,0.008923,0.346042,0.34952
rhad1,0.00522,0.189539,0.191945
rphi,0.021686,0.12243,0.102108


In [40]:
ratios = dict(
ratio1=lambda x: (x['boosted_jet']-x['boosted_el'])/x['boosted_el'],
ratio2=lambda x: (x['boosted_jet']-x['boosted_el'])/x['el_jet'],
ratio3=lambda x: (x['boosted_jet']-x['boosted_el']))

In [42]:
for ratio, get_ratio in ratios.items():
    wass_distances[ratio] = wass_distances.apply(get_ratio, axis=1)
wass_distances.sort_values(by='ratio1', inplace=True, ascending=False)
wass_distances

Unnamed: 0,boosted_el,boosted_jet,el_jet,ratio1,ratio2,ratio3
rhad,0.008923,0.346042,0.34952,37.780687,0.964518,0.337119
rhad1,0.00522,0.189539,0.191945,35.308172,0.960268,0.184318
eratio,0.037337,0.540616,0.505021,13.479405,0.99655,0.503279
wstot,0.208389,2.728402,2.596661,12.092818,0.970482,2.520012
reta,0.012518,0.116747,0.106054,8.32635,0.982787,0.104229
rphi,0.021686,0.12243,0.102108,4.645591,0.986644,0.100744
f3,0.002957,0.013735,0.016679,3.644717,0.646163,0.010778
f1,0.123779,0.015785,0.110229,-0.872476,-0.979724,-0.107994
weta2,0.252399,0.005956,0.25604,-0.976402,-0.962513,-0.246442


In [43]:
wass_distances.sort_values(by='ratio2', ascending=False)

Unnamed: 0,boosted_el,boosted_jet,el_jet,ratio1,ratio2,ratio3
eratio,0.037337,0.540616,0.505021,13.479405,0.99655,0.503279
rphi,0.021686,0.12243,0.102108,4.645591,0.986644,0.100744
reta,0.012518,0.116747,0.106054,8.32635,0.982787,0.104229
wstot,0.208389,2.728402,2.596661,12.092818,0.970482,2.520012
rhad,0.008923,0.346042,0.34952,37.780687,0.964518,0.337119
rhad1,0.00522,0.189539,0.191945,35.308172,0.960268,0.184318
f3,0.002957,0.013735,0.016679,3.644717,0.646163,0.010778
weta2,0.252399,0.005956,0.25604,-0.976402,-0.962513,-0.246442
f1,0.123779,0.015785,0.110229,-0.872476,-0.979724,-0.107994
