# Post histogram maker


This notebook helps to make plots using the pre-processed data, which includes
 - data/MC comparison plots under some given event selection;
 - Xbb/cc signal and proxy jets comparison plots on various jet observables;
 - other comparisons.


The data anlaysis is backed by `ak-array` data stucture.

In [None]:
from coffea.nanoevents import NanoEventsFactory, TreeMakerSchema, BaseSchema
import awkward1 as ak
import uproot4 as uproot
import uproot3
import numpy as np
import math
import os

In [None]:
from data_utils import get_hist, plot_hist
from cycler import cycler
import boost_histogram as bh
import matplotlib.pyplot as plt
import matplotlib as mpl
# mpl.use('AGG') # no rendering plots in the window

import mplhep as hep
use_helvet = True  ## true: use helvetica for plots, make sure the system have the font installed
if use_helvet:
    CMShelvet = hep.style.CMS
    CMShelvet['font.sans-serif'] = ['Helvetica', 'Arial']
    plt.style.use(CMShelvet)
else:
    plt.style.use(hep.style.CMS)

In [None]:
from data_utils import get_hist, plot_hist

In [None]:
## Load the config.yml
import yaml
with open('cards/config_cc_ak15_std.yml') as f:
    config = yaml.safe_load(f)

## Load files

Load the ROOT files into lazy awkward arrays

In [None]:
year = config['year']

lumi = {2016: 35.92, 2017: 41.53, 2018: 59.74}

read_sample_list_map = {
    'qcd-mg-noht': 'mc/qcd-mg_tree.root',
    'qcd-herwig-noht': 'mc/qcd-herwig_tree.root',
    'top-noht': 'mc/top_tree.root',
    'v-qq-noht': 'mc/v-qq_tree.root',
    'jetht-noht': 'data/jetht_tree.root',
}
if config['samples']['use_bflav']:
    read_sample_list_map['qcd-mg-bflav-noht'] = 'mc/qcd-mg-bflav_tree.root'

## Read the root file into lazy awkward arrays
arr = {}
sample_prefix = f"{config['samples']['sample_prefix']}_{year}"
for sam in read_sample_list_map:
    arr[sam] = NanoEventsFactory.from_root(f'{sample_prefix}/{read_sample_list_map[sam]}', schemaclass=BaseSchema).events()

## Store the branch names
stored_branches = {}
for sam in read_sample_list_map:
    stored_branches[sam] = ak.fields(arr[sam])
store_name = f"{config['samples']['name']}_SF{config['year']}"

## Load backup pickels

In [None]:
## Fetch variables from the backup file
def load_backup_array(backup_name, read_sample_list_map):
    r"""Load newly stored variables to the awkwary array list.
    
    Arguments:
        backup_name: name of backup folder
        read_sample_list: sample list to read.
    """

    import pickle
    for sam in os.listdir(f'prep/{backup_name}'):
        if sam in read_sample_list_map:
            for var in os.listdir(f'prep/{backup_name}/{sam}'):
                if var.startswith('.'):
                    continue
                if var == 'maskdict':
                    arr[sam].maskdict = {}
                    with open(f'prep/{backup_name}/{sam}/maskdict', 'rb') as f:
                        arr[sam].maskdict = pickle.load(f)
                    print('loading...', sam, 'maskdict', arr[sam].maskdict.keys())
                else:
                    with open(f'prep/{backup_name}/{sam}/{var}', 'rb') as f:
                        arr[sam][var] = pickle.load(f)
                    print('loading...', sam, var)
            if sam != 'jetht-noht':
                arr['subst_'+sam] = arr[sam] # make a reference
        elif not sam.startswith('.') and os.path.isfile(f'prep/{backup_name}/{sam}'):
            with open(f'prep/{backup_name}/{sam}', 'rb') as f:
                arr[sam] = pickle.load(f)
            print('loading...', sam)

load_backup_array(store_name, read_sample_list_map)

In [None]:
def eval_expr(ak_array, expr, mask=None):
    """A function that can do `eval` to the awkward array, immitating the behavior of `eval` in pandas."""
    
    def get_variable_names(expr, exclude=['awkward', 'ak', 'np', 'numpy', 'math']):
        """Extract variables in the expr"""
        import ast
        root = ast.parse(expr)
        return sorted({node.id for node in ast.walk(root) if isinstance(node, ast.Name) and not node.id.startswith('_')} - set(exclude))

    tmp = {k:ak_array[k] if mask is None else ak_array[k].mask[mask] for k in get_variable_names(expr)}
    tmp.update({'math': math, 'numpy': np, 'np': np, 'awkward': ak, 'ak': ak})
#     print('eval expr: ', expr, '\nvars', get_variable_names(expr))
    return eval(expr, tmp)

In [None]:
def mask_and(arr, mask_list):
    """Calculate AND of given mask list"""
    return np.logical_and.reduce([arr.maskdict[mask] for mask in mask_list])

def concat_array(arrdict, expr, sam_list, filter_list):
    """Concatenate the awkward arrays passing the given filter list"""
    if not isinstance(sam_list, list):
        sam_list = [sam_list]
    return np.concatenate([
        np.array(eval_expr(arrdict[sam], expr)[mask_and(arrdict[sam], filter_list)]) for sam in sam_list
    ])

def mask_and_fj12(arr, mask_list):
    """Comibne `mask_and` result for fj_1 and fj_2"""
    mask_list_fj1 = [ele.replace('fj_x', 'fj_1') for ele in mask_list]
    mask_list_fj2 = [ele.replace('fj_x', 'fj_2') for ele in mask_list]
    return np.concatenate([mask_and(arr, mask_list_fj1), mask_and(arr, mask_list_fj2)])

def concat_array_fj12(arrdict, expr, sam_list, filter_list):
    """Comibne `concat_array` result for fj_1 and fj_2"""
    filter_list_fj1 = [ele.replace('fj_x', 'fj_1') for ele in filter_list]
    filter_list_fj2 = [ele.replace('fj_x', 'fj_2') for ele in filter_list]
    return np.concatenate([concat_array(arrdict, expr.replace('fj_x', 'fj_1'), sam_list, filter_list_fj1), 
                           concat_array(arrdict, expr.replace('fj_x', 'fj_2'), sam_list, filter_list_fj2)])

def calc_rwgt_akarray(arr, rwgt_edge, rwgt):
    """Calculate the weight ak-array based on the value ak-array of the reweight variable"""
    arr_out = (arr<rwgt_edge[0])*rwgt[0]
    for i in range(len(rwgt_edge)-1):
        arr_out = arr_out + ((arr>=rwgt_edge[i]) & (arr<rwgt_edge[i+1]))*rwgt[i+1]
    arr_out = arr_out + (arr>=rwgt_edge[-1])*rwgt[-1]
    return arr_out

