# Pre-processing

This notebook aims to do pre-processing before making the fit templates. The pre-processing includes calculating new variables and extracting reweight factors. 

The data analyzing is backed by `awkward-array` with `coffea` non-processor workflow.

Please first setup `config.yaml`, then run the notebook in the order of blocks. The results are stored in the backup folder `prep/` with your specified "routine name" and "year condition" in the yaml file. 

In [None]:
from coffea.nanoevents import NanoEventsFactory, TreeMakerSchema, BaseSchema
import awkward1 as ak
import uproot4 as uproot
import numpy as np
import math
import os

In [None]:
from data_utils import get_hist, plot_hist
from cycler import cycler
import boost_histogram as bh
import matplotlib.pyplot as plt
import matplotlib as mpl
# mpl.use('AGG') # no rendering plots in the window

import mplhep as hep
use_helvet = True  ## true: use helvetica for plots, make sure the system have the font installed
if use_helvet:
    CMShelvet = hep.style.CMS
    CMShelvet['font.sans-serif'] = ['Helvetica', 'Arial']
    plt.style.use(CMShelvet)
else:
    plt.style.use(hep.style.CMS)

In [None]:
## Load the config.yml
import yaml
with open('cards/config_bb_std.yml') as f:
    config = yaml.safe_load(f)
config

## 1. Load files

Load the ROOT files into lazy awkward arrays

In [None]:
year = config['year']

lumi = {2016: 35.92, 2017: 41.53, 2018: 59.74}

read_sample_list_map = {
    'qcd-mg-noht': 'mc/qcd-mg_tree.root',
    'qcd-herwig-noht': 'mc/qcd-herwig_tree.root',
    'top-noht': 'mc/top_tree.root',
    'v-qq-noht': 'mc/v-qq_tree.root',
    'jetht-noht': 'data/jetht_tree.root',
}
if config['samples']['use_bflav']:
    read_sample_list_map['qcd-mg-bflav-noht'] = 'mc/qcd-mg-bflav_tree.root'
omit_herwig = 'optional' in config['samples'] and 'omit_herwig' in config['samples']['optional'] and config['samples']['optional']['omit_herwig']
if omit_herwig:
    read_sample_list_map.pop('qcd-herwig-noht', None)
if 'optional' in config['samples'] and 'exclude_mc_sample_in_preprocessing' in config['samples']['optional']:
    for ex_sam in config['samples']['optional']['exclude_mc_sample_in_preprocessing']:
        read_sample_list_map.pop(ex_sam, None)
print('Read samples for preprocessing:', read_sample_list_map.keys())

## Read the root file into lazy awkward arrays
arr = {}
sample_prefix = f"{config['samples']['sample_prefix']}_{year}"
for sam in read_sample_list_map:
    arr[sam] = NanoEventsFactory.from_root(f'{sample_prefix}/{read_sample_list_map[sam]}', schemaclass=BaseSchema).events()

## Store the branch names
def get_stored_branches(arr, read_sample_list_map):
    stored_branches = {}
    for sam in read_sample_list_map:
        stored_branches[sam] = ak.fields(arr[sam])
    return stored_branches
stored_branches = get_stored_branches(arr, read_sample_list_map)
stored_branches_interm = {}
store_name = f"{config['samples']['name']}_SF{config['year']}"

In [None]:
## You can optionally run this block to extend the custom coffea NanoEventsArray to ExtendedNanoEventsArray.
## This functionality will automatically store and read NEW variables from disk instead of having them all in memory.

from data_utils import ExtendedNanoEventsArray
def use_extended_nanoarray(arr, store_name):
    for k in arr:
        arr[k] = ExtendedNanoEventsArray(arr[k])
        arr[k].record_awkward_items()
        arr[k].set_backup_path(f'prep/{store_name}/{k}/') # backup directly to backup_array destination

use_extended_nanoarray(arr, store_name)

## 2. Apply baseline selections

For data: apply OR of all HT trigger to enhance statistics.

For MC: apply no HT trigger, based on the strategy we name it "MC substitute".

We define an attribute `maskdict` in each sample that stores masks corresponding to different selections.

In [None]:
def backup_array(backup_name, stored_branches, read_sample_list_map, global_key_only=False, ext_branches=[]):
    r"""Backup newly produced variables in the awkward array list to pickle.
    
    Arguments:
        backup_name: name of backup folder
        stored_branches: branches already stored in previous routines. New branches that does not appear in the list will be stored, and be updated in the list.
        read_sample_list: sample list to read.
    """

    import pickle
    if not global_key_only:
        for sam in read_sample_list_map:
            if 'ExtendedNanoEventsArray' in str(type(arr[sam])):
                ext_branches = []
            if not os.path.exists(f'prep/{backup_name}/{sam}'):
                os.makedirs(f'prep/{backup_name}/{sam}')
            for var in set(ak.fields(arr[sam])) - set(stored_branches[sam]) | set([br for br in ext_branches if br in ak.fields(arr[sam])]):
                with open(f'prep/{backup_name}/{sam}/{var}', 'wb') as fw:
                    pickle.dump(arr[sam][var], fw)
                print('storing...', sam, var)
            if hasattr(arr[sam], 'maskdict'):
                with open(f'prep/{backup_name}/{sam}/maskdict', 'wb') as fw:
                    pickle.dump(arr[sam].maskdict, fw)
    for key in arr.keys():
        if key not in read_sample_list_map and not key.startswith('subst_') and key != 'real-signal':
            with open(f'prep/{backup_name}/{key}', 'wb') as fw:
                pickle.dump(arr[key], fw)
            print('storing additonal keys...', key)

## Fetch variables from the backup file
def load_backup_array(backup_name, read_sample_list_map):
    r"""Load newly stored variables to the awkwary array list.
    
    Arguments:
        backup_name: name of backup folder
        read_sample_list: sample list to read.
    """

    import pickle
    for sam in os.listdir(f'prep/{backup_name}'):
        if sam in read_sample_list_map:
            for var in os.listdir(f'prep/{backup_name}/{sam}'):
                if var.startswith('.'):
                    continue
                if var == 'maskdict':
                    arr[sam].maskdict = {}
                    with open(f'prep/{backup_name}/{sam}/maskdict', 'rb') as f:
                        arr[sam].maskdict = pickle.load(f)
                    print('loading...', sam, 'maskdict', arr[sam].maskdict.keys())
                elif 'ExtendedNanoEventsArray' not in str(type(arr[sam])): # not using the extended nanoarray functionality
                    with open(f'prep/{backup_name}/{sam}/{var}', 'rb') as f:
                        arr[sam][var] = pickle.load(f)
                    print('loading...', sam, var)
            if sam != 'jetht-noht':
                arr['subst_'+sam] = arr[sam] # make a reference
        elif not sam.startswith('.') and os.path.isfile(f'prep/{backup_name}/{sam}'):
            with open(f'prep/{backup_name}/{sam}', 'rb') as f:
                arr[sam] = pickle.load(f)
            print('loading...', sam)

In [None]:
def eval_expr(ak_array, expr, mask=None):
    """A function that can do `eval` to the awkward array, immitating the behavior of `eval` in pandas."""
    
    def get_variable_names(expr, exclude=['awkward', 'ak', 'np', 'numpy', 'math']):
        """Extract variables in the expr"""
        import ast
        root = ast.parse(expr)
        return sorted({node.id for node in ast.walk(root) if isinstance(node, ast.Name) and not node.id.startswith('_')} - set(exclude))

    tmp = {k:ak_array[k] if mask is None else ak_array[k].mask[mask] for k in get_variable_names(expr)}
    tmp.update({'math': math, 'numpy': np, 'np': np, 'awkward': ak, 'ak': ak})
