In [25]:
import joblib 
import numpy as np
import pandas as pd 
from pathlib import Path
import matplotlib.pyplot as plt
import hist
from hist import Hist 
from uncertainties import ufloat, unumpy


In [2]:
period = "A" 
pythia_path = f'/global/cfs/projectdirs/atlas/hrzhao/HEP_Repo/QG_Calibration/tmp/pythia{period}_pred.pkl'

In [3]:
pythia_pd = joblib.load(pythia_path)

In [4]:
pythia_pd = pythia_pd[(pythia_pd["jet_nTracks"] > 1) & (pythia_pd["target"] != '-')] 
# remove jet with ntrk < 2 and the parton label = -1 

In [5]:
pythia_pd["target"] = pythia_pd["target"].astype(float)

In [6]:
pythia_pd.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 55968852 entries, 0 to 288769
Data columns (total 15 columns):
 #   Column                  Dtype  
---  ------                  -----  
 0   run                     float64
 1   event                   float64
 2   jet_pt                  float64
 3   jet_eta                 float64
 4   jet_nTracks             float64
 5   jet_trackWidth          float64
 6   jet_trackC1             float64
 7   jet_trackBDT            float64
 8   jet_PartonTruthLabelID  float64
 9   event_weight            float64
 10  is_forward              float64
 11  is_leading              float64
 12  pt_idx                  int64  
 13  target                  float64
 14  GBDT_newScore           float64
dtypes: float64(14), int64(1)
memory usage: 6.7 GB


In [23]:
pythia_pd.columns

Index(['run', 'event', 'jet_pt', 'jet_eta', 'jet_nTracks', 'jet_trackWidth',
       'jet_trackC1', 'jet_trackBDT', 'jet_PartonTruthLabelID', 'event_weight',
       'is_forward', 'is_leading', 'pt_idx', 'target', 'GBDT_newScore'],
      dtype='object')

In [None]:
def make_hist(values, bins, weights):
    # assuming bins numpy array with (start, stop, n_edges)
    _hist = Hist(hist.axis.Regular(bins=len(bins)-1, start=bins[0], stop=bins[-1], overflow=True, underflow=True), 
                                storage=hist.storage.Weight())
    _hist.fill(values, weight=weights)
    area = np.sum(_hist.values()) * _hist.axes[0].widths
    _normed_hist = _hist / area

    return _hist, _normed_hist

In [77]:
label_pt_bin = [500, 600, 800, 1000, 1200, 1500, 2000]
label_var = ["pt", "eta", "ntrk", "width", "c1", "bdt", "newBDT"]
label_leadingtype = ["LeadingJet", "SubLeadingJet"]
label_etaregion = ["Forward", "Central"]
label_jettype = ["Gluon", "Quark", "B_Quark", "C_Quark"]

label_var_map = {
    'pt':'jet_pt',
    'eta':'jet_eta',
    'ntrk':'jet_nTracks', 
    'width':'jet_trackWidth',
    'c1':'jet_trackC1', 
    'bdt':'jet_trackBDT', 
    'newBDT':'GBDT_newScore'
}

is_leading_map = {
    "LeadingJet": [1],
    "SubLeadingJet": [0],
}

is_forward_map = {
    "Forward": [1],
    "Central": [0],
}

label_jettype_map = {
    "Gluon" : [21], 
    "Quark" : [1, 2, 3],
    "B_Quark" : [5],
    "C_Quark" : [4],
}

HistBins = {
    'jet_pt' : np.linspace(500, 2000, 61),
    'jet_eta' : np.linspace(-2.5, 2.5, 51), 
    'jet_nTracks' : np.linspace(0, 60, 61),
    'jet_trackWidth' : np.linspace(0, 0.4, 61),
    'jet_trackC1' : np.linspace(0, 0.4, 61),
    'jet_trackBDT' : np.linspace(-1.0, 1.0, 101),
    'GBDT_newScore' : np.linspace(-1.0, 1.0, 101),
}

def digitize_pd(pd_input, reweight='event_weight'):
    values = []
    HistMap_unumpy = {}
    for pt_idx, pt in enumerate(label_pt_bin[-2:-1]):
        pt_input_idx = pd_input['pt_idx'] == pt_idx
        pd_input_at_pt = pd_input[pt_input_idx]

        for leadingtype in label_leadingtype:
            leadingtype_idx = pd_input_at_pt['is_leading'].isin(is_leading_map[leadingtype])
            pd_input_at_pt_leadingtype = pd_input_at_pt[leadingtype_idx]
            
            for eta_region in label_etaregion: 
                etaregion_idx = pd_input_at_pt_leadingtype['is_forward'].isin(is_forward_map[eta_region])
                pd_input_at_pt_leadingtype_etaregion = pd_input_at_pt_leadingtype[etaregion_idx]
                
                for jettype in label_jettype:
                    type_idx = pd_input_at_pt_leadingtype_etaregion['jet_PartonTruthLabelID'].isin(label_jettype_map[jettype])
                    pd_input_at_pt_leadingtype_etaregion_jettype = pd_input_at_pt_leadingtype_etaregion[type_idx]
                    for var in label_var:
                        key = f"{pt}_{leadingtype}_{eta_region}_{jettype}_{var}"
                        bin_var = HistBins[label_var_map[var]]

                        # TODO: Can change the format from unumpy to hist. Now just to test the plotting code. 
                        if len(pd_input_at_pt_leadingtype_etaregion_jettype) == 0: ## for subset, if len==0, give it an empty unumpy array
                            HistMap_unumpy[key] = unumpy.uarray(np.zeros(len(bin_var)-1), np.zeros(len(bin_var)-1))
                            continue
                        else:
                            _hist, _norm_hist = make_hist(values=pd_input_at_pt_leadingtype_etaregion_jettype[label_var_map[var]],
                                                bins=bin_var, weights=pd_input_at_pt_leadingtype_etaregion_jettype[reweight])
                            HistMap_unumpy[key] = unumpy.uarray(_hist.values(), np.sqrt(_hist.variances()))
                            values.append(pd_input_at_pt_leadingtype_etaregion_jettype)
                            
    return HistMap_unumpy, values