------------
# Data/MC comparison plots

Based on the ak-array dict `arr`, this section aims to make data and MC plots, while MC is categorized into three flavors: C/B/L.
With the universial make_data_mc_plots function, one can make specify any final selection, any sample list to produce the standard hist+ratio plot.

The below recipe can make a default set of plots.

In [None]:
### ================ configuration  ===================

def make_config_dm(sl_dm, wgtstr_dm):
    return {
        'data':  ('Data',       'jetht-noht',      '1.0',    ''      ),
        'flvB':  ('MC (flvB)', sl_dm[:-1],        wgtstr_dm,   'fj_x_nbhadrons>=1'  ),
        'flvC':  ('MC (flvC)', sl_dm[:-1],        wgtstr_dm,   '(fj_x_nbhadrons==0) & (fj_x_nchadrons>=1)'  ),
        'flvL':  ('MC (flvL)', sl_dm[:-1],        wgtstr_dm,   '(fj_x_nbhadrons==0) & (fj_x_nchadrons==0)'  ),
    }
categories_dm = ['flvL', 'flvB', 'flvC', 'data']

bininfo_dm = [ #(savename, vname, nbin, xmin, xmax, label)
    ('ht', 'ht', 60, 0, 3000, r'$H_{T}$ [GeV]'),
#     ('fj_x_pt', 'fj_x_pt', 100, 0, 2500, r'$p_{T}(AK15)$ [GeV]'),
#     ('fj_x_eta', 'fj_x_eta', 20, -2.5, 2.5, r'$\eta(AK15)$'),
#     ('fj_x_sdmass', 'fj_x_sdmass', 15, 50, 200, r'$m_{SD}(AK15)$ [GeV]'),
#     ('fj_x_sfBDT', 'fj_x_sfBDT', 50, 0.5, 1, r'$sfBDT(AK15)$'),

#     ("fj_x_mSV12_dxysig", "fj_x_mSV12_dxysig", 50, 0, 20, r'$log(m_{SV1,d_{xy}sig\,max}\; /GeV)$'),
#     ("fj_x_btagcsvv2", "fj_x_btagcsvv2", [0,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.98,0.99,0.995,1], None, None, r'$CSVv2$'),
#     ("fj_x_mSV12_ptmax_log", "fj_x_mSV12_ptmax_log", [-0.4,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,2.5,3.2,3.9], None, None, r'$log(m_{SV1,p_{T}\,max}\; /GeV)$'),
#     ("fj_x_mSV12_dxysig_log", "fj_x_mSV12_dxysig_log", [-0.8,-0.4,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,2.5,3.2], None, None, r'$log(m_{SV1,d_{xy}sig\,max}\; /GeV)$'),
]
bininfo_dm += [
    (config['tagger']['var'], config['tagger']['var'], 100, 0, 1, config['tagger']['var'].replace('fj_x_','')),
#     (config['tagger']['var'], config['tagger']['var'], 50, 0.8, 1, config['tagger']['var'].replace('fj_x_','')+'-u'),
]

In [None]:
### ================ slim on tagger, sfBDT, then make data/MC plots ===================

import seaborn as sns
def set_sns_color(*args):
    sns.palplot(sns.color_palette(*args))
    sns.set_palette(*args)

def calc_custom_masks(sl_dm, filter_list, config_dm):
    for sam in sl_dm:
        ext_filter_list = [config_dm[s][-1] for s in config_dm if (sam in config_dm[s][1] or config_dm[s][1]==sam) and config_dm[s][-1] != '']
        for mask in filter_list + ext_filter_list:
            for i in '12':
#                 if mask.replace('fj_x', f'fj_{i}') not in arr[sam].maskdict.keys():
                print('new mask calculated (fj_x -> fj_1/2): ', sam, mask.replace('fj_x', f'fj_{i}'))
                if 'fj_x_pt' in mask:
                    import re
                    ptmin, ptmax = re.findall('fj_x_pt(\S+)to(\S+)', mask)[0]
                    ptmax = '100000' if ptmax=='Inf' else ptmax
                    arr[sam].maskdict[mask.replace('fj_x', f'fj_{i}')] = eval_expr(arr[sam], f'(fj_{i}_pt>={ptmin}) & (fj_{i}_pt<{ptmax})')
                else:
                    arr[sam].maskdict[mask.replace('fj_x', f'fj_{i}')] = eval_expr(arr[sam], mask.replace('fj_x', f'fj_{i}'))