#     print('eval expr: ', expr, '\nvars', get_variable_names(expr))
    return eval(expr, tmp)

In [None]:
def mask_and(arr, mask_list):
    """Calculate AND of given mask list"""
    return np.logical_and.reduce([arr.maskdict[mask] for mask in mask_list])

def concat_array(arrdict, expr, sam_list, filter_list):
    """Concatenate the awkward arrays passing the given filter list"""
    if not isinstance(sam_list, list):
        sam_list = [sam_list]
    return np.concatenate([
        np.array(eval_expr(arrdict[sam], expr)[mask_and(arrdict[sam], filter_list)]) for sam in sam_list
    ])

def mask_and_fj12(arr, mask_list):
    """Comibne `mask_and` result for fj_1 and fj_2"""
    mask_list_fj1 = [ele.replace('fj_x', 'fj_1') for ele in mask_list]
    mask_list_fj2 = [ele.replace('fj_x', 'fj_2') for ele in mask_list]
    return np.concatenate([mask_and(arr, mask_list_fj1), mask_and(arr, mask_list_fj2)])

def concat_array_fj12(arrdict, expr, sam_list, filter_list):
    """Comibne `concat_array` result for fj_1 and fj_2"""
    filter_list_fj1 = [ele.replace('fj_x', 'fj_1') for ele in filter_list]
    filter_list_fj2 = [ele.replace('fj_x', 'fj_2') for ele in filter_list]
    return np.concatenate([concat_array(arrdict, expr.replace('fj_x', 'fj_1'), sam_list, filter_list_fj1), 
                           concat_array(arrdict, expr.replace('fj_x', 'fj_2'), sam_list, filter_list_fj2)])

def calc_rwgt_akarray(arr, rwgt_edge, rwgt):
    """Calculate the weight ak-array based on the value ak-array of the reweight variable"""
    arr_out = (arr<rwgt_edge[0])*rwgt[0]
    for i in range(len(rwgt_edge)-1):
        arr_out = arr_out + ((arr>=rwgt_edge[i]) & (arr<rwgt_edge[i+1]))*rwgt[i+1]
    arr_out = arr_out + (arr>=rwgt_edge[-1])*rwgt[-1]
    return arr_out

In [None]:
# ## If you have run some blocks in the following - use this to restore the backup variables directly
# load_backup_array(store_name, read_sample_list_map)

In [None]:
### ================ Pre-processing for data  ===================

## Baseline selection applied to data. 
## Note that we use the OR or all HT triggers (some are pre-scaled triggers)

if 'prep-data' not in stored_branches_interm:
    stored_branches_interm['prep-data'] = get_stored_branches(arr, read_sample_list_map)
hlt_branches = {  ## used HLT_PFHT* branches depend on year
    2016: ['HLT_PFHT125', 'HLT_PFHT200', 'HLT_PFHT250', 'HLT_PFHT300', 'HLT_PFHT350', 'HLT_PFHT400', 'HLT_PFHT475', 'HLT_PFHT600', 'HLT_PFHT650', 'HLT_PFHT800', 'HLT_PFHT900'],
    2017: ['HLT_PFHT180', 'HLT_PFHT250', 'HLT_PFHT370', 'HLT_PFHT430', 'HLT_PFHT510', 'HLT_PFHT590', 'HLT_PFHT680', 'HLT_PFHT780', 'HLT_PFHT890', 'HLT_PFHT1050', 'HLT_PFHT350'],
    2018: ['HLT_PFHT180', 'HLT_PFHT250', 'HLT_PFHT370', 'HLT_PFHT430', 'HLT_PFHT510', 'HLT_PFHT590', 'HLT_PFHT680', 'HLT_PFHT780', 'HLT_PFHT890', 'HLT_PFHT1050', 'HLT_PFHT350'],
}
htcut_incl = '('+' | '.join(hlt_branches[year])+')'
basesel_ext_noht_prep = f"passmetfilters & (fj_x_pt>200) & fj_x_is_qualified"
sl_prep = ['jetht-noht']

for sam in sl_prep:
    assert 'noht' in sam
    arr[sam].maskdict = {}
    arr[sam].maskdict['hlt'] = eval_expr(arr[sam], htcut_incl)
    for i in '12':
        ## The baseline selection for data
        print('baseline selection for data: ', sam, f'jet{i}')
        arr[sam].maskdict[f'fj_{i}_base'] = arr[sam].maskdict['hlt'] & eval_expr(arr[sam], basesel_ext_noht_prep.replace('fj_x', f'fj_{i}'))

## Store new variables
stored_branches = backup_array(store_name, stored_branches_interm['prep-data'], read_sample_list_map)

In [None]:
# ## FOR TEST: check the xsecWeight for MG samples & genWeight for Herwig sample (to avoid extremely large values) 
# from collections import Counter
# print(Counter(np.array(arr['qcd-mg-noht'].xsecWeight)),'\n')
# for i in [0.96, 0.98, 0.99]:
#     print(np.quantile(np.array(arr['qcd-herwig-noht'].genWeight), q=i))

In [None]:
### ================ Pre-processing for MC substitute  ===================

## Baseline selection applied to MC.
## No HT trigger is applied, based on the "MC substitute" strategy
if 'prep-mc' not in stored_branches_interm:
    stored_branches_interm['prep-mc'] = get_stored_branches(arr, read_sample_list_map)

basesel_noht_prep_subst = "passmetfilters & (fj_x_pt>200) & fj_x_is_qualified"
## Mark sample name with "subst_" as a reminder of MC substitute. Default is ['subst_qcd-mg-noht', 'subst_qcd-herwig-noht', 'subst_top-noht', 'subst_v-qq-noht']  
sl_prep_subst = ['subst_'+sam for sam in read_sample_list_map.keys() if 'jetht' not in sam]
if config['samples']['use_bflav']:
    sl_prep_subst += ['subst_qcd-mg-bflav-noht']
for sam in sl_prep_subst:
    assert 'noht' in sam
    arr[sam] = arr[sam.replace('subst_','')]  ## use the name subst_ as a ref
    arr[sam].maskdict = {}
    for i in '12':
        print('baseline selection for: ', sam, f'jet{i}')
        arr[sam].maskdict[f'fj_{i}_base'] = eval_expr(arr[sam], basesel_noht_prep_subst.replace('fj_x', f'fj_{i}'))
        ## Drop MG events with extremely large xsecWeight (coming from low HT sample in the HT-binned MG list)
        if sam == 'subst_qcd-mg-noht':
            arr[sam].maskdict[f'fj_{i}_base'] = arr[sam].maskdict[f'fj_{i}_base'] & eval_expr(arr[sam], 'xsecWeight<5.')
        ## Drop Herwig events with extremely large genWeight
        if sam == 'subst_qcd-herwig-noht':
            arr[sam].maskdict[f'fj_{i}_base'] = arr[sam].maskdict[f'fj_{i}_base'] & eval_expr(arr[sam], 'genWeight<{}'.format(np.quantile(np.array(arr[sam].genWeight), q=0.96)))
    ## Fix a 2016 bug: Herwig sample xsec is mistaken
    if year == 2016 and sam == 'subst_qcd-herwig-noht' and not hasattr(arr[sam], 'xsecWeight_is_normed'):
        arr[sam]['xsecWeight'] = arr[sam]['xsecWeight'] * 2400.
        arr[sam]['xsecWeight_is_normed'] = True