In [78]:
test, values = digitize_pd(pythia_pd)

In [92]:
def make_hist(values, bins, weights):
    # assuming bins numpy array with (start, stop, n_edges)
    _hist = Hist(hist.axis.Regular(bins=len(bins)-1, start=bins[0], stop=bins[-1], overflow=True, underflow=True), 
                                storage=hist.storage.Weight())
    _hist.fill(values, weight=weights)
    factor = np.sum(_hist.values())
    _normed_hist = _hist / (factor * _hist.axes[0].widths)

    return _hist, _normed_hist

In [11]:
pythia_pd.loc[pythia_pd['jet_PartonTruthLabelID'] == 1]

Unnamed: 0,run,event,jet_pt,jet_eta,jet_nTracks,jet_trackWidth,jet_trackC1,jet_trackBDT,jet_PartonTruthLabelID,event_weight,is_forward,is_leading,pt_idx,target,GBDT_newScore
7,364707.0,15124477.0,1606.541626,0.500120,18.0,0.003163,0.015169,-0.208305,1.0,0.002045,0.0,0.0,5,0.0,-2.264001
11,364707.0,15119785.0,1126.649902,-0.322667,14.0,0.008973,0.143120,-0.292958,1.0,0.001297,0.0,0.0,3,0.0,-2.755733
12,364707.0,15302504.0,1856.073364,0.165950,11.0,0.012667,0.083827,-0.517470,1.0,0.002276,0.0,1.0,5,0.0,-4.124952
18,364707.0,15299096.0,1913.370972,-0.328563,21.0,0.044191,0.257675,-0.010115,1.0,0.001816,0.0,1.0,5,0.0,-1.338654
20,364707.0,15130699.0,1882.922607,-0.433325,28.0,0.093330,0.283374,0.037830,1.0,0.004432,1.0,1.0,5,0.0,-1.011453
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
288742,364704.0,121007307.0,729.985596,1.092036,15.0,0.016490,0.195351,-0.102996,1.0,0.333955,1.0,1.0,1,0.0,-0.836292
288746,364704.0,121006482.0,602.241089,-1.354948,10.0,0.021024,0.116361,-0.343159,1.0,0.804881,1.0,1.0,1,0.0,-2.733771
288748,364704.0,121014092.0,721.598083,1.043867,11.0,0.010201,0.097644,-0.439583,1.0,0.303808,0.0,1.0,1,0.0,-3.110432
288753,364704.0,121015362.0,573.036743,-1.000995,12.0,0.013051,0.079654,-0.351687,1.0,0.597531,1.0,0.0,0,0.0,-2.462014


In [93]:
reweight_factor = joblib.load('../test_reweight/test_reweight_factor2.pkl')

In [94]:
reweight_factor

