In [None]:
import argparse
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import uproot 
import awkward as ak
from pathlib import Path

from typing import Dict, List 
import re
import pickle
from tqdm import tqdm
import logging
logging.basicConfig(level=logging.INFO)

import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 


In [None]:
pkl_file_path = '/global/cfs/projectdirs/atlas/hrzhao/qgcal/BDT_EPEML/perpared_dijets_w_etalabel/Processed_Samples_Pythia_Nov8/periodA/dijet_pythia_mc16Aevent.pkl'
pkl_file_path = Path(pkl_file_path)

In [None]:
with open(pkl_file_path, 'rb') as f:
    dijet_array_all_JZs = pickle.load(f)

In [None]:
dijet_array_all_JZs.keys()

In [None]:
n_files = 0 
n_events = 0
for i in range(1,10):
    n_files_JZ = len(dijet_array_all_JZs[f'JZ{i}']) 
    if n_files_JZ != 0 :
        n_events_JZ = len(np.concatenate(dijet_array_all_JZs[f'JZ{i}'], axis=0))
    else:
        n_events_JZ = 0
    n_files += n_files_JZ
    n_events += n_events_JZ
    logging.info(f"JZ{i}: num. of files: \t {n_files_JZ}, \t num. of events:{n_events_JZ}" ) 

logging.info(f"Stats: total num. of files: \t {n_files}, \t total num. of events {n_events}" ) 

# Format 1

In [None]:
output_pd = {}
truth_parton_idx = 6
eta_idx = 1
column_names = ["jet_pt", "jet_eta", "jet_nTracks", "jet_trackWidth", "jet_trackC1", "jet_trackBDT", "jet_PartonTruthLabelID", "total_weight"]
column_names_pd = ["event"] + column_names + ["is_forward", "target"]

for key, dijet_array_JZ in tqdm(dijet_array_all_JZs.items()):
    logging.info(f"Processing {key}....")

    n_files_JZ = len(dijet_array_JZ) 
    if  n_files_JZ == 0:
        logging.warning(f"{key} is empty! Skipping!... ")
        continue

    dijet_array_JZ = np.array(dijet_array_JZ)
    dijet_array_JZ = np.concatenate(dijet_array_JZ, axis=0)
    n_events_JZ = len(dijet_array_JZ)

    # Label is_forward
    forward_idx = np.argmax(np.abs(dijet_array_JZ[:,:,eta_idx]), axis=1) # compare abs eta of jets inside events
    central_idx = -1*forward_idx+1

    is_forward = np.zeros((len(dijet_array_JZ),2))
    is_forward[np.arange(len(is_forward)), forward_idx] = 1

    dijet_array_JZ_w_etalabel = np.concatenate((dijet_array_JZ, np.broadcast_to(is_forward[:,:,None], (dijet_array_JZ.shape[:2] + (1,)))), axis = 2)

    # Categorize truth ID 
    truth_parton_id = np.abs(dijet_array_JZ_w_etalabel[:,:,truth_parton_idx])
    target = -1* np.ones_like(dijet_array_JZ_w_etalabel[:,:,truth_parton_idx])
    gluon_idx = np.where(truth_parton_id == 21)
    quark_idx = np.where((truth_parton_id==1) | (truth_parton_id==2) | (truth_parton_id==3) | (truth_parton_id==4) | (truth_parton_id==5))

    target[gluon_idx] = 1
    target[quark_idx] = 0

    dijet_array_JZ_w_etalabel = np.concatenate((dijet_array_JZ_w_etalabel, np.broadcast_to(target[:,:,None], (dijet_array_JZ_w_etalabel.shape[:2] + (1,)))), axis = 2)

    # Flat
    events = np.repeat(np.arange(n_events_JZ), 2)
    dijet_array_JZ_w_etalabel = dijet_array_JZ_w_etalabel.reshape( (len(dijet_array_JZ_w_etalabel)*2, dijet_array_JZ_w_etalabel.shape[-1])) # shape is (2*nevents, 9)
    dijet_array_JZ_w_etalabel = np.concatenate((events[:,None], dijet_array_JZ_w_etalabel), axis = 1)

    pd_JZ = pd.DataFrame(data=dijet_array_JZ_w_etalabel, columns = column_names_pd)
    output_pd[key] = pd_JZ