## Produce new variables used for fit
for sam in sl_prep + sl_prep_subst:
    for i in '12':
        _mask = arr[sam].maskdict[f'fj_{i}_base']
        print('calculating new vars for: ', sam, f'jet{i}')
        arr[sam][f'fj_{i}_mSV12_ptmax'] = eval_expr(arr[sam], f'(fj_{i}_sj1_sv1_pt>fj_{i}_sj2_sv1_pt)*fj_{i}_sj1_sv1_masscor + (fj_{i}_sj1_sv1_pt<=fj_{i}_sj2_sv1_pt)*fj_{i}_sj2_sv1_masscor', mask=_mask)
        arr[sam][f'fj_{i}_mSV12_ptmax_log'] = eval_expr(arr[sam], f'np.log(fj_{i}_mSV12_ptmax)', mask=_mask)
        arr[sam][f'fj_{i}_mSV12_dxysig'] = eval_expr(arr[sam], f'(fj_{i}_sj1_sv1_dxysig>fj_{i}_sj2_sv1_dxysig)*fj_{i}_sj1_sv1_masscor + (fj_{i}_sj1_sv1_dxysig<=fj_{i}_sj2_sv1_dxysig)*fj_{i}_sj2_sv1_masscor', mask=_mask)
        arr[sam][f'fj_{i}_mSV12_dxysig_log'] = eval_expr(arr[sam], f'np.log(fj_{i}_mSV12_dxysig)', mask=_mask)

## Store new variables
stored_branches = backup_array(store_name, stored_branches_interm['prep-mc'], read_sample_list_map, ext_branches=['xsecWeight'])

## 3. Obtain reweight factors

We need to extract some reweight factors as well as the BDT variation points specific to pT ranges. Step 1-3 are necessary for the nominal fit routine. Factors obtained from step 4-5 are for validation fits.

 3-1. **MC substitute-to-data reweight factor**: reweight based on the 3D (HT, pT, jet index) grid. The goal is to bring the shape of MC (without pre-selection on the jet-HT triggers) back to the data shape (passing the logical OR of prescaled jet-HT triggers). Remember that the raw MC always yields much larger than data. New variables take the name `htwgt`, `htwgt_herwig`. (`htwgt_herwig` is derived using the Herwig QCD sample and is only used in the validation fit.)

 3-2. **sfBDT central point and variation range**: a set of sfBDT cut values which are specific for different pT range. The values are derived by judging the similarity of the tagger shape between the signal and proxy jet samples.

 3-3. **sfBDT reweight factor**: reweight on the sfBDT variable based on (pT, jet index) bins. The reweight factors `sfbdtwgt_g50` are obtained, which is only used to derive the systematics shape templates in the nominal fit. `sfbdtwgt_g50_herwig` is derived as well using the Herwig sample, used in the validation fit.

 3-4. **Additional MC substitute-to-data reweight factor on $p_{T}$ only**: A possible replacement of the factors in step 1. This factor is only used in the validation fit to check if different reweighting schemes may affect the SF fit results. New variables take the name `ad_ptwgt` and `ad_ptwgt_herwig`.
 
 3-5. **Proxy-to-signal reweight factor on $m_{SD}$ / $p_{T}$ / $\tau_{21}$**: based on the shape of MC after applying the MC-to-data factors in step 1 and the H->cc signal jet shape. The factor is only used in the validation fit, in which we apply such reweight factor to both MC and data to check if the SF results are affected. New variables take the name `(mass|pt|tau21)datamcwgt` and `(mass|pt|tau21)datamcwgt_herwig`

In [None]:
### ================ 3-1. Reweight MC subsitute to data: stored as variable "fj_x_htwgt", "fj_x_htwgt_herwig") ===================

## True: if the block has run before, we can obtain the reweight factor from the previously stored pickle output
is_read_from_pickel = False

if 'prep-3-1' not in stored_branches_interm:
    stored_branches_interm['prep-3-1'] = get_stored_branches(arr, read_sample_list_map)