def make_data_mc_plots(sl_dm, config_dm, filter_list, prefix, **kwargs):
    r"""To make standard hist+ratio plots based on the sample list and the final selection
    Arguments:
        sl_dm: sample list
        config_dm: configuration set for each categories in the plots, in the dict format. name: (label, sample/sample list, weight string, cat selection)
        filter_list: keys of maskdict. The corresponding selections are used to produce the plots
        prefix: prefix string used in the output plot title
        kwargs: includes further KDE-related variables
    """
    
    calc_custom_masks(sl_dm, filter_list, config_dm)
    result_dic = {savename: {} for savename, _, _, _, _, _ in bininfo_dm}
    for savename, vname, nbin, xmin, xmax, vlabel in bininfo_dm:
        if 'plot_vars' in kwargs and savename not in kwargs['plot_vars']:
            continue
        if not isinstance(nbin, int):
            edges, xmin, xmax, nbin = nbin, min(nbin), max(nbin), len(nbin)
        else:
            edges = np.linspace(xmin, xmax, nbin+1)

        label, hdm = {}, {}
        underflow = False if vlabel[-2:] in ['-u','-a'] else True
        overflow  = False if vlabel[-2:] in ['-o','-a'] else True
        if vlabel[-2:] in ['-u','-o','-a']:
            vlabel = vlabel[:-2]
        
        if 'g_do_kde_vars' in kwargs and savename in kwargs['g_do_kde_vars'] and kwargs['g_do_kde_vars'][savename]==True:
            g_do_kde_vars = True
            kde = {}
        else:
            g_do_kde_vars = False
        
        ## Loop over categories to extract the hist for each flavor and data
        for cat in categories_dm:
            lab, sam, wgt, sel = config_dm[cat]
            label[cat] = lab
            if cat != 'data':
                _content = concat_array_fj12(arr, expr=vname, sam_list=sam, filter_list=['fj_x_base']+filter_list+([sel] if sel!='' else []))
                _weights = concat_array_fj12(arr, expr=wgt,   sam_list=sam, filter_list=['fj_x_base']+filter_list+([sel] if sel!='' else []))
                import pickle
                with open(f'.roofit_{cat}.pickle', 'wb') as fw:
                    pickle.dump({'content':_content, 'weights':_weights}, fw)
                hdm[cat] = get_hist(_content, bins=edges, weights=_weights, underflow=underflow, overflow=overflow)
                if g_do_kde_vars:
                    from scipy.stats import gaussian_kde
                    from scipy import integrate
                    import multiprocessing
                    if 'custom_kde' in kwargs.keys() and savename in kwargs['custom_kde']:
                        kde[cat] = kwargs['custom_kde'][savename][cat]
                        kde_int_res = [
                                integrate.quad(kde[cat][0], -np.inf if (i==0 and underflow) else edges[i], 
                                                  +np.inf if (i==len(edges)-1 and overflow) else edges[i+1]) for i in range(len(edges)-1)]
                    else:
                        kdetmp = gaussian_kde(_content, weights=np.clip(_weights, 0, np.inf))
                        if 'g_custom_kde_bw' in kwargs.keys() and savename in kwargs['g_custom_kde_bw']:
                            kdetmp = gaussian_kde(_content, weights=np.clip(_weights, 0, np.inf), bw_method=kdetmp.factor/kwargs['g_custom_kde_bw'][savename])
                        kde[cat] = (kdetmp, _weights.sum())
                        kde_int_res = [(kde[cat][0].integrate_box_1d(-np.inf if (i==0 and underflow) else edges[i], +np.inf if (i==len(edges)-1 and overflow) else edges[i+1]), 0.) for i in range(len(edges)-1)]
                    hdm[cat+'_kde'] = hdm[cat].copy()
                    hdm[cat+'_kde'].view(flow=True).value = np.array([kde_int_res[i][0] for i in range(len(edges)-1)]) * kde[cat][1]
                    hdm[cat+'_kde'].view(flow=True).variance = np.zeros(len(edges)-1)
                        
            else: ## is data: no sel, weight=1
                _content = concat_array_fj12(arr, expr=vname, sam_list=sam, filter_list=['fj_x_base']+filter_list)
                _weights = np.ones_like(_content)
                hdm[cat] = get_hist(_content, bins=edges, weights=_weights, underflow=underflow, overflow=overflow)
                
        cat_sufs = ['']
        if g_do_kde_vars:
            cat_sufs += ['_kde']
        for cat_suf in cat_sufs:
            ## Draw the standard hist_ratio plot
            set_sns_color('cubehelix_r', 3) ## set the color palette
            f = plt.figure(figsize=(12,12))
            gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.05) 
            
            ## Upper histogram panel
            ax = f.add_subplot(gs[0])
            hep.cms.label(data=True, paper=False, year=2016, ax=ax, rlabel=r'%s $fb^{-1}$ (13 TeV)'%lumi[year], fontname='sans-serif')
            ax.set_xlim(xmin, xmax); ax.set_xticklabels([]); ax.set_ylabel('Events / bin', ha='right', y=1.0)

            plot_hist([hdm[cat+cat_suf] for cat in categories_dm if cat!='data'], bins=edges, label=[label[cat] for cat in categories_dm if cat!='data'], histtype='fill', edgecolor='k', linewidth=1, stack=True) ## draw stacked bkg
            cats_mc = list(set(categories_dm) - set(['data']))
            hdm_add = hdm[cats_mc[0]+cat_suf].copy()
            for cat in cats_mc[1:]:
                hdm_add += hdm[cat+cat_suf]
            bkgtot, bkgtot_err = hdm_add.view(flow=True).value, np.sqrt(hdm_add.view(flow=True).variance)
            ax.fill_between(edges, (bkgtot-bkgtot_err).tolist()+[0], (bkgtot+bkgtot_err).tolist()+[0], label='BKG unce.', step='post', hatch='///', edgecolor='darkblue', facecolor='none', linewidth=0) ## draw bkg unce.
            plot_hist(hdm['data'], bins=edges, label='Data', histtype='errorbar', color='k', markersize=15, elinewidth=1.5) ## draw data
            ax.set_yscale('log'); ax.set_ylim(1e-1, ax.get_ylim()[1])

            ax.legend()
            # ax.legend(loc='upper left'); ax.set_ylim(0, 1.4*ax.get_ylim()[1])
            
            ## Ratio panel
            ax1 = f.add_subplot(gs[1]); ax1.set_xlim(xmin, xmax); ax1.set_ylim(0.001, 1.999)
            ax1.set_xlabel(vlabel, ha='right', x=1.0); ax1.set_ylabel('Data / MC', ha='center')
            ax1.plot([xmin,xmax], [1,1], 'k'); ax1.plot([xmin,xmax], [0.5,0.5], 'k:'); ax1.plot([xmin,xmax], [1.5,1.5], 'k:')

            hr = hdm['data'].view(flow=True).value / hdm_add.view(flow=True).value
            # hr_err = hr * np.sqrt(hdm['data'].view(flow=True).variance/(hdm['data'].view(flow=True).value**2) + hdm_add.view(flow=True).variance/(hdm_add.view(flow=True).value**2))
            hr_dataerr = hr * np.sqrt(hdm['data'].view(flow=True).variance/(hdm['data'].view(flow=True).value**2))
            ax1.fill_between(edges, ((bkgtot-bkgtot_err)/bkgtot).tolist()+[0], ((bkgtot+bkgtot_err)/bkgtot).tolist()+[0], step='post', hatch='///', edgecolor='darkblue', facecolor='none', linewidth=0) ## draw bkg unce.
            hep.histplot(np.nan_to_num(hr, nan=-1), bins=edges, yerr=np.nan_to_num(hr_dataerr), histtype='errorbar', color='k', markersize=15, elinewidth=1) ## draw data in ratio plot

            filter_list_str = '_'.join(filter_list)
            print('save plot: ', f'plots/{g_dirname}_{year}/{prefix}__{filter_list_str}__{savename}{cat_suf}.png/pdf')
            plt.savefig(f'plots/{g_dirname}_{year}/{prefix}__{filter_list_str}__{savename}{cat_suf}.png')
            plt.savefig(f'plots/{g_dirname}_{year}/{prefix}__{filter_list_str}__{savename}{cat_suf}.pdf')

        ## kde/orig comparison plots
        if g_do_kde_vars:
            mpl.rcParams['axes.prop_cycle'] = cycler(color=['blue', 'red', 'green'])
            f, ax = plt.subplots(figsize=(12,12))
            hep.cms.label(data=False, paper=False, year=year, ax=ax, rlabel=r'%s $fb^{-1}$ (13 TeV)'%lumi[year], fontname='sans-serif')
            x_contin = np.linspace(xmin, xmax, 201)
            bin_width = edges[int(nbin/2)+1] - edges[int(nbin/2)]
            for cat, color in zip(['flvC', 'flvB', 'flvL'], ['blue', 'red', 'green']):
                lab, sam, wgt, sel = config_dm[cat]
                ax.plot(x_contin, kde[cat][0](x_contin) * kde[cat][1] * bin_width, label=lab+' KDE', linestyle=':', color=color)
            for cat, color in zip(['flvC', 'flvB', 'flvL'], ['blue', 'red', 'green']):
                lab, sam, wgt, sel = config_dm[cat]
                hep.histplot(hdm[cat+'_kde'].view(flow=True).value, bins=edges, label=lab+' KDE integral', linestyle='--', color=color)
                plot_hist(hdm[cat], bins=edges, label=lab, normed=False, color=color)
            ax.set_xlim(xmin, xmax); ax.set_xlabel(vlabel, ha='right', x=1.0); ax.set_ylabel('A.U.', ha='right', y=1.0); ax.legend()

            filter_list_str = '_'.join(filter_list)
            plt.savefig(f'plots/{g_dirname}_{year}/{prefix}:kde_shape__{filter_list_str}__{savename}.png')
            plt.savefig(f'plots/{g_dirname}_{year}/{prefix}:kde_shape__{filter_list_str}__{savename}.pdf')
            