In [None]:
output_pd['JZ8'].head()

In [None]:
output_path = '../pkls_etalabel/all_JZs_format1.pkl'
with open(output_path, 'wb') as f:
    pickle.dump(output_pd, f)

# Format 2

In [None]:
output_pd2 = {}
truth_parton_idx = 6
eta_idx = 1
jet1_vars = ["jet1_pt", "jet1_eta", "jet1_nTracks", "jet1_trackWidth", "jet1_trackC1", "jet1_trackBDT", "jet1_PartonTruthLabelID", "jet1_total_weight", "jet1_is_forward", "jet1_target"]
jet2_vars = ["jet2_pt", "jet2_eta", "jet2_nTracks", "jet2_trackWidth", "jet2_trackC1", "jet2_trackBDT", "jet2_PartonTruthLabelID", "jet2_total_weight", "jet2_is_forward", "jet2_target"]
column_names_pd = ["event"] + jet1_vars + jet2_vars

for key, dijet_array_JZ in tqdm(dijet_array_all_JZs.items()):
    logging.info(f"Processing {key}....")

    n_files_JZ = len(dijet_array_JZ) 
    if  n_files_JZ == 0:
        logging.warning(f"{key} is empty! Skipping!... ")
        continue

    dijet_array_JZ = np.array(dijet_array_JZ)
    dijet_array_JZ = np.concatenate(dijet_array_JZ, axis=0)
    n_events_JZ = len(dijet_array_JZ)

    # Label is_forward
    forward_idx = np.argmax(np.abs(dijet_array_JZ[:,:,eta_idx]), axis=1) # compare abs eta of jets inside events
    central_idx = -1*forward_idx+1

    is_forward = np.zeros((len(dijet_array_JZ),2))
    is_forward[np.arange(len(is_forward)), forward_idx] = 1

    dijet_array_JZ_w_etalabel = np.concatenate((dijet_array_JZ, np.broadcast_to(is_forward[:,:,None], (dijet_array_JZ.shape[:2] + (1,)))), axis = 2)

    # Categorize truth ID 
    truth_parton_id = np.abs(dijet_array_JZ_w_etalabel[:,:,truth_parton_idx])
    target = -1* np.ones_like(dijet_array_JZ_w_etalabel[:,:,truth_parton_idx])
    gluon_idx = np.where(truth_parton_id == 21)
    quark_idx = np.where((truth_parton_id==1) | (truth_parton_id==2) | (truth_parton_id==3) | (truth_parton_id==4) | (truth_parton_id==5))

    target[gluon_idx] = 1
    target[quark_idx] = 0

    dijet_array_JZ_w_etalabel = np.concatenate((dijet_array_JZ_w_etalabel, np.broadcast_to(target[:,:,None], (dijet_array_JZ_w_etalabel.shape[:2] + (1,)))), axis = 2)

    # Flat
    events = np.arange(n_events_JZ)
    dijet_array_JZ_w_etalabel = dijet_array_JZ_w_etalabel.reshape( (len(dijet_array_JZ_w_etalabel), 2* dijet_array_JZ_w_etalabel.shape[-1])) # shape is (nevents, 10*2)
    dijet_array_JZ_w_etalabel = np.concatenate((events[:,None], dijet_array_JZ_w_etalabel), axis = 1)

    pd_JZ = pd.DataFrame(data=dijet_array_JZ_w_etalabel, columns = column_names_pd)
    output_pd2[key] = pd_JZ

 

In [None]:
output_pd2['JZ8'].head()

In [None]:
output_path2 = '../pkls_etalabel/all_JZs_format2.pkl'
with open(output_path2, 'wb') as f:
    pickle.dump(output_pd2, f)