In [None]:
import argparse
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import uproot 
import awkward as ak
from pathlib import Path

from typing import Dict, List 
import re
import pickle
from tqdm import tqdm

In [None]:
file = '/global/cfs/projectdirs/atlas/hrzhao/qgcal/BDT_EPEML/pkls_etalabel/all_JZs_format1.pkl'
all_sample = pd.read_pickle(file)

In [None]:
all_JZs_list = []
for key, value in all_sample.items():
    all_JZs_list.append(value)

all_jets = pd.concat(all_JZs_list)

In [None]:
all_jets = all_jets.drop('event', axis = 1)

## Physical weights

In [None]:
fig, ax = plt.subplots()
pt_edges = np.linspace(500, 2000, 61)
bin_contents, bin_edges, _ = ax.hist(all_jets['jet_pt'], bins=pt_edges, weights=all_jets['total_weight'])
ax.set_yscale('log')
plt.plot()


## Flatten Pt

### Plot

In [None]:
flat_weight_factor = 1./bin_contents
pt_binned_sample_alljets_idx = np.digitize(all_jets['jet_pt'], pt_edges)

In [None]:
all_jets['flatpt_weight'] = all_jets['total_weight']

In [None]:
for i in tqdm(range(np.max(pt_binned_sample_alljets_idx))):
    mod_idx = np.where(pt_binned_sample_alljets_idx-1 == i)[0]
    all_jets.iloc[mod_idx, all_jets.columns.get_loc('flatpt_weight')] *= flat_weight_factor[i]
    ## df.iloc[0, df.columns.get_loc('COL_NAME')] = x

In [None]:
fig, ax = plt.subplots()
pt_edges = np.linspace(500, 2000, 61)
bin_contents, bin_edges, _ = ax.hist(all_jets['jet_pt'], bins=pt_edges, weights=all_jets['flatpt_weight'])
# ax.set_yscale('log')
plt.plot()


### Adjust the columns 

In [None]:
all_jets['equal_weight'] = np.ones(len(all_jets))

In [None]:
all_jets

In [None]:
all_jets.rename(columns = {'total_weight':'event_weight'}, inplace = True)
col_list = list(all_jets)
print(col_list)

In [None]:
insert_pos = col_list.index('event_weight')
adj_col_list = col_list[:insert_pos] + ['equal_weight'] + [col_list[insert_pos]] + [col_list[-2]] + col_list[insert_pos+2:-2]

In [None]:
adj_col_list

In [None]:
all_jets = all_jets.reindex(columns = adj_col_list)
all_jets = all_jets[all_jets['event_weight'] != 0]

In [None]:
all_jets.head()

In [None]:
all_jets.to_pickle("./sample_all_jets.pkl")

In [None]:
with open("./sample_all_jets.pkl", 'rb') as f:
    all_jets = pd.read_pickle(f)

## Sample 1500GeV

### 2.8M jets

In [None]:
sample_1500 = all_jets[(all_jets['jet_pt'] >= 1500) & (all_jets['jet_PartonTruthLabelID'] != -1)]

In [None]:
sample_1500.head()

In [None]:
sample_1500.describe()

In [None]:
sample_1500.to_pickle("./sample_1500_all_jets.pkl")

In [None]:
with open("./sample_1500_all_jets.pkl", 'rb') as f:
    test_sample = pd.read_pickle(f)

In [None]:
sum(test_sample['target']==1)

In [None]:
sum(test_sample['target']==0)

In [None]:
sample_quark = sample_1500[sample_1500['target']==0]
sample_gluon = sample_1500[sample_1500['target']==1]

n_quark = len(sample_quark)
n_gluon = len(sample_gluon)

In [None]:
n_sample = np.min([n_quark, n_gluon])

In [None]:
subset_sample_quark = sample_quark.sample(n=n_sample)
subset_sample_gluon = sample_gluon.sample(n=n_sample)


In [None]:
subset_sample_1500 = pd.concat([subset_sample_quark, subset_sample_gluon])

In [None]:
len(subset_sample_1500)

In [None]:
subset_sample_1500.to_pickle("./sample_1500_2p8M_jets.pkl")

### 200k jets

In [None]:
all_jets = "./sample_all_jets.pkl" 

with open(all_jets, 'rb') as f:
    all_jets = pd.read_pickle(f)



In [None]:
all_jets.head()

In [None]:
sample_1500 = all_jets[(all_jets['jet_pt'] >= 1500) & (all_jets['jet_PartonTruthLabelID'] != -1)]

In [None]:
n_sample = 100_000
sample_quark = sample_1500[sample_1500['target']==0]
sample_gluon = sample_1500[sample_1500['target']==1]
subset_sample_quark = sample_quark.sample(n=n_sample, random_state=42)
subset_sample_gluon = sample_gluon.sample(n=n_sample, random_state=42)
subset_sample_1500 = pd.concat([subset_sample_quark, subset_sample_gluon])
subset_sample_1500.to_pickle("./sample_1500_200k_jets.pkl")

In [None]:
subset_sample_1500.tail()

# Sample 12M jets 

In [None]:
all_jets = "./sample_all_jets.pkl" 

with open(all_jets, 'rb') as f:
    all_jets = pd.read_pickle(f)

In [None]:
all_jets.shape

In [None]:
all_jets.head()

In [None]:
#### Remove 
all_jets = all_jets[(all_jets['jet_PartonTruthLabelID'] != -1) & (all_jets['jet_nTracks'] >= 2)]

In [None]:
label_pt_bin = [500, 600, 800, 1000, 1200, 1500, 2000]
all_jets['pt_idx'] = pd.cut(x=all_jets['jet_pt'], bins=label_pt_bin, right=False, labels=False)


In [None]:
quark_jets = all_jets[all_jets['target'] == 0]
gluon_jets = all_jets[all_jets['target'] == 1]
n_sample_ptbin = 1_000_000
jets_list = []

for jets in [quark_jets, gluon_jets]:
    for pt_idx, pt in enumerate(label_pt_bin[:-1]):
        jets_pt = jets[jets['pt_idx'] == pt_idx] 
        jets_list.append(jets_pt.sample(n=n_sample_ptbin, random_state = 42))

In [None]:
jets_allpt_12M = pd.concat(jets_list)

In [None]:
jets_allpt_12M.shape

In [None]:
jets_allpt_12M.to_pickle("./sample_allpt_12M_jets.pkl")