g_do_kde_vars = {'fj_x_btagcsvv2':True, 'fj_x_mSV12_ptmax_log':True, 'fj_x_mSV12_dxysig_log':True}
g_custom_kde_bw = {'fj_x_btagcsvv2':15, 'fj_x_mSV12_ptmax_log':4, 'fj_x_mSV12_dxysig_log':4}

g_dirname = config['routine_name']+'_datamc'
if not os.path.exists(f'plots/{g_dirname}_{year}'):
    os.makedirs(f'plots/{g_dirname}_{year}')

for ptrange in config['pt_range']['range']:
    ptlab = f'pt{ptrange[0]}to{ptrange[1]}'
    bdt_seq = arr[f"bdt_seq_{config['pt_range']['name']}__{config['main_analysis_tree']['name']}"][(ptrange[0], ptrange[1])]
    bdt_cent = bdt_seq[int((len(bdt_seq)-1)/2)]
    tagger_wp = sorted([rg[0] for rg in config['tagger']['working_points']['range'].values()])

    ## 1. With MadGraph sample list
    wgtstr_dm = f'genWeight*xsecWeight*puWeight*{lumi[year]}*fj_x_htwgt'
    sl_dm = ['subst_qcd-mg-noht', 'subst_top-noht', 'subst_v-qq-noht', 'jetht-noht']
    make_data_mc_plots(sl_dm, make_config_dm(sl_dm, wgtstr_dm), filter_list=[f'fj_x_{ptlab}', 'fj_x_sfBDT>0.5'], prefix='mg')
    make_data_mc_plots(sl_dm, make_config_dm(sl_dm, wgtstr_dm), filter_list=[f'fj_x_{ptlab}', f'fj_x_sfBDT>{bdt_cent:.3f}'], prefix='mg')
    make_data_mc_plots(sl_dm, make_config_dm(sl_dm, wgtstr_dm), filter_list=[f'fj_x_{ptlab}', f'fj_x_sfBDT>{bdt_cent:.3f}', f"{config['tagger']['var']}>{tagger_wp[-1]}"], prefix='mg')

#     ## 2. With MadGraph sample list, while using the optional MC-to-data reweight scheme (on pT)
#     wgtstr_dm = f'genWeight*xsecWeight*puWeight*{lumi[year]}*fj_x_ad_ptwgt'
#     sl_dm = ['subst_qcd-mg-noht', 'subst_top-noht', 'subst_v-qq-noht', 'jetht-noht']
#     make_data_mc_plots(sl_dm, make_config_dm(sl_dm, wgtstr_dm), filter_list=[f'fj_x_{ptlab}', f'fj_x_sfBDT>{bdt_cent:.3f}'], prefix='mg_ptwgt')
    
#     ## 3. With Herwig sample list
#     wgtstr_dm = f'genWeight*xsecWeight*puWeight*{lumi[year]}*fj_x_htwgt_herwig'
#     sl_dm = ['subst_qcd-herwig-noht', 'subst_top-noht', 'subst_v-qq-noht', 'jetht-noht']
#     make_data_mc_plots(sl_dm, make_config_dm(sl_dm, wgtstr_dm), filter_list=[f'fj_x_{ptlab}', 'fj_x_sfBDT>0.5'], prefix='herwig')
#     make_data_mc_plots(sl_dm, make_config_dm(sl_dm, wgtstr_dm), filter_list=[f'fj_x_{ptlab}', f'fj_x_sfBDT>{bdt_cent:.3f}'], prefix='herwig')
#     make_data_mc_plots(sl_dm, make_config_dm(sl_dm, wgtstr_dm), filter_list=[f'fj_x_{ptlab}', f'fj_x_sfBDT>{bdt_cent:.3f}', f"{config['tagger']['var']}>{tagger_wp[-1]}"], prefix='herwig', 
#                        g_do_kde_vars=g_do_kde_vars, g_custom_kde_bw=g_custom_kde_bw) ## also make the KDE plots

----------
# Signal/proxy comparison plots

Based on the ak-array dict `arr`, The below recipe creates the proxy jet (from MC) and signal jet comparison plots on various jet observables.

In [None]:
## Load the signal tree
import re
arr['real-signal'] = NanoEventsFactory.from_root(config['main_analysis_tree']['path'].replace('$YEAR', str(year)), treepath='/'+config['main_analysis_tree']['treename'], schemaclass=BaseSchema).events()

basecut_signal = config['main_analysis_tree']['selection']
arr['real-signal'].maskdict = {}
arr['real-signal'].maskdict['base'] = eval_expr(arr['real-signal'], basecut_signal)