{500: {'jet_nTracks': {'quark_factor': array([0.        , 0.        , 0.76504383, 0.77453366, 0.78228986,
          0.78910322, 0.80815286, 0.8349122 , 0.85043639, 0.87644848,
          0.90028127, 0.93057548, 0.96142615, 0.98730172, 1.01508972,
          1.03846653, 1.05588899, 1.075738  , 1.09366685, 1.110069  ,
          1.12514566, 1.1222418 , 1.13719806, 1.15714548, 1.16301699,
          1.182969  , 1.18455275, 1.1917966 , 1.19317652, 1.21329927,
          1.221427  , 1.21577912, 1.22022633, 1.22589986, 1.25712537,
          1.25063228, 1.27060326, 1.24808208, 1.31119057, 1.25645101,
          1.23597459, 1.3256365 , 1.29446667, 1.31503331, 1.29904655,
          1.23445892, 1.43688728, 1.22237645, 1.39451109, 1.50267667,
          1.09579408, 1.18314577, 1.41519986, 1.39648542, 1.47515813,
          0.86330381, 0.70737141, 1.35553626, 1.01694768, 1.5672156 ]),
   'gluon_factor': array([0.        , 0.        , 0.51758646, 0.67283604, 0.78769898,
          0.82296798, 0.83735847, 0.

In [101]:
def attach_reweight_factor(pd_input, reweight_factor):
    reweighting_vars = ['jet_nTracks', 'jet_trackBDT', 'GBDT_newScore'] 
    for reweighting_var in reweighting_vars:
        pd_input[f'{reweighting_var}_quark_reweighting_weights'] = pd_input['event_weight'].copy()
        pd_input[f'{reweighting_var}_gluon_reweighting_weights'] = pd_input['event_weight'].copy()

    reweighted_sample = []
    #### reweight_factor[pt][var]['quark_factor']
    for pt_idx, pt in enumerate(label_pt_bin[:-1]):
        pd_input_at_pt = pd_input[pd_input['pt_idx'] == pt_idx]

        for reweighting_var in reweighting_vars:
            bin_var = HistBins[reweighting_var]
            quark_factor_idx = pd_input_at_pt.columns.get_loc(f'{reweighting_var}_quark_reweighting_weights')
            gluon_factor_idx = pd_input_at_pt.columns.get_loc(f'{reweighting_var}_gluon_reweighting_weights')

            quark_factor = reweight_factor[pt][reweighting_var]['quark_factor']
            gluon_factor = reweight_factor[pt][reweighting_var]['gluon_factor']

            var_idx = np.digitize(x=pd_input_at_pt[reweighting_var] , bins=bin_var) - 1  # Binned feature distribution 
            for i, score in enumerate(bin_var[:-1]): # Loop over the bins 
                mod_idx = np.where(var_idx == i)[0]
                pd_input_at_pt.iloc[mod_idx, quark_factor_idx] *= quark_factor[i]
                pd_input_at_pt.iloc[mod_idx, gluon_factor_idx] *= gluon_factor[i]
            
        reweighted_sample.append(pd_input_at_pt)                
    
    return pd.concat(reweighted_sample)

In [102]:
reweighted_pythia_pd = attach_reweight_factor(pd_input= pythia_pd, reweight_factor=reweight_factor)
# 2m for pythiaA

In [103]:
reweighted_pythia_pd

Unnamed: 0,run,event,jet_pt,jet_eta,jet_nTracks,jet_trackWidth,jet_trackC1,jet_trackBDT,jet_PartonTruthLabelID,event_weight,...,is_leading,pt_idx,target,GBDT_newScore,jet_nTracks_quark_reweighting_weights,jet_nTracks_gluon_reweighting_weights,jet_trackBDT_quark_reweighting_weights,jet_trackBDT_gluon_reweighting_weights,GBDT_newScore_quark_reweighting_weights,GBDT_newScore_gluon_reweighting_weights
53939,364707.0,15450286.0,554.589966,1.039075,20.0,0.061749,0.199279,-0.024167,21.0,0.002462,...,0.0,0,1.0,-0.189685,0.002770,0.002384,0.002413,0.002312,0.002385,0.002360
72879,364707.0,12004426.0,582.164856,-0.601817,21.0,0.025921,0.223785,0.227491,21.0,0.005609,...,0.0,0,1.0,1.010659,0.006295,0.005530,0.006245,0.005664,0.005609,0.005609
120607,364707.0,14720243.0,544.812134,1.529484,11.0,0.024930,0.126167,-0.292486,1.0,0.006114,...,0.0,0,0.0,-2.082529,0.005690,0.005314,0.005889,0.005741,0.006114,0.006114
203109,364707.0,12634591.0,576.683289,0.778905,6.0,0.021471,0.126657,-0.476749,2.0,0.001317,...,0.0,0,0.0,-2.955314,0.001064,0.001103,0.001231,0.001237,0.001317,0.001317
15349,364707.0,12000643.0,589.608765,0.122688,17.0,0.040411,0.153857,-0.048371,2.0,0.003040,...,0.0,0,0.0,-1.046684,0.003271,0.002827,0.002955,0.002846,0.003040,0.003040
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
450,364709.0,6594460.0,1659.721680,0.385620,23.0,0.027218,0.203993,0.015940,2.0,0.000006,...,1.0,5,0.0,-0.673563,0.000006,0.000006,0.000006,0.000006,0.000006,0.000006
451,364709.0,6594460.0,1618.221802,0.386658,29.0,0.029542,0.220075,0.158161,2.0,0.000006,...,0.0,5,0.0,0.429353,0.000006,0.000006,0.000006,0.000006,0.000006,0.000006
452,364709.0,6599446.0,1918.357178,0.592627,19.0,0.028643,0.202892,-0.121426,21.0,0.000015,...,1.0,5,1.0,-1.748678,0.000015,0.000014,0.000015,0.000014,0.000015,0.000015
454,364709.0,6593182.0,1899.958862,0.509660,42.0,0.103400,0.299461,0.196188,21.0,0.000011,...,1.0,5,1.0,0.009521,0.000013,0.000012,0.000012,0.000011,0.000012,0.000011


In [104]:
joblib.dump(reweighted_pythia_pd, f"reweighted_pythia{period}_pred.pkl")

['reweighted_pythiaA_pred.pkl']