def extract_source_to_target_ht_weight(arr, sl_rwgt_source, wgtstr_rwgt_source, sl_rwgt_target, wgtstr_rwgt_target, wgtname, ext_sl_rwgt_source=[], presel='', do_plot=True):
    r"""Extract the "MC subsisute to data" reweight factor on HT based on (pT, jet index) bins
    
    Arguments:
        arr: awkward array dict as input
        sl_rwgt_(source|target): sample list for the source/target in this reweighting routine
        wgtstr_rwgt_(source|target): the weight string applied to the source/target to produce the histogram in this reweighting routine
        wgtname: the reweight name stored as a new column
        ext_sl_rwgt_source: extra source sample list for which we also calculate the reweight factors after extracting them
        presel: additonal pre-selection applied before reweigting
        do_plot: if store plots of reweighting
    """

    rwgt_var = 'ht'
    ## The binning info for (pT, HT) grid. The adopted HT grid is based on MC shape in each pT bin
    rwgt_edge_dic = {}
    def linear_edge(start, end):
        return list(np.arange(start, end-50, 50)) + [end]
    rwgt_edge_dic[2016] = rwgt_edge_dic[2017] = rwgt_edge_dic[2018] = {
        'jet1': {
            'pt200to250': linear_edge(250, 1250),
            'pt250to300': linear_edge(350, 1400),
            'pt300to350': linear_edge(400, 1600),
            'pt350to400': linear_edge(450, 1700),
            'pt400to450': linear_edge(500, 1800),
            'pt450to500': linear_edge(550, 1900),
            'pt500to550': linear_edge(600, 1900),
            'pt550to600': linear_edge(650, 2000),
            'pt600to700': linear_edge(700, 2100),
            'pt700to800': linear_edge(800, 2200),
            'pt800to100000': linear_edge(1000, 2400),
        },
        'jet2': {
            'pt200to250': linear_edge(250, 1500),
            'pt250to300': linear_edge(350, 1600),
            'pt300to350': linear_edge(400, 1800),
            'pt350to400': linear_edge(450, 2000),
            'pt400to450': linear_edge(500, 2200),
            'pt450to500': linear_edge(550, 2400),
            'pt500to550': linear_edge(650, 2400),
            'pt550to600': linear_edge(750, 2400),
            'pt600to700': linear_edge(850, 2400),
            'pt700to800': linear_edge(1000, 2400),
            'pt800to100000': linear_edge(1200, 2400),
        },

    }
    
    ## Initially fill the output column with 0, since we will fill the column iteratively for each pT bin
    for sam in sl_rwgt_source + ext_sl_rwgt_source:
        for i in '12':
            arr[sam][wgtname.replace('fj_x', f'fj_{i}')] = ak.zeros_like(arr[sam][rwgt_var])

    if is_read_from_pickel: ## restore info from a previously stored pickle
        import pickle
        with open(f'prep/{store_name}/plots/{wgtname}_{year}.pickle', 'rb') as f:
            res = pickle.load(f)
            ent_target, ent_source, rwgt = res['ent_target'], res['ent_source'], res['rwgt']
    else:
        ent_target, ent_source, rwgt = {}, {}, {}

    ## Rewight separately on jet pT bins
    ptlab_list = list(rwgt_edge_dic[year]['jet1'].keys())
    import re
    ptcutval_list = [re.findall('pt(\d+)to(\d+)', s)[0] for s in ptlab_list]
    for ptsel, ptlab in zip([f'(fj_x_pt>={ptmin}) & (fj_x_pt<{ptmax})' for ptmin, ptmax in ptcutval_list], ptlab_list):
        ## Reweight separately for 1st or 2nd jet
        for i, lab in zip(['1','2'], ['jet1','jet2']):
            print (' -- ', ptsel, lab)
            rwgt_edge = rwgt_edge_dic[year][lab][ptlab]
            ## Calculate the rwgt for the first time
            if not is_read_from_pickel:
                for sam in set(sl_rwgt_target) | set(sl_rwgt_source) | set(ext_sl_rwgt_source):
                    arr[sam].maskdict[f'fj_{i}_{ptlab}'] = eval_expr(arr[sam], ptsel.replace('fj_x', f'fj_{i}'))
                    if presel != '':  ## has additional preselection
                        arr[sam].maskdict[f'_tmp_fj_{i}_presel'] = eval_expr(arr[sam], presel.replace('fj_x', f'fj_{i}'))

                ## Get data and MC histogram. Note: consider underflow & overflow bins, hence len = nbins+2
                ent_target[ptlab+lab] = get_hist(
                    concat_array(arr, expr=rwgt_var, sam_list=sl_rwgt_target, filter_list=[f'fj_{i}_base', f'fj_{i}_{ptlab}']+([f'_tmp_fj_{i}_presel'] if presel != '' else [])),
                    bins=rwgt_edge,
                    weights=concat_array(arr, expr=wgtstr_rwgt_target if wgtstr_rwgt_target!='1' else '0*ht+1', sam_list=sl_rwgt_target, filter_list=[f'fj_{i}_base', f'fj_{i}_{ptlab}']+([f'_tmp_fj_{i}_presel'] if presel != '' else [])),
                    underflow=True, overflow=True, mergeflowbin=False
                ).view(flow=True).value
                ent_source[ptlab+lab] = get_hist(
                    concat_array(arr, expr=rwgt_var, sam_list=sl_rwgt_source, filter_list=[f'fj_{i}_base', f'fj_{i}_{ptlab}']+([f'_tmp_fj_{i}_presel'] if presel != '' else [])),
                    bins=rwgt_edge,
                    weights=concat_array(arr, expr=wgtstr_rwgt_source if wgtstr_rwgt_source!='1' else '0*ht+1', sam_list=sl_rwgt_source, filter_list=[f'fj_{i}_base', f'fj_{i}_{ptlab}']+([f'_tmp_fj_{i}_presel'] if presel != '' else [])),
                    underflow=True, overflow=True, mergeflowbin=False
                ).view(flow=True).value
                ## Calculate the reweight factor
                rwgt[ptlab+lab] = np.nan_to_num(ent_target[ptlab+lab] / ent_source[ptlab+lab], nan=0) # len=nbin+2
                rwgt[ptlab+lab] = np.clip(rwgt[ptlab+lab], 0, 2) # clip between [0, 2]
            print(ent_target[ptlab+lab], '\n', rwgt[ptlab+lab])

            ## Assign the reweight factor to the new column
            for sam in sl_rwgt_source + ext_sl_rwgt_source:
                _var = rwgt_var
                _wgtname = wgtname.replace('fj_x', f'fj_{i}')
                _mask = mask_and(arr[sam], mask_list=[f'fj_{i}_base', f'fj_{i}_{ptlab}'])
                arr[sam][_wgtname] = arr[sam][_wgtname] + ak.fill_none(calc_rwgt_akarray(arr[sam][_var].mask[_mask], rwgt_edge, rwgt[ptlab+lab]), 0)
#                 print('midpoint: ', sam, _wgtname, arr[sam][_wgtname])

    if not is_read_from_pickel: ## store the info for the first run
        import pickle
        if not os.path.exists(f'prep/{store_name}/plots'):
            os.makedirs(f'prep/{store_name}/plots')
        with open(f'prep/{store_name}/plots/{wgtname}_{year}.pickle', 'wb') as fw:
            pickle.dump({'ent_target':ent_target, 'ent_source':ent_source, 'rwgt':rwgt}, fw)

    # =========== plot ===========
    if do_plot:
        mpl.rcParams['axes.prop_cycle'] = cycler(color=['blue', 'red', 'green', 'violet', 'darkorange', 'black', 'cyan', 'yellow'])
        for ptlab in ptlab_list:
            f = plt.figure(figsize=(12,12))
            gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.04) 
            for lab, cm, cd in zip(['jet1', 'jet2'], ['blue', 'red'], ['royalblue', 'lightcoral']):
                ax = f.add_subplot(gs[0])
                hep.cms.label(data=True, paper=False, year=year, ax=ax, rlabel=r'%s $fb^{-1}$ (13 TeV)'%lumi[year], fontname='sans-serif')
                hep.histplot(ent_source[ptlab+lab], bins=[0]+list(rwgt_edge_dic[year][lab][ptlab])+[2500], label='Jet '+lab[-1]+' (MC)', color=cm)
                hep.histplot(ent_target[ptlab+lab], bins=[0]+list(rwgt_edge_dic[year][lab][ptlab])+[2500], label='Jet '+lab[-1]+' (Data)', color=cd, linestyle='--')

                ax.set_xlim(0, 2500); ax.set_xticklabels([]); 
                ax.set_yscale('log'); ax.set_ylabel('Events', ha='right', y=1.0)
                ax.legend()
                ax1 = f.add_subplot(gs[1]); 
                hep.histplot(rwgt[ptlab+lab], bins=[0]+list(rwgt_edge_dic[year][lab][ptlab])+[2500], label='Jet '+lab[-1], color=cm)
                ax1.set_xlim(0, 2500); ax1.set_xlabel('$H_{T}$ [GeV]', ha='right', x=1.0);
                ax1.legend()
                ax1.set_yscale('log')
                ax1.set_ylim(5e-3, 2e0); ax1.set_ylabel('Rwgt factor', ha='right', y=1.0);  ax1.set_yticks([1e-2,1e-1,1e0,1e1]);
                ax1.plot([0, 2500], [1, 1], 'k:')

            if not os.path.exists(f'prep/{store_name}/plots'):
                os.makedirs(f'prep/{store_name}/plots')
            plt.savefig(f'prep/{store_name}/plots/rwgtfac_{wgtname}_{year}_{ptlab}_{lab}.pdf')
            plt.savefig(f'prep/{store_name}/plots/rwgtfac_{wgtname}_{year}_{ptlab}_{lab}.png')
    # ============================
    
    return {'ent_target':ent_target, 'ent_source':ent_source, 'rwgt':rwgt}