basesel = { # name: cut, label
    'sv': ("(fj_x_sj1_nsv>=1) & (fj_x_sj2_nsv>=1)", r'$N_{SV}^{match}\geq 1$'),
    'tightsv': ("((fj_x_sj1_sv1_ntracks>2) & (np.abs(fj_x_sj1_sv1_dxy)<3) & (fj_x_sj1_sv1_dlensig>4) & (fj_x_sj2_sv1_ntracks>2) & (np.abs(fj_x_sj2_sv1_dxy)<3) & (fj_x_sj2_sv1_dlensig>4))", r'$N_{SV,tight}^{match}\geq 1$'),
}
def func_basesel(name):
    if name in basesel.keys():
        return basesel[name]
    elif name[:5]=='sfbdt':
        x = float(name[5:])/1000.
        return ('(fj_x_sfBDT>%.3f)'%x, r'$sfBDT>%.3f$'%x)
#         return (f'(fj_x_sfBDT>-0.5*np.exp(70*(fj_x_ParticleNetMD_XbbVsQCD-1))+{x:.3f})', r'$sfBDT+0.5\,exp(70\,T_{xbb})>%.3f$'%x)
#         return ('(fj_x_sfBDT_nopresel_ext_pt600_j1>%.3f)'%x, r'$sfBDT>%.3f$'%x)
    else:
        raise RuntimeError('Baseline cut name not recognized.')

In [None]:
bininfo = [ #(vname, nbin, xmin, xmax, label, *vname for nominal*, xlim)   
#     ('fj_x_pt', 40, 0, 1000, r'$p_{T}$ (AK15)', 'fj_1_pt', None),
    ('fj_x_sdmass', 15, 50, 200, r'$m_{SD}$ (AK15)', 'fj_1_sdmass', None),
#     ('fj_x_tau21', 20, 0, 1, r'$\tau_{21}$ (AK15)', 'ak15_tau21', None), ##avaliable
    
#     ('fj_x_deltaR_sj12', 40, 0, 1.5, r'$\Delta R_{sj_{1},sj_{2}}$ (AK15)', 'fj_1_deltaR_sj12', None),
#     ('fj_x_pt', 40, 0, 1000, r'$p_{T}$ (AK15)', 'ak15_pt', None),
#     ('fj_x_sj1_pt', 40, 0, 1000, r'$p_{T,sj_{1}}$ (AK15)', 'ak15_sj1_pt', None),
#     ('fj_x_sj1_rawmass', 40, 0, 200, r'$m_{sj_{1},raw}$ (AK15)', 'ak15_sj1_rawmass', None), ##avaliable
#     ('fj_x_sj2_pt', 40, 0, 1000, r'$p_{T,sj_{2}}$ (AK15)', 'ak15_sj2_pt', None),
#     ('fj_x_sj2_rawmass', 40, 0, 200, r'$m_{sj_{2},raw}$ (AK15)', 'ak15_sj2_rawmass', None), ##avaliable
    
#     ('fj_x_nsv', 10, 0, 10, r'$N_{SV}$ (AK15)', 'ak15_nlooseSV', None), ##avaliable
#     ('fj_x_nsv_ptgt25', 8, 0, 8, r'$N_{SV,p_{T}\geq 25}$ (AK15)', 'ak15_nlooseSV_ptgt25', None), ##avaliable
#     ('fj_x_nsv_ptgt50', 8, 0, 8, r'$N_{SV,p_{T}\geq 50}$ (AK15)', 'ak15_nlooseSV_ptgt50', None), ##avaliable
#     ('fj_x_ntracks', 20, 0, 20, r'$N_{tracks}$ (AK15)', 'ak15_nlooseSV_ntracks', None), ##avaliable
#     ('fj_x_ntracks_sv12', 20, 0, 20, r'$N_{tracks\;for\;SV_{1,2}}$ (AK15)', 'ak15_nlooseSV_ntracks_sv12', None), ##avaliable
#     ('fj_x_sj1_nsv', 20, 0, 20, r'$N_{SV\;from\;sj_{1}}$ (AK15)', 'ak15_sj1_nlooseSV', None), ##avaliable
#     ('fj_x_sj1_ntracks', 20, 0, 20, r'$N_{tracks\;from\;sj_{1}}$ (AK15)', 'ak15_sj1_nlooseSV_ntracks', None), ##avaliable
#     ('fj_x_sj1_sv1_pt', 20, 0, 200, r'$p_{T,\;SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_pt', None),
#     ('fj_x_sj1_sv1_mass', 20, 0, 50, r'$m_{SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_mass', None), ##avaliable
#     ('fj_x_sj1_sv1_masscor', 20, 0, 50, r'$m_{cor\;for\;SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_masscor', None),
#     ('fj_x_sj1_sv1_ntracks', 20, 0, 20, r'$N_{tracks\;from\;SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_ntracks', None),
#     ('fj_x_sj1_sv1_dxy', 20, 0, 5, r'$d_{xy,\;SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_dxy', None),
#     ('fj_x_sj1_sv1_dxysig', 20, 0, 20, r'$\sigma_{d_{xy},\;SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_dxysig', None),
#     ('fj_x_sj1_sv1_dlen', 20, 0, 5, r'$d_{z,\;SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_dlen', None),
#     ('fj_x_sj1_sv1_dlensig', 20, 0, 20, r'$\sigma_{d_{z},\;SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_dlensig', None),
#     ('fj_x_sj1_sv1_chi2ndof', 20, 0, 5, r'$\chi^2 / Ndof_{SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_chi2ndof', None),
#     ('fj_x_sj1_sv1_pangle', 40, 0, 5, r'$pAngle_{SV_{1}\;in\;sj_{1}}$ (AK15)', 'ak15_sj1_looseSV_pangle', None),
]
bininfo += [
    (config['tagger']['var'], 50, 0, 1, config['tagger']['var'].replace('fj_x_',''), config['main_analysis_tree']['tagger'], None),
#     ((config['tagger']['var']+'_WP', config['tagger']['var']), 50, 0, 1, config['tagger']['var'].replace('fj_x_',''), config['main_analysis_tree']['tagger'], None),
]

In [None]:
g_dirname = config['routine_name']+'_sigpxy'
if not os.path.exists(f'plots/{g_dirname}_{year}'):
    os.makedirs(f'plots/{g_dirname}_{year}')

## Make comparison plots for normal weight (MC adopt the same weight as in the fit), or for additional mass / pT / tau21 weight
# for wgtfac, pfwgt in zip(['1','massdatamcwgt','ptdatamcwgt'], ['nom', 'massdatamcwgt', 'ptdatamcwgt']):
for wgtfac, pfwgt in zip(['1'], ['nom']):

    wgtstr = f'genWeight*xsecWeight*puWeight*fj_x_htwgt*{wgtfac}'
    wgtstr_signal = config['main_analysis_tree']['weight']

    mpl.rcParams['axes.prop_cycle'] = cycler(color=['blue', 'red', 'green', 'violet', 'darkorange', 'black', 'cyan', 'yellow'])
    do_rwgt = 0
    for ptmin, ptmax in config['pt_range']['range']:
        ptlab = f'pt{ptmin}to{ptmax}'
        presel, presel1 = f'(fj_x_pt>{ptmin}) & (fj_x_pt<{ptmax})', f"({config['main_analysis_tree']['pt_var']}>={ptmin}) & ({config['main_analysis_tree']['pt_var']}<{ptmax})"
        label = {'proxy': f"g({config['type']})", 'real-signal':config['main_analysis_tree']['label']}
                
        for vname, nbin, xmin, xmax, vlabel, vname1, xlim in bininfo:
            if not isinstance(vname, str): ## savename is specified other then the variable name
                savename, vname = vname
            else:
                savename = vname
            if not isinstance(nbin, int):
                edges, xmin, xmax, nbin = nbin, min(nbin), max(nbin), len(nbin)
            else:
                edges = np.linspace(xmin, xmax, nbin+1)

            f, ax = plt.subplots(figsize=(12,12))
            hep.cms.label(data=False, paper=False, year=year, ax=ax, rlabel=r'%s $fb^{-1}$ (13 TeV)'%lumi[year], fontname='sans-serif')
            
            ## Signal jet
            for sam in ['real-signal']:
                arr[sam].maskdict['_tmp_sigpxy_presel'] = eval_expr(arr[sam], presel1)
                _content = concat_array(arr, expr=vname1, sam_list=sam, filter_list=['base', '_tmp_sigpxy_presel'])
                _weights = concat_array(arr, expr=wgtstr_signal, sam_list=sam, filter_list=['base', '_tmp_sigpxy_presel'])
                h = get_hist(_content, bins=edges, weights=_weights)
                plot_hist(h, bins=edges, label=label[sam]+' $N_{SV}^{match}\geq 1$' if sam=='qcd-mg' else label[sam], normed=True)

            ## Proxy jet
            use_standard_sfbdt = False  # if True, use standard sfBDT variation list for plot: [0.5, 0.8, 0.9, 0.95]
            proxy_sl = ['subst_qcd-mg-noht', 'subst_top-noht', 'subst_v-qq-noht']
            if use_standard_sfbdt:
                selclist, suf_label = ['sv+sfbdt500', 'sv+sfbdt800', 'sv+sfbdt900', 'sv+sfbdt950'], ['','','','']
            else:
                bdt_seq = arr[f"bdt_seq_{config['pt_range']['name']}__{config['main_analysis_tree']['name']}"][(ptmin,ptmax)]
                selclist = ['sv+sfbdt500'] + [f'sv+sfbdt{int(b*1000)}' for b in [bdt_seq[0], bdt_seq[int((len(bdt_seq)-1)/2)], bdt_seq[-1]]]
                suf_label = ['', ' (lower)', ' (central)', ' (upper)']
            selclist, suf_label = ['sv+sfbdt500', 'sv+sfbdt800', 'sv+sfbdt900', 'sv+sfbdt950'], ['','','','']
            for ext, slb in zip(selclist, suf_label):
                cutstr = ' & '.join(list(filter(None, [presel]+[func_basesel(cname)[0] for cname in ext.split('+')]))) ## join the cut string
                for sam in proxy_sl:
                    for i in '12':
                        arr[sam].maskdict[f'_tmp_fj_{i}_sigpxy_{ext}'] = eval_expr(arr[sam], cutstr.replace('fj_x', f'fj_{i}'))
                        if f'fj_x_{ptlab}' not in arr[sam].maskdict:
                            arr[sam].maskdict[f'fj_{i}_{ptlab}'] = eval_expr(arr[sam], f'(fj_{i}_pt>={ptmin}) & (fj_{i}_pt<{ptmax})')
                        if any([m not in arr[sam].maskdict for m in [f'fj_{i}_flvB', f'fj_{i}_flvC', f'fj_{i}_flvL']]):
                            arr[sam].maskdict[f'fj_{i}_flvB'] = eval_expr(arr[sam], f'fj_{i}_nbhadrons>=1')
                            arr[sam].maskdict[f'fj_{i}_flvC'] = eval_expr(arr[sam], f'(fj_{i}_nbhadrons==0) & (fj_{i}_nchadrons>=1)')
                            arr[sam].maskdict[f'fj_{i}_flvL'] = eval_expr(arr[sam], f'(fj_{i}_nbhadrons==0) & (fj_{i}_nchadrons==0)')
                _content = concat_array_fj12(arr, expr=vname, sam_list=proxy_sl, filter_list=['fj_x_base', f"fj_x_flv{config['type'][0].upper()}", f'fj_x_{ptlab}', f'_tmp_fj_x_sigpxy_{ext}'])
                _weights = concat_array_fj12(arr, expr=wgtstr, sam_list=proxy_sl, filter_list=['fj_x_base', f"fj_x_flv{config['type'][0].upper()}", f'fj_x_{ptlab}', f'_tmp_fj_x_sigpxy_{ext}'])
                h = get_hist(_content, bins=edges, weights=_weights)
                plot_hist(h, bins=edges, label=label['proxy']+' '+(rwgt_ext_label if do_rwgt else '')+' & '.join([func_basesel(cname)[1] for cname in ext.split('+')])+slb, normed=True)

            ax.legend()
            ax.set_xlim((xmin, xmax) if xlim is None else xlim)
            ax.set_xlabel(vlabel, ha='right', x=1.0); ax.set_ylabel('A.U.', ha='right', y=1.0); 
            print('save plot: ', f'plots/{g_dirname}_{year}/{pfwgt}_{presel}__{savename}.png/pdf')
            plt.savefig(f'plots/{g_dirname}_{year}/{pfwgt}_{presel}__{savename}.png')
            plt.savefig(f'plots/{g_dirname}_{year}/{pfwgt}_{presel}__{savename}.pdf')

--------------
# Other comparisons

The below function enables one to make a simple comparison with the given sample lists, weight strings, pre-selection strings, and labels.