## Calculate two sets of reweight factor: one for the MG sample list and another for Herwig sample list
htwgt = extract_source_to_target_ht_weight(
    arr, sl_rwgt_source=['subst_'+s for s in read_sample_list_map if s not in ['qcd-herwig-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], ext_sl_rwgt_source=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
    wgtstr_rwgt_source=f"{lumi[year]}*genWeight*xsecWeight*puWeight",
    sl_rwgt_target=['jetht-noht'], wgtstr_rwgt_target='1', wgtname='fj_x_htwgt',
)
if not omit_herwig:
    htwgt_herwig = extract_source_to_target_ht_weight(
        arr, sl_rwgt_source=['subst_'+s for s in read_sample_list_map if s not in ['qcd-mg-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], ext_sl_rwgt_source=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [],
        wgtstr_rwgt_source=f"{lumi[year]}*genWeight*xsecWeight*puWeight", 
        sl_rwgt_target=['jetht-noht'], wgtstr_rwgt_target='1', wgtname='fj_x_htwgt_herwig',
    )

## Calculate bflav factors: reweight bflav sample to inclusive QCD (after b selection cut)
if config['samples']['use_bflav']:
    bflav_htwgt = extract_source_to_target_ht_weight(
        arr, sl_rwgt_source=['subst_qcd-mg-bflav-noht'], wgtstr_rwgt_source=f"{lumi[year]}*genWeight*xsecWeight*puWeight",
        sl_rwgt_target=['subst_qcd-mg-noht'], wgtstr_rwgt_target=f"{lumi[year]}*genWeight*xsecWeight*puWeight", wgtname='fj_x_bflav_htwgt',
        presel='fj_x_nbhadrons>=1', do_plot=False,
    )
    for sam in ['subst_'+s for s in read_sample_list_map if s not in ['qcd-mg-bflav-noht', 'jetht-noht']]:
        arr[sam]['bflav_htwgt'] = ak.ones_like(arr[sam]['ht'])

## Store new variables
stored_branches = backup_array(store_name, stored_branches_interm['prep-3-1'], read_sample_list_map)

## Test output
ak.to_pandas(arr['subst_qcd-mg-noht'][['ht', 'fj_1_pt', 'fj_1_htwgt']][arr['subst_qcd-mg-noht'].maskdict['fj_1_base']])

In [None]:
### ================ 3-2. Determine the optimal sfBDT cut value for each pT range  ===================
# First load the h->cc signal ntuple. Adopt the selction used in the analysis
import re
arr['real-signal'] = NanoEventsFactory.from_root(config['main_analysis_tree']['path'].replace('$YEAR', str(year)), treepath='/'+config['main_analysis_tree']['treename'], schemaclass=BaseSchema).events()

basecut_signal = config['main_analysis_tree']['selection']
arr['real-signal'].maskdict = {}
arr['real-signal'].maskdict['base'] = eval_expr(arr['real-signal'], basecut_signal)

if 'prep-3-2' not in stored_branches_interm:
    stored_branches_interm['prep-3-2'] = get_stored_branches(arr, read_sample_list_map)

def extract_bdt_sequence(arr, sl_pxy, wgtstr, bdt_start, bdt_mod_factor=None, do_plot=True):
    r"""Extract the sfBDT sequence for specified pT range, based on the signal/proxy similarity
    
    Arguments:
        arr: awkward array dict as input
        sl_pxy: proxy sample list
        wgtstr: the weight string applied to proxy samples
        bdt_start: starting point of sfBDT for scanning
        bdt_mod_factor: modify the sfBDT cut by introducing a exponentially decay term from the tagger
    """
    
    ## Edges based on the tagger WPs
    edges = [0.] + sorted([rg[0] for rg in config['tagger']['working_points']['range'].values()]) + [1.]
    print('edges:', edges)
    rat_pxy = {}
    bdt_seq = {}

    ## Extract the optimal sfBDT and variation cut values for each pT range
    for ptmin, ptmax in config['pt_range']['range']:
        ptlab = f'pt{ptmin}to{ptmax}'
        print('pt range: ', ptmin, ptmax)

        ## Calculate the proportion of LP+MP+HP over inclusive tagger score for "signal jets"
        for sam in ['real-signal']:
            if ptlab not in arr[sam].maskdict.keys():
                arr[sam].maskdict[ptlab] = eval_expr(arr[sam], f"({config['main_analysis_tree']['pt_var']}>={ptmin}) & ({config['main_analysis_tree']['pt_var']}<{ptmax})")
        h = get_hist(
            concat_array(arr, expr=config['main_analysis_tree']['tagger'], sam_list=['real-signal'], filter_list=['base', ptlab]),
            bins=edges,
            weights=concat_array(arr, expr=config['main_analysis_tree']['weight'], sam_list=['real-signal'], filter_list=['base', ptlab]),
        )
        rat_hcc = np.array([h.view().value[-1], sum(h.view().value[1:])]) / sum(h.view().value) ## <LP, LP, MP, TP, LP+MP+TP

        ## Calculate the proportion for "proxy jets" as sfBDT floats

        if config['type'] == 'cc':
            pxy_base_sel = f'(fj_x_nbhadrons==0) & (fj_x_nchadrons>=1) & (fj_x_sfBDT>{bdt_start})'
        elif config['type'] == 'bb':
            pxy_base_sel = f'(fj_x_nbhadrons>=1) & (fj_x_sfBDT>{bdt_start})'
        for sam in sl_pxy:
            for i in '12':
                arr[sam].maskdict[f'fj_{i}_bdt_seq_pxy_base'] = eval_expr(arr[sam], pxy_base_sel.replace('fj_x', f'fj_{i}'))
                if f'fj_{i}_{ptlab}' not in arr[sam].maskdict.keys():
                    arr[sam].maskdict[f'fj_{i}_{ptlab}'] = eval_expr(arr[sam], f'(fj_{i}_pt>={ptmin}) & (fj_{i}_pt<{ptmax})')
        ratios = [[], []]
        bdt_scanlist = []
        _df = ak.to_pandas(
            concat_array_fj12(arr, expr=config['tagger']['var'], sam_list=sl_pxy, filter_list=['fj_x_base', 'fj_x_bdt_seq_pxy_base', f'fj_x_{ptlab}']),
            anonymous='tagger',
        ) # use pandas dataframe to speed up iterative processing
        _df['wgt'] = concat_array_fj12(arr, expr=wgtstr, sam_list=sl_pxy, filter_list=['fj_x_base', 'fj_x_bdt_seq_pxy_base', f'fj_x_{ptlab}'])
        _df['sfBDT'] = concat_array_fj12(arr, expr='fj_x_sfBDT', sam_list=sl_pxy, filter_list=['fj_x_base', 'fj_x_bdt_seq_pxy_base', f'fj_x_{ptlab}'])
        # =========== plot ===========
        if do_plot:
            mpl.rcParams['axes.prop_cycle'] = cycler(color=['blue', 'red', 'green', 'violet', 'darkorange', 'black', 'cyan', 'yellow'])
            f, ax = plt.subplots(figsize=(12,12))
            hep.cms.label(data=False, paper=False, year=year, ax=ax, rlabel=r'%s $fb^{-1}$ (13 TeV)'%lumi[year], fontname='sans-serif')
            edge_plot = np.linspace(0, 1, 51)
            ## Signal jet hist
            h_sig = get_hist(
                concat_array(arr, expr=config['main_analysis_tree']['tagger'], sam_list=['real-signal'], filter_list=['base', ptlab]),
                bins=edge_plot,
                weights=concat_array(arr, expr=config['main_analysis_tree']['weight'], sam_list=['real-signal'], filter_list=['base', ptlab]),
            )
            plot_hist(h_sig, bins=edge_plot, label=config['main_analysis_tree']['label'], normed=True)
            ## Proxy jet hist
            for bdt in [0.5, 0.85, 0.9, 0.95]:
                h_pxy = get_hist(_df.query(f'sfBDT>{bdt}')['tagger'], bins=edge_plot, weights=_df.query(f'sfBDT>{bdt}')['wgt'])
                plot_hist(h_pxy, bins=edge_plot, label=f"g({config['type']}) (sfBDT>{bdt:.2f})", normed=True)
            ax.set_xlabel(config['tagger']['var'].replace('fj_x_',''), ha='right', x=1.0); ax.set_ylabel('A.U.', ha='right', y=1.0);
            ax.set_ylim(bottom=0)
            plt.savefig(f"prep/{store_name}/plots/tagger_shape_comp_{config['pt_range']['name']}__{config['main_analysis_tree']['name']}.pdf")
            plt.savefig(f"prep/{store_name}/plots/tagger_shape_comp_{config['pt_range']['name']}__{config['main_analysis_tree']['name']}.png")
        # ============================
        
        bdt_expr = 'sfBDT'
        if bdt_mod_factor is not None:
            bdt_expr = f'sfBDT + 0.5*exp({bdt_mod_factor}*(tagger-1))'
        for bdt in np.arange(bdt_start, 0.999, 0.001):  # loop oversf BDT grid
            _df = _df.query(f'{bdt_expr}>{bdt}')
            h = get_hist(_df['tagger'].values, bins=edges, weights=_df['wgt'].values)
            rat = np.array([h.view().value[-1], sum(h.view().value[1:])]) / sum(h.view().value) ## <LP, LP, MP, TP, LP+MP+TP
            if len(ratios[1]) > 0 and rat[1] < ratios[1][-1] and bdt > 0.996:
                break
            rat_pxy[((ptmin,ptmax), np.round(bdt,3))] = rat
            for j in range(2):
                ratios[j].append(rat[j])
            bdt_scanlist.append(bdt)

        ## Get sfBDT cut WP
        from scipy.interpolate import interp1d
        bdt_wp = interp1d(ratios[1], bdt_scanlist, fill_value="extrapolate")(rat_hcc[1]) # chosen BDT WP: proxy proportion under LP+MP+TP reaches signal
        bdt_wp_hi = interp1d(ratios[0], bdt_scanlist, fill_value="extrapolate")(rat_hcc[0]) # chosen BDT WP (for 4/5's upper bound): proxy proportion under TP reaches signal
        print('all WP:\nsfBDT scan list:', bdt_scanlist[0], '->', bdt_scanlist[-1], 'proxy prop:', ratios[1][0], '->', ratios[1][-1], 'signal prop:', rat_hcc[1], 'interp bdt:', bdt_wp)
        print('HP:    \nsfBDT scan list:', bdt_scanlist[0], '->', bdt_scanlist[-1], 'proxy prop:', ratios[0][0], '->', ratios[0][-1], 'signal prop:', rat_hcc[0], 'interp bdt:', bdt_wp_hi)
        rat_wp, rat_wp_hi = rat_hcc[1], interp1d(bdt_scanlist, ratios[1], fill_value="extrapolate")(bdt_wp_hi) # corresponding LP+MP+TP proportion
        step = (rat_wp_hi - rat_wp) / 4
        rat_seq = np.linspace(rat_wp-step*5, rat_wp+step*5, 11) # derive an arithmetic sequence
        bdt_seq[(ptmin,ptmax)] = interp1d(ratios[1], bdt_scanlist, fill_value="extrapolate")(rat_seq)
        if bdt_seq[(ptmin,ptmax)][-1] >= 1.0:
            raise RuntimeError('The derived sfBDT sequence has values exceeded 1.0')
        print('BDT seq: ', bdt_seq[(ptmin,ptmax)])

    arr[f"bdt_seq_{config['pt_range']['name']}__{config['main_analysis_tree']['name']}"] = bdt_seq
    arr[f"rat_pxy_{config['pt_range']['name']}__{config['main_analysis_tree']['name']}"] = rat_pxy
    arr[f"bdt_mod_factor_{config['pt_range']['name']}__{config['main_analysis_tree']['name']}"] = bdt_mod_factor

## ================ Tune this parameter if necessary ================
## Alternatively set the sfBDT expression to sfBDT + 0.5*exp(bdt_mod_factor*(tagger-1)) in order to 
## extract a possible set of BDT sequence without *going out of interpolation range*
bdt_mod_factor = 70
# bdt_mod_factor = None
# ===================================================================
extract_bdt_sequence(arr, sl_pxy=['subst_'+s for s in read_sample_list_map if s not in ['qcd-herwig-noht', 'qcd-mg-bflav-noht', 'jetht-noht']],
                     wgtstr='genWeight*xsecWeight*puWeight*fj_x_htwgt', bdt_start=0.5, bdt_mod_factor=bdt_mod_factor)

In [None]:
stored_branches = backup_array(store_name, stored_branches_interm['prep-3-2'], read_sample_list_map, global_key_only=1)

In [None]:
### ================ 3-3. Extract the sfBDT>0.5 binned fractor: stored as variable "sfbdtwgt_g50"; similar for herwig ===================

if 'prep-3-3' not in stored_branches_interm:
    stored_branches_interm['prep-3-3'] = get_stored_branches(arr, read_sample_list_map)

def extract_further_sfbdt_weight(arr, sl_rwgt, wgtstr_rwgt, wgtname, rwgt_info, sl_ext_rwgt=[], presel=''):
    r"""Extract the "MC substitute to data" reweight factor (both overall and binned factor) further on sfBDT variable, after a sfBDT>0.9 selection
    
    Arguments:
        arr: awkward array dict as input
        sl_rwgt: sample list for MC substitute in this reweighting routine
        wgtstr_rwgt: the weight string applied to MC to produce the histogram in this reweighting routine
        wgtname: the reweight name (the binned factors) stored as a new column
        rwgt_info: info of the reweight variable, in the format of (var, nbin, xmin, xmax) or (var, edges list, None, None)
        sl_ext_rwgt: extra sample list for which we also calculate the reweight factors after extracting them
        presel: pre-selection before reweighting
    """
    
    ## Initially fill the output column with 0, since we will fill the column iteratively for each pT bin
    for sam in sl_rwgt + sl_ext_rwgt:
        for i in '12':
            arr[sam][wgtname.replace('fj_x', f'fj_{i}')] = ak.zeros_like(arr[sam]['ht'])
    
    ## Reweight based on given variable
    rwgt_var, nbin, xmin, xmax = rwgt_info
    if not isinstance(nbin, int):
        rwgt_edge, xmin, xmax, nbin = nbin, min(nbin), max(nbin), len(nbin)
    else:
        rwgt_edge = np.linspace(xmin, xmax, nbin+1)
    print('rwgt info: ', rwgt_var, rwgt_edge)
    
    ## Rewight separately on jet pT bins
    ent_data, ent_mc, rwgt = {}, {}, {}
    for pt_range in config['pt_range']['range']:
        pt_range = tuple(pt_range)
        ptlab = f'pt{pt_range[0]}to{pt_range[1]}'
        rwgt_presel = f'(fj_x_pt>={pt_range[0]}) & (fj_x_pt<{pt_range[1]})'
        for sam in sl_rwgt + sl_ext_rwgt + ['jetht-noht']:
            for i in '12':
                if f'fj_{i}_{ptlab}' not in arr[sam].maskdict.keys():
                    arr[sam].maskdict[f'fj_{i}_{ptlab}'] = eval_expr(arr[sam], rwgt_presel.replace('fj_x', f'fj_{i}'))
        
        ## Requires the selection sfBDT>0.9 which is used in the fit region
        if presel != '':
            for sam in sl_rwgt + sl_ext_rwgt + ['jetht-noht']:
                for i in '12':
                    arr[sam].maskdict[f'_tmp_fj_{i}_presel'] = eval_expr(arr[sam], presel.replace('fj_x', f'fj_{i}'))

        ## Get data and MC histogram. Note: consider underflow & overflow bins, hence len = nbins+2
        ## does not distinguish jet1 or jet2 on this reweighting
        filter_list = ['fj_x_base', f'fj_x_{ptlab}']+([f'_tmp_fj_{i}_presel'] if presel != '' else [])
        ent_data[pt_range] = get_hist(concat_array_fj12(arr, expr=rwgt_var, sam_list=['jetht-noht'], filter_list=filter_list),
                            bins=rwgt_edge, 
                            weights=np.ones(np.sum(mask_and_fj12(arr['jetht-noht'], mask_list=filter_list))), 
                            underflow=True, overflow=True, mergeflowbin=False).view(flow=True).value
        ent_mc[pt_range]   = get_hist(concat_array_fj12(arr, expr=rwgt_var, sam_list=sl_rwgt, filter_list=filter_list),
                            bins=rwgt_edge,
                            weights=concat_array_fj12(arr, expr=wgtstr_rwgt, sam_list=sl_rwgt, filter_list=filter_list),
                            underflow=True, overflow=True, mergeflowbin=False).view(flow=True).value
        ## Calculate the reweight factor
        rwgt[pt_range] = np.nan_to_num(ent_data[pt_range] / ent_mc[pt_range], nan=0) # len=nbin+2
        print (ent_data[pt_range], rwgt[pt_range])

        ## Assign the reweight factor to the new column
        for sam in sl_rwgt + sl_ext_rwgt:
            for i in '12':
                _var = rwgt_var.replace('fj_x', f'fj_{i}')
                _wgtname = wgtname.replace('fj_x', f'fj_{i}')
                _mask = mask_and(arr[sam], mask_list=[f.replace('fj_x', f'fj_{i}') for f in filter_list])
                arr[sam][_wgtname] = arr[sam][_wgtname] + ak.fill_none(calc_rwgt_akarray(arr[sam][_var].mask[_mask], rwgt_edge, rwgt[pt_range]), 0)
        
    ## Store reweight factors
    import pickle
    if not os.path.exists(f'prep/{store_name}/plots'):
        os.makedirs(f'prep/{store_name}/plots')
    with open(f'prep/{store_name}/plots/{wgtname}_{year}.pickle', 'wb') as fw:
        pickle.dump({'ent_data':ent_data, 'ent_mc':ent_mc, 'rwgt':rwgt}, fw)
    
    return {'ent_data':ent_data, 'ent_mc':ent_mc, 'rwgt':rwgt}

## Calculate two sets of reweight factor: one for the MG sample list and another for Herwig sample list
extract_further_sfbdt_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-herwig-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
                             wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight*fj_x_htwgt",
                             wgtname='fj_x_sfbdtwgt_g50', rwgt_info=('fj_x_sfBDT', 25, 0.5, 1.))
if not omit_herwig:
    extract_further_sfbdt_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-mg-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
                                 wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight*fj_x_htwgt_herwig",
                                 wgtname='fj_x_sfbdtwgt_g50_herwig', rwgt_info=('fj_x_sfBDT', 25, 0.5, 1.))

## Store new variables
stored_branches = backup_array(store_name, stored_branches_interm['prep-3-3'], read_sample_list_map)

## Test output
ak.to_pandas(arr['subst_qcd-mg-noht'][['fj_1_pt', 'fj_1_sfBDT', 'fj_1_sfbdtwgt_g50']][arr['subst_qcd-mg-noht'].maskdict['fj_1_base']])

In [None]:
### ================ 3-4. [additional] Reweight MC subsitute to data on pT: stored as variable "ad_ptwgt", "ad_ptwgt_herwig" ===================

if 'prep-3-4' not in stored_branches_interm:
    stored_branches_interm['prep-3-4'] = get_stored_branches(arr, read_sample_list_map)

def extract_mc_to_data_pt_weight(arr, sl_rwgt, wgtstr_rwgt, wgtname, sl_ext_rwgt=[]):
    r"""Extract the "MC subsisute to data" reweight factor on pT as a optional choice
    
    Arguments:
        arr: awkward array dict as input
        sl_rwgt: sample list for MC substitue in this reweighting routine
        wgtstr_rwgt: the weight string applied to MC to produce the histogram in this reweighting routine
        wgtname: the reweight name stored as a new column
        sl_ext_rwgt: extra sample list for which we also calculate the reweight factors after extracting them
    """
    
    # Apply simple 1D reweight to pT
    rwgt_var, nbin, xmin, xmax  = 'fj_x_pt', 20, 200., 1200.
    rwgt_edge = np.linspace(xmin, xmax, nbin+1)
    
    ## Rewight separately on 1st/2nd jet
    for i, lab in zip(['1','2'], ['jet1','jet2']):
        ## Get data and MC histogram. Note: consider underflow & overflow bins, hence len = nbins+2
        ## Previously this extra factor is extracted with a presel of sfBDT>0.9. Now given that the sfBDT is optimized by pT range thus not fixed, we relax this cut
        ent_data = get_hist(concat_array(arr, expr=rwgt_var.replace('fj_x', f'fj_{i}'), sam_list=['jetht-noht'], filter_list=[f'fj_{i}_base']),
                            bins=rwgt_edge, 
                            weights=np.ones(np.sum(mask_and(arr['jetht-noht'], mask_list=[f'fj_{i}_base']))), 
                            underflow=True, overflow=True, mergeflowbin=False).view(flow=True).value
        ent_mc   = get_hist(concat_array(arr, expr=rwgt_var.replace('fj_x', f'fj_{i}'), sam_list=sl_rwgt, filter_list=[f'fj_{i}_base']),
                            bins=rwgt_edge,
                            weights=concat_array(arr, expr=wgtstr_rwgt, sam_list=sl_rwgt, filter_list=[f'fj_{i}_base']),
                            underflow=True, overflow=True, mergeflowbin=False).view(flow=True).value
        ## Calculate the reweight factor
        rwgt = np.nan_to_num(ent_data / ent_mc, nan=0) # len=nbin+2
        print (ent_data, rwgt)
        
        ## assign the reweight factor to the new column
        for sam in sl_rwgt + sl_ext_rwgt:
            _var = rwgt_var.replace('fj_x', f'fj_{i}')
            _wgtname = wgtname.replace('fj_x', f'fj_{i}')
            _mask = mask_and(arr[sam], mask_list=[f'fj_{i}_base'])
            arr[sam][_wgtname] = calc_rwgt_akarray(arr[sam][_var].mask[_mask], rwgt_edge, rwgt)  ## fill the new column directly as a masked array
#             print('midpoint: ', sam, _wgtname, arr[sam][_wgtname])
        
## Calculate two sets of reweight factor: one for the MG sample list and another for Herwig sample list
extract_mc_to_data_pt_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-herwig-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [],
                             wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight", wgtname='fj_x_ad_ptwgt')
if not omit_herwig:
    extract_mc_to_data_pt_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-mg-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [],
                                 wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight", wgtname='fj_x_ad_ptwgt_herwig')

## Store new variables
stored_branches = backup_array(store_name, stored_branches_interm['prep-3-4'], read_sample_list_map)

## Test output
ak.to_pandas(arr['subst_qcd-mg-noht'][['ht', 'fj_1_pt', 'fj_1_htwgt', 'fj_1_sfbdtwgt_g50', 'fj_1_ad_ptwgt']][arr['subst_qcd-mg-noht'].maskdict['fj_1_base']])

In [None]:
### ================ 3-5. [additional] Reweight MC (proxy jet) to H->cc signal jet on either mass/pT/tau21: stored as variable "(mass|pt|tau21)datamcwgt"; similar for herwig  ===================

# First load the h->cc signal ntuple. Adopt the selction used in the analysis
import re
arr['real-signal'] = NanoEventsFactory.from_root(config['main_analysis_tree']['path'].replace('$YEAR', str(year)), treepath='/'+config['main_analysis_tree']['treename'], schemaclass=BaseSchema).events()

basecut_signal = config['main_analysis_tree']['selection']
arr['real-signal'].maskdict = {}
arr['real-signal'].maskdict['base'] = eval_expr(arr['real-signal'], basecut_signal)

if 'prep-3-5' not in stored_branches_interm:
    stored_branches_interm['prep-3-5'] = get_stored_branches(arr, read_sample_list_map)

def extract_mc_to_signal_weight(arr, sl_rwgt, wgtstr_rwgt, wgtname, rwgt_info, sl_ext_rwgt=[]):
    r"""Extract the "MC subsisute (proxy) to H->cc signal jet" reweight factor on possible variable
    
    Arguments:
        arr: awkward array dict as input
        sl_rwgt: sample list for MC substitue in this reweighting routine
        wgtstr_rwgt: the weight string applied to MC to produce the histogram in this reweighting routine
        wgtname: the reweight name stored as a new column
        rwgt_info: variable and binning info for this reweighting routine
        sl_ext_rwgt: extra sample list for which we also calculate the reweight factors after extracting them
    """
    
    # Reweight info extracted from the function argument
    rwgt_var, nbin, xmin, xmax, rwgt_var_nom  = rwgt_info
    print('rwgt info: ', rwgt_var, nbin, xmin, xmax)
    rwgt_edge = np.linspace(xmin, xmax, nbin+1)
    
    ## Requires the selection sfBDT>0.9 which is (averagely) used in the fit region
    rwgt_sel = 'fj_x_sfBDT>0.9'
    for sam in sl_rwgt + sl_ext_rwgt:
        for i in '12':
            arr[sam].maskdict[f'_tmp_fj_{i}_presel'] = eval_expr(arr[sam], rwgt_sel.replace('fj_x', f'fj_{i}'))
        
    ## Get MC and h->cc signal histogram. Note: consider underflow & overflow bins, hence len = nbins+2
    wgt_mc = concat_array_fj12(arr, expr=wgtstr_rwgt, sam_list=sl_rwgt, filter_list=['fj_x_base', '_tmp_fj_x_presel'])
    yield_mc = wgt_mc.sum()
    ent_mc  = get_hist(concat_array_fj12(arr, expr=rwgt_var, sam_list=sl_rwgt, filter_list=['fj_x_base', '_tmp_fj_x_presel']),
                       bins=rwgt_edge,
                       weights=wgt_mc,
                       underflow=True, overflow=True, mergeflowbin=False).view(flow=True).value      
    
    wgt_hcc = concat_array(arr, expr=config['main_analysis_tree']['weight'], sam_list=['real-signal'], filter_list=['base'])
    yield_hcc = wgt_hcc.sum()
    ent_hcc = get_hist(concat_array(arr, expr=rwgt_var_nom, sam_list=['real-signal'], filter_list=['base']),
                       bins=rwgt_edge,
                       weights=wgt_hcc,
                       underflow=True, overflow=True, mergeflowbin=False).view(flow=True).value
    
    ## Calculate the reweight factors for the two normalized histograms, and clip to (0, 50)
    rwgt = np.nan_to_num((ent_hcc/yield_hcc) / (ent_mc/yield_mc), nan=0) # len=nbin+2
    rwgt = np.clip(rwgt, 0, 50)
    print (ent_hcc, rwgt)

    ## assign the reweight factor to the new column (to both MC and data)
    for sam in sl_rwgt + ['jetht-noht']:
        for i in '12':
            _var = rwgt_var.replace('fj_x', f'fj_{i}')
            _wgtname = wgtname.replace('fj_x', f'fj_{i}')
            _mask = mask_and(arr[sam], mask_list=[f'fj_{i}_base'])
            arr[sam][_wgtname] = calc_rwgt_akarray(arr[sam][_var].mask[_mask], rwgt_edge, rwgt)  ## fill the new column directly as a masked array
#             print('midpoint: ', sam, _wgtname, arr[sam][_wgtname])
    
## For each reweight variable, calculate two sets of reweight factor: one for the MG sample list and another for Herwig sample list
extract_mc_to_signal_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-herwig-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
                            wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight*fj_x_htwgt",
                            wgtname='fj_x_massdatamcwgt', rwgt_info=('fj_x_sdmass', 15, 50, 200, config['main_analysis_tree']['addition_var']['mass']))
extract_mc_to_signal_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-herwig-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
                            wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight*fj_x_htwgt",
                            wgtname='fj_x_ptdatamcwgt', rwgt_info=('fj_x_pt', 20, 200, 1200, config['main_analysis_tree']['pt_var']))
extract_mc_to_signal_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-herwig-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
                            wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight*fj_x_htwgt",
                            wgtname='fj_x_tau21datamcwgt', rwgt_info=('fj_x_tau21', 20, 0, 1, config['main_analysis_tree']['addition_var']['tau21']))
if not omit_herwig:
    extract_mc_to_signal_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-mg-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
                                wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight*fj_x_htwgt_herwig",
                                wgtname='fj_x_massdatamcwgt_herwig', rwgt_info=('fj_x_sdmass', 15, 50, 200, config['main_analysis_tree']['addition_var']['mass']))
    extract_mc_to_signal_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-mg-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
                                wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight*fj_x_htwgt_herwig",
                                wgtname='fj_x_ptdatamcwgt_herwig', rwgt_info=('fj_x_pt', 20, 200, 1200, config['main_analysis_tree']['pt_var']))
    extract_mc_to_signal_weight(arr, sl_rwgt=['subst_'+s for s in read_sample_list_map if s not in ['qcd-mg-noht', 'qcd-mg-bflav-noht', 'jetht-noht']], sl_ext_rwgt=['subst_qcd-mg-bflav-noht'] if config['samples']['use_bflav'] else [], 
                                wgtstr_rwgt = f"{lumi[year]}*genWeight*xsecWeight*puWeight*fj_x_htwgt_herwig",
                                wgtname='fj_x_tau21datamcwgt_herwig', rwgt_info=('fj_x_tau21', 20, 0, 1, config['main_analysis_tree']['addition_var']['tau21']))

## Store new variables
stored_branches = backup_array(store_name, stored_branches_interm['prep-3-5'], read_sample_list_map)

## Test output
ak.to_pandas(arr['jetht-noht'][['fj_1_sdmass', 'fj_1_massdatamcwgt', 'fj_1_pt', 'fj_1_ptdatamcwgt', 'fj_1_tau21', 'fj_1_tau21datamcwgt']][arr['jetht-noht'].maskdict['fj_1_base']])