In [None]:
def simple_comp_plot(arr, bininfo, sam_list, wgtstr, base_mask, presel, label, isnormed=True):
    for i in range(len(sam_list)):
        if isinstance(sam_list[i], str):
            sam_list[i] = [sam_list[i]]
    for i in range(len(base_mask)):
        if isinstance(base_mask[i], str):
            base_mask[i] = [base_mask[i]]
    
    mpl.rcParams['axes.prop_cycle'] = cycler(color=['blue', 'red', 'green', 'violet', 'darkorange', 'black', 'cyan', 'yellow'])
    for vname, nbin, xmin, xmax, vlabel in bininfo:
        if not isinstance(vname, str): ## savename is specified other then the variable name
            savename, vname = vname
        else:
            savename = vname
        if not isinstance(nbin, int):
            edges, xmin, xmax, nbin = nbin, min(nbin), max(nbin), len(nbin)
        else:
            edges = np.linspace(xmin, xmax, nbin+1)

        f, ax = plt.subplots(figsize=(12,12))
        hep.cms.label(data=False, paper=False, year=year, ax=ax, rlabel=r'%s $fb^{-1}$ (13 TeV)'%lumi[year], fontname='sans-serif')

        for sl, wgt, bmask, sel, lab in zip(sam_list, wgtstr, base_mask, presel, label):
            for sam in sl:
                for i in '12':
                    arr[sam].maskdict[f'_tmp_fj_{i}_simple_comp'] = eval_expr(arr[sam], sel.replace('fj_x',f'fj_{i}'))
            _content = concat_array_fj12(arr, expr=vname, sam_list=sl, filter_list=bmask+['_tmp_fj_x_simple_comp'])
            _weights = concat_array_fj12(arr, expr=wgt, sam_list=sl, filter_list=bmask+['_tmp_fj_x_simple_comp']) if wgt!='1' else ak.ones_like(_content)
#             print(_content, _weights)
            h = get_hist(_content, bins=edges, weights=_weights)
            plot_hist(h, bins=edges, label=lab, normed=isnormed)

        ax.legend()
        ax.set_xlim(xmin, xmax)
        ax.set_xlabel(vlabel, ha='right', x=1.0); ax.set_ylabel('A.U.' if isnormed else 'Events / bin', ha='right', y=1.0); 

## Standard vs. extra b-enriched sample comparison

In [None]:
bininfo = [ #(savename, vname, nbin, xmin, xmax, label)
    ('fj_x_btagcsvv2', [0,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.98,0.99,0.995,1], None, None, r'$CSVv2$'),
    ('fj_x_mSV12_ptmax_log', [-0.4,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,2.5,3.2,3.9], None, None, r'$log(m_{SV1,p_{T}\,max}\; /GeV)$'),
    ('fj_x_mSV12_dxysig_log', [-0.8,-0.4,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,2.5,3.2], None, None, r'$log(m_{SV1,d_{xy}sig\,max}\; /GeV)$'),
]
ptmin, ptmax = 250, 350
simple_comp_plot(
    arr=arr, bininfo=bininfo,
    sam_list=[['subst_qcd-mg-noht'],['subst_qcd-mg-bflav-noht']],
    wgtstr=['genWeight*xsecWeight*puWeight*fj_x_htwgt', 'genWeight*xsecWeight*puWeight*fj_x_htwgt*fj_x_bflav_htwgt'],
    base_mask=['fj_x_base']*2,
    presel=[f'(fj_x_nbhadrons>=1) & (fj_x_pt>{ptmin}) & (fj_x_pt<{ptmax}) & (fj_x_sfBDT>0.9)']*2,
    label=['standard','b-flavor'],
    isnormed=True
)

## QCD Pythia vs. Herwig sample comparison

In [None]:
bininfo = [
    (config['tagger']['var'], 100, 0, 1, config['tagger']['var'].replace('fj_x_','')),
]
simple_comp_plot(
    arr=arr, bininfo=bininfo,
    sam_list=[['subst_qcd-mg-noht'],['subst_qcd-herwig-noht']],
    wgtstr=['genWeight*xsecWeight*puWeight*fj_x_htwgt', 'genWeight*xsecWeight*puWeight*fj_x_htwgt_herwig'],
    base_mask=['fj_x_base']*2,
    presel=[f'(fj_x_nbhadrons>=1) & (fj_x_sfBDT>0.95) & (fj_x_pt>800) ']*2,
    label=['Pythia','Herwig'],
    isnormed=True
)

-----------------
# Check systematic templates

In [None]:
import seaborn as sns
def set_sns_color(*args):
    sns.palplot(sns.color_palette(*args))
    sns.set_palette(*args)

set_sns_color('cubehelix_r', 3)

In [None]:
year = config['year']
lumi = {2016: 35.92, 2017: 41.53, 2018: 59.74}

bininfo_dm = [ # vtitlecontains, bins, xmin, xmax, vlabel
    # v2
    ("csvv2_var22binsv2", [0,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.98,0.99,0.995,1], None, None, r'$CSVv2$'),
    ("csvv2_var20binsv2", [0,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.98,1], None, None, r'$CSVv2$'),
    ('msv12_ptmax_log_var22binsv2', [-0.4,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,2.5,3.2,3.9], None, None, r'$log(m_{SV1,p_{T}\,max}\; /GeV)$'),
#     ('msv12_dxysig_log_var22binsv2', [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8], None, None, r'$log(m_{SV1,d_{xy}sig\,max}\; /GeV)$'),
    ('msv12_dxysig_log_var22binsv2', [-0.8,-0.4,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,2.5,3.2], None, None, r'$log(m_{SV1,d_{xy}sig\,max}\; /GeV)$'),
]

if config['type'] == 'cc':
    color_order, cat_order = sns.color_palette('cubehelix_r', 3), ['flvL','flvB','flvC']
else:
    color_order, cat_order = np.array(sns.color_palette('cubehelix_r', 3))[[1,0,2]], ['flvL','flvC','flvB']

In [None]:
def make_stacked_plots_for_shapeunc(inputdir, unce_type=None, plot_unce=True, draw_stacked_plots=False, save_unce_comp_plots=True, show_plots=True, norm_unce=False):
    r"""Make the shape comparison and/or the stacked histograms for a specific type of shape uncertainty based on the fitDiagnostics.root
    
    Arguments:
        inputdir: Directory to fitDiagnostics.root
        unce_type: Name of shape uncertainty (w/o Up or Down) to plot.
        plot_unce: If or not plot the MC uncertainty in the upper & lower panel. Default: True
        draw_stacked_plots: If or not also draw the stacked histograms (drawing the comparison plots is the default option). Default: False
        save_unce_comp_plots: If or not store the shape comparison plot. Default: True
        show_plots: If or not show plot in the runtime. Default: True
        norm_unce: Normalize the up/down uncertainty to nominal. Default: False
    """

    year = 2016 if 'SF2016' in inputdir else 2017 if 'SF2017' in inputdir else 2018 if 'SF2018' in inputdir else None
    for vname, nbin, xmin, xmax, vlabel in bininfo_dm:
        if vname in inputdir:
            break
    else:
        raise RuntimeError('Bininfo not found')
    import os
    if not isinstance(unce_type, str) or not os.path.exists(f'{inputdir}/{unce_type}Up') or not os.path.exists(f'{inputdir}/{unce_type}Down'):
        raise RuntimeError('Uncertainty type not exist')

    if not isinstance(nbin, int):
        edges, xmin, xmax, nbin = nbin, min(nbin), max(nbin), len(nbin)
    else:
        edges = np.linspace(xmin, xmax, nbin+1)
    print(inputdir, '--unce--', unce_type)
    
    # curves for unce
    for b in ['pass', 'fail']:
        content = [uproot3.open(f'{inputdir}/nominal/inputs_{b}.root')[f'{cat}'].allvalues[1:-1] for cat in cat_order[::-1]]
        yerror  = [np.sqrt(uproot3.open(f'{inputdir}/nominal/inputs_{b}.root')[f'{cat}'].allvariances[1:-1]) for cat in cat_order[::-1]]
        content_up   = [uproot3.open(f'{inputdir}/{unce_type}Up/inputs_{b}.root')[f'{cat}_{unce_type}Up'].allvalues[1:-1] for cat in cat_order[::-1]]
        content_down = [uproot3.open(f'{inputdir}/{unce_type}Down/inputs_{b}.root')[f'{cat}_{unce_type}Down'].allvalues[1:-1] for cat in cat_order[::-1]]
        lab_suf = ''
        if norm_unce:
            lab_suf = '(norm)'
            for icat, cat in enumerate(cat_order[::-1]):
                content_up[icat] *= content[icat].sum() / content_up[icat].sum()
                content_down[icat] *= content[icat].sum() / content_down[icat].sum()
        f, ax = plt.subplots(figsize=(12,12))
        hep.cms.label(data=True, paper=False, year=year, ax=ax, rlabel=r'%s $fb^{-1}$ (13 TeV)'%lumi[year], fontname='sans-serif')
        for icat, (cat, color) in enumerate(zip(cat_order[::-1], ['blue', 'red', 'green'])):
            hep.histplot(content[icat], yerr=yerror[icat], bins=edges, label=f'QCD ({cat})', color=color)
        for icat, (cat, color) in enumerate(zip(cat_order[::-1], ['blue', 'red', 'green'])):
            hep.histplot(content_up[icat], bins=edges, label=f'QCD ({cat}) {unce_type}Up {lab_suf}', color=color, linestyle='--')
        for icat, (cat, color) in enumerate(zip(cat_order[::-1], ['blue', 'red', 'green'])):
            hep.histplot(content_down[icat], bins=edges, label=f'QCD ({cat}) {unce_type}Down {lab_suf}', color=color, linestyle=':')
        ax.set_xlim(xmin, xmax); ax.set_xlabel(vlabel, ha='right', x=1.0); ax.set_ylabel('Events / bin', ha='right', y=1.0)
        ax.legend(prop={'size': 18})
        
        if save_unce_comp_plots:
            plt.savefig(f'{inputdir}/unce_comp_{unce_type}_{b}.png')
            plt.savefig(f'{inputdir}/unce_comp_{unce_type}_{b}.pdf')
            if not show_plots:
                plt.close()

    # stacked plots
    if draw_stacked_plots:
        for filedir in ['nominal', unce_type+'Up', unce_type+'Down']:
            roothist_suf = '' if filedir=='nominal' else '_'+filedir
            for b in ['pass', 'fail']:
                set_sns_color(color_order)
                f = plt.figure(figsize=(12,12))
                gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.05) 
                ax = f.add_subplot(gs[0])
                hep.cms.label(data=True, paper=False, year=year, ax=ax, rlabel=r'%s $fb^{-1}$ (13 TeV)'%lumi[year], fontname='sans-serif')
                ax.set_xlim(xmin, xmax); ax.set_xticklabels([]); 
                ax.set_ylabel('Events / bin', ha='right', y=1.0)
                label, hdm = {}, {}
                underflow = False if vlabel[-2:] in ['-u','-a'] else True
                overflow  = False if vlabel[-2:] in ['-o','-a'] else True
                if vlabel[-2:] in ['-u','-o','-a']:
                    vlabel = vlabel[:-2]

                content = [uproot3.open(f'{inputdir}/{filedir}/inputs_{b}.root')[f'{cat}{roothist_suf}'].allvalues[1:-1] for cat in cat_order]
                bkgtot = np.sum(content, axis=0)
                hep.histplot(content, bins=edges, label=[f'QCD ({cat})' for cat in cat_order], histtype='fill', edgecolor='k', linewidth=1, stack=True) ## draw MC
                data = uproot3.open(f'{inputdir}/nominal/inputs_{b}.root')['data_obs'].allvalues[1:-1]
                data_errh = data_errl = np.sqrt(uproot3.open(f'{inputdir}/nominal/inputs_{b}.root')['data_obs'].allvariances[1:-1])
                hep.histplot(data, yerr=(data_errl, data_errh), bins=edges, label='Data', histtype='errorbar', color='k', markersize=15, elinewidth=1.5) ## draw data
                ax.set_ylim(0, ax.get_ylim()[1])
                ax.legend()

                ax1 = f.add_subplot(gs[1]); ax1.set_xlim(xmin, xmax); ax1.set_ylim(0.001, 1.999)
                ax1.set_xlabel(vlabel, ha='right', x=1.0); ax1.set_ylabel('Data / MC', ha='center')
                ax1.plot([xmin,xmax], [1,1], 'k'); ax1.plot([xmin,xmax], [0.5,0.5], 'k:'); ax1.plot([xmin,xmax], [1.5,1.5], 'k:')

                hep.histplot(data/bkgtot, yerr=(data_errl/bkgtot, data_errh/bkgtot), bins=edges, histtype='errorbar', color='k', markersize=15, elinewidth=1)

In [None]:
## ====== config me! ======
inputdir = '/home/pku/licq/hcc/new/results/20210412_bb_M120_SF2018_pnV02bb_-val_pt-_HP_msv12_dxysig_log_var22binsv2/Cards/pt400to500/bdt968/'
unce_type = 'pu'
## ========================

make_stacked_plots_for_shapeunc(inputdir, unce_type=unce_type, save_unce_comp_plots=False, draw_stacked_plots=False)