In [None]:
import numpy as np
# import awkward as ak
# import dask
import json
# from coffea import processor
# from coffea.analysis_tools import Weights, PackedSelection

import pandas as pd
import pyarrow.parquet as pq
from tqdm.auto import tqdm
import os
import xgboost as xgb
import matplotlib.pyplot as plt
from pathlib import Path
import pickle
import hist
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()

import mplhep as hep
plt.style.use([hep.style.CMS])

In [None]:
# with open('samples_noQCD2000.json', 'r') as file:
with open('samples.json', 'r') as file:
    pmap = json.load(file)

print(pmap.keys())
    

In [None]:
# '''processing functions'''
'''Cut definitions'''
def minjetkin(df):
    # fatjets = df['ak8FatJetmsoftdrop', 'ak8FatJetPt', 'ak8FatJetEta']
    # print(df['ak8FatJetmsoftdrop'].shape)
    fatjet_msd = df['FatJet0_msd'].values
    fatjet_pt = df['FatJet0_pt'].values
    fatjet_eta = df['FatJet0_eta'].values
        # fatjets['msdcorr'] = fatjets.msoftdrop
        # fatjets['qcdrho'] = 2 * np.log(fatjets.msdcorr / fatjets.pt)
    candidatejet = df[
            (fatjet_pt > 200)
            & (abs(fatjet_eta) < 2.5)
            # & fatjets.isTight 
        ]

    # candidatejet = candidatejet[:, :2]
    # candidatejet = ak.firsts(candidatejet[ak.argmax(candidatejet.particleNet_XbbVsQCD, axis=1, keepdims=True)])

    # bvl = candidatejet.particleNet_XbbVsQCD
    minjetkin=np.array([
            (candidatejet['FatJet0_pt'] >= 450)
            & (candidatejet['FatJet0_pt']< 1200)
            & (candidatejet['FatJet0_msd'] >= 40.)
            & (candidatejet['FatJet0_msd'] < 201.)
            & (abs(candidatejet['FatJet0_eta']) < 2.5)
       ])
    # minjetkin=np.sum(minjetkin, axis=1).astype('bool').transpose()
    minjetkin = minjetkin.astype('bool').transpose()

    # print(minjetkin)
    # print(minjetkin.shape)
    # print(minjetkin)
    
    return df[minjetkin]
    
def get_paths(year, data_path, proc = 'QCD', deep=False):
    #returns list of paths to parquet files
    parquet_parents = [os.path.join(data_path, year, p, 'parquet','signal-all') for p in pmap[proc]]
    
    if deep:
        file_list=None
        for parent in parquet_parents:
            if file_list is None:
                file_list = [os.path.join(parent,file)for file in os.listdir(parent)]
            else:
                file_list = np.append(file_list, [os.path.join(parent,file)for file in os.listdir(parent)])
    else:
        file_list=parquet_parents
    return file_list

def mode2category(mode):
    cats = np.array(['ggF', 'VBF', 'VH'])
    if mode not in cats:
        raise ValueError(f'Decay mode {mode} not in {cats}')
    category = (cats==mode).astype(int)
    return category

print(mode2category('VBF'))
        
    
    
def process_single(df, 
                   cuts=False,
                   save_fields = ['weight','FatJet0_pt'],
                   signal = False,
                   category = 'QCD', #category order: ['ggF', 'VBF', 'VH']
               ):    
    if cuts: 
        dfc = minjetkin(df.copy(deep=True))
        #add more cuts here
    else:
        dfc = df.copy(deep=True)
                           
    X = dfc[save_fields] 

    if signal:
        X['isSignal']  = np.ones(X['weight'].shape[0]).astype(int)
        X['category'] = [signal]*X['weight'].shape[0]
        # X['y'] = mode2category(X['category'])*X['weight'].shape[0]
    else: 
        X['isSignal'] = np.zeros(X['weight'].shape[0]).astype(int)
        X['category'] = ['QCD']*X['weight'].shape[0]
        # X['y'] = np.array([0,0,0]*X['weight'].shape[0])
    del dfc
    return X

def get_sum_genweights(data_dir: Path, dataset: str) -> float:
    """
    Get the sum of genweights for a given dataset.
    :param data_dir: The directory where the datasets are stored.
    :param dataset: The name of the dataset to get the genweights for.
    :return: The sum of genweights for the dataset.
    """
    total_sumw = 0
    try:
        # Load the genweights from the pickle file
        for pickle_file in list(Path(data_dir / dataset / "pickles").glob("*.pkl")):
            with Path(pickle_file).open("rb") as file:
                out_dict = pickle.load(file)
            # The sum of weights is stored in the "sumw" key
            # You can access it like this:
            for key in out_dict:
                sumw = next(iter(out_dict[key]["sumw"].values()))
            total_sumw += sumw
        print(pickle_file)
    except:
        print("shit: ", list(Path(data_dir / dataset / "pickles").glob("*.pkl"))[0])
        warnings.warn(
            f"Error loading genweights for dataset: {dataset}. Skipping.",
            category=UserWarning,
            stacklevel=2,
        )
        total_sumw = 1

    # print(f"Total sum of weights for all pickles for {dataset}: {total_sumw}")
    return total_sumw

def accumulator(proc, isSignal=False, shallow=False, path=None): #perform data accumulation for a particular process
    if path is None:
        data_dir = '/uscms/home/bweiss/nobackup/hbb/'
        dirs = get_paths('2023', data_dir, proc)
        # print(dirs)
    else:
        if os.path.isfile(path):
            dirs = [path]
        else: 
            dirs = os.listdir(path)
    # dataset = None
    all_data = None
    # total = 0
    # for d in dirs:
    #     if shallow:
    #         total += min(len(os.listdir(d)), shallow)
    #     else:
    #         total += len(os.listdir(d))
    # print(total)
    for d in tqdm(dirs, desc="Processing "+str(proc)): #runs through subsets of a process
        dataset = None
        if os.path.isfile(d):
            ds = [d]
        else:
            ds = os.listdir(d)
        # print(ds)
        for i, file in enumerate(ds): #runs through files in subset
            if shallow and i>shallow: #use only 1 parquet file from each subset if shallow
                print(file)
                break
            file_path = os.path.join(d,file)
            df = pd.read_parquet(file_path)
            cols = df.columns
            excluded_cols = ['MET', #'FatJet0_pt', 'FatJet0_msd', 'FatJet0_pnetMass', 'FatJet0_pnetTXbb'
                            ]
            save_cols = [c for c in cols if (c not in excluded_cols) 
                         and ('Gen' not in c)
                         
                        ]
            # # save_cols = [col for col in multiindex_columns if isinstance(col, int) and col_string in col[0]]+['weight']
            # save_cols = [col for col in multiindex_columns if ( (col_string in col[0]) #save all ak8fatjet columns and weights
                                                            # and ('ass' not in col[0]) 
                                                            # and ('soft' not in col[0])
                                                            #   )] 
            # save_cols = save_cols + [('weight', 0)] + [('weight_noxsec', 0)]
            # if i == 0:
            #      print('save_cols: ', save_cols)
            
            thisdf = process_single(df, cuts=True,
                               save_fields = save_cols,
                               signal = isSignal,
                                  ) #apply cuts save select columns, add isSignal column
            if dataset is None:
                dataset = thisdf
            else:
                dataset = pd.concat([dataset, thisdf], axis = 0, ignore_index=True)
            del thisdf
        #reweight events but sum of weights in a MC dataset
        this_dataset = Path(d).parent.parent.name
        print(this_dataset)
        sumW = get_sum_genweights(Path('/uscms/home/bweiss/nobackup/hbb/2023'), this_dataset)
        dataset['sumW'] = np.ones_like(dataset['weight'])*sumW
        # sumW = np.sum(dataset['weight'].values)
        dataset['weight_final'] = abs(dataset['weight'])/sumW
        print(f'sum of all weights in {d} is {sumW}')
        dataset['MC_name'] = this_dataset
        if all_data is None:
            all_data = dataset
        else:
            all_data = pd.concat([all_data, dataset], axis = 0, ignore_index=True)
        # del dataset
            
    # print('save_cols: ', save_cols)
    return all_data

def df2Dmatrix(X):
    #convert final df to dmatrix for xgb
    dmatrix = xgb.DMatrix(X, label= X['isSignal'], missing = -9999, weight = X['weight_noxsec'])
    return dmatrix

In [None]:
path = '/uscms/home/bweiss/nobackup/hbb/2023/VBFHto2B_M-125_dipoleRecoilOn/parquet/signal-all/part0.parquet'

df = pd.read_parquet(path)
print(df.columns)
print(df['weight'].head())

# data = accumulator('VH', isSignal='test', shallow=100, path = path)

# print(data.columns)
# print('nJet: ', data['nFatJet'].head())
# print(type(data['FatJet1_pt']), data['FatJet1_pt'].head())

# for c in proc_data.columns:
#     print(c, type(data[c]), type(data[c][10]))

# proc_data = data

In [None]:
''' Accumulate data, prepare it, save to mega DF '''
import warnings
# warnings.filterwarnings('ignore', category=pd.core.common.SettingWithCopyWarning) 
warnings.filterwarnings('ignore') 

# shallow = True #take only one parquet from each process/proc subset

samples = ['VBF', 'VH', 'ggF', 'ttH'#'QCD'
          ] #processes to aquire
isSignal = ['VBF', 'VH', 'ggF', 'ttH'#False
           ]

# samples = ['QCD']
# isSignal = ['QCD']

X = None

for j, s in enumerate(samples):
    # print(s)
    proc_data = accumulator(s, isSignal=isSignal[j], shallow=False)
    # print(proc_data.columns)
    if X is None:
        X = proc_data
        # print(X.columns)
    else:
        X = pd.concat([X, proc_data], axis = 0, ignore_index=True)
    # print(X['isSignal'].shape)
    del proc_data

In [None]:
le.fit(X['category'])
X['y'] = le.transform(X['category'])
print(X[['y', 'category']].sample(frac=1, random_state=11))

In [None]:
# plt.rcParams.update({
#     "text.usetex": True,
#     "font.family": "serif",
#     "font.serif": ["Computer Modern Roman"]
# })
plt.rcdefaults()

ak4eta_cols = ['Jet0_eta', 'Jet1_eta', 'Jet2_eta','Jet3_eta', 'category']
jj_pairs = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]

ak4_etas = X[ak4eta_cols]
dEta_jj_all = pd.DataFrame()

#f'jj_{jj}'
for jj in jj_pairs:
    jet1_eta, jet2_eta = ak4_etas[ak4eta_cols[ jj[0] ]], ak4_etas[ak4eta_cols[ jj[1] ]]
        
    dEta_jj = pd.DataFrame()
    dEta_jj[f'jj_{jj}'] = abs(jet2_eta-jet1_eta)
    # dEta_jj['dEta'] = abs(jet2_eta-jet1_eta)
    
    dEta_jj_all = pd.concat([dEta_jj_all, dEta_jj], axis = 1)
    print(dEta_jj_all.columns)

dEta_jj_max = pd.DataFrame(np.nanmax(dEta_jj_all.values, axis = 1), columns = ['max'])

dEta_jj_all = pd.concat([dEta_jj_all, dEta_jj_max], axis = 1)
dEta_jj_all = pd.concat([dEta_jj_all, X['category']], axis = 1)

# print(dEta_jj_all)

stacks = pd.DataFrame()
samples = ['QCD', 'VBF', 'ggF',  'VH']

for s in samples:
    cat_mask = dEta_jj_all['category'] == s
    subset = pd.DataFrame(dEta_jj_all['max'][cat_mask].values, columns = [s])
    stacks = pd.concat([stacks, subset], axis = 1)
fig, ax = plt.subplots(1,1, figsize=(7,5) )
ax.hist(stacks, stacked = False, label = stacks.columns, histtype = 'step', lw = 2)
ax.set(yscale = 'log', xlabel = 'max(dEta_jj)')
# ax.legend(samples)
ax.legend()


# plt.hist(dEta_jj_all['max'])


In [None]:
#manual omition of negative weights
# X= proc_data
print(X.columns)

# X['weight']=abs(X['weight'])
# X['isSignal'] = X['isSignal'].astype(int)

from sklearn.model_selection import train_test_split

X_train, X_test = train_test_split(X, test_size=0.2, random_state=42, shuffle=True, #stratify = X['isSignal']
                                  )

In [None]:
# Define the BDT model
# import xgboost as xgb

# see https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBClassifier
# for detailed explanations of parameters




In [None]:
print(X_train['category']!='ttH')
print(X_train['category'])

In [None]:
omit_cols = ['isSignal', 'weight','weight_final', 
             'category', 'FatJet0_pt', 'FatJet0_msd', 'FatJet0_msdmatched',
             'MC_name', 'y', 'sumW', 'genWeight',
             'FatJet0_pnetMass', 'FatJet0_pnetTXbb', 'FatJet0_pnetTXgg',
             'FatJet0_pnetTXcc', 'FatJet0_pnetTXqq', 'FatJet0_pnetXbbXcc', 'FatJet0_pnetTQCD',
             # 'FatJet1_pnetMass', 'FatJet1_pnetTXbb', 'FatJet1_pnetTXcc',
             # 'FatJet1_pnetTXqq', 'FatJet1_pnetTXgg',
             # 'Jet0_btagPNetB', 'Jet0_btagPNetCvB', 'Jet0_btagPNetCvL', 'Jet0_btagPNetQvG',
             # 'Jet1_btagPNetB', 'Jet1_btagPNetCvB', 'Jet1_btagPNetCvL', 'Jet1_btagPNetQvG',
             # 'Jet2_btagPNetB', 'Jet2_btagPNetCvB', 'Jet2_btagPNetCvL', 'Jet2_btagPNetQvG',
             # 'Jet3_btagPNetB', 'Jet4_btagPNetCvB', 'Jet4_btagPNetCvL', 'Jet4_btagPNetQvG',
            ]
train_not_ttH = X_train['category']!='ttH'
test_not_ttH = X_test['category']!='ttH'
print(len(X_train))
X_train = X_train[train_not_ttH]
print(len(X_train))

Y_train = X_train['y']
Y_test = X_test['y']
W_train = X_train['weight_final']
W_test = X_test['weight_final']
pos_weight = sum(W_train[X_train['isSignal'] == 0])/sum(W_train[X_train['isSignal'] == 1])
print(pos_weight)
# print('Y_train: ',Y_train.head(), sum(Y_train))
# print('Y_test: ', sum(Y_test))

model = xgb.XGBClassifier(
    n_estimators=100,  # number of boosting rounds (i.e. number of decision trees)
    max_depth=8,  # max depth of each decision tree
    learning_rate=0.5,
    early_stopping_rounds=20,  #Remove this # how many rounds to wait to see if the loss is going down
    missing = np.nan,
    # scale_pos_weight = pos_weight,
    eval_metric='merror',
    objective = 'multi:softmax',
    num_classes = 3
    
)

trained_model = model.fit(
    X_train.drop(omit_cols, axis=1), #data should not include label column OR weights
    Y_train, #labels
    sample_weight=W_train,
    # Y_train_val,
    # xgboost uses the last set for early stopping
    # https://xgboost.readthedocs.io/en/stable/python/python_intro.html#early-stopping
    eval_set=[(X_train.drop(omit_cols, axis=1), Y_train), 
              (X_test[test_not_ttH].drop(omit_cols, axis=1), Y_test[test_not_ttH])],  # sets for which to save the loss
    verbose=True,
)

In [None]:
Y_predict = model.predict_proba(X_test.drop(omit_cols, axis=1))
print('Pred. class:', np.argmax(Y_predict, axis = 1))
print('True class:', Y_test.values)

evals_result = trained_model.evals_result()
# print(evals_result['validation_0'].keys())
fig = plt.figure(figsize=(5, 4))
for i, label in enumerate(["Train", "Test"]):
    plt.plot(evals_result[f"validation_{i}"]["merror"], label=label, linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()


# Plot ROC
# Y_predict = model.predict_proba(X_test.drop(omit_cols, axis=1))

# # Y_predict = Y_predict[:, 1].squeeze()
# # print(Y_predict)
# # Y_predict = le.inverse_transform(Y_predict)
# # print(Y_predict)
# # X_test['isSignal'] = le.inverse_transform(X_test['isSignal'])

# from sklearn.metrics import roc_curve, auc

# fig, ax = plt.subplots(1, 1, figsize=(5, 4))

# samples = [False, 'VBF', 'ggF',  'VH']
# for i, s in enumerate(samples[1:]):
#     category_mask = ((X_test['category'] == s) | (X_test['category'] == 'QCD'))
#     fpr, tpr, thresholds = roc_curve(X_test[category_mask]['isSignal'].astype(int), 
#                                      Y_predict[category_mask,1], 
#                                      sample_weight = X_test[category_mask]['weight_final'],
#                                      pos_label=1)
#     roc_auc = auc(fpr, tpr)
    
#     ax.plot(fpr, tpr, lw=2, label=f"{s} auc = %.3f" % (roc_auc))
# ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="k", label="random chance")
# ax.set_xlim([0, 1.0])
# ax.set_ylim([0, 1.0])
# ax.set_xlabel("false positive rate")
# ax.set_ylabel("true positive rate")
# ax.set_title("receiver operating curve")
# ax.legend(loc="lower right")
# plt.show()

In [None]:
from pathlib import Path
main_dir = '/uscms/home/bweiss/nobackup/hbb/'
categories = ['VBF', 'VH', 'ggF', 'ttH']
# proc = 'VH'
CutBased_events = pd.DataFrame()

for proc in categories:
    paths = get_paths('2023', main_dir, proc = proc)
    for samp in paths:
        print(samp)
        k = 0
        for region in os.listdir(str(Path(samp).parent)):
            if k==0:
                print(f'Processing MC: {proc} categorized as {region}')
            reg_in_cat = os.path.join(Path(samp).parent, region)
            for p in os.listdir(reg_in_cat):
                X = pd.read_parquet(os.path.join(reg_in_cat, p), columns = ['weight'])
                X['true_cat'] = proc
                X['cutBased_cat'] = region
                CutBased_events = pd.concat([CutBased_events, X], axis = 0, ignore_index=True)
            k+=1               

In [None]:
# print(X_test['category'])
for cat in categories:
    cat_mask = X_test['category'] == cat
    # if cat=='VBF':
        # print(cat_mask)
        # print(X_test[cat_mask]['sumW'].values)
    sumW = np.nanmean(X_test[cat_mask]['sumW'].values)
    print(sumW)
    # print(cat, CutBased_events['true_cat'] == cat)
    CutBased_events['weights_true'] = CutBased_events[CutBased_events['true_cat'] == cat]['weight'].values/sumW

print(CutBased_events)

In [None]:
''' Purity plot'''

def purity_plot(ax, X_test, category = '', use_weights = False, label =''): # produce a stacked histogram bar with the makeup of that category
    # cat_mask = X_test['category'] == category
    Y_predict = model.predict_proba(X_test.drop(omit_cols, axis=1))
    Y_predict = np.argmax(Y_predict, axis = 1)
    cat_index = le.transform([category])
    pred_mask = Y_predict == cat_index
    X_pred = X_test[pred_mask]
    use_weights = int(use_weights)
    yields = np.zeros((4,2))
    bottom = 0
    for cat in range(4):
        truecat = le.inverse_transform([cat])[0]
        cat_mask = X_pred['category'] == truecat
        yields[cat][:] = [np.sum(cat_mask), 
                          np.sum(X_pred['weight_final'][cat_mask].values, axis = 0)]
    sumw = np.sum(yields[:,1])
    yields[:,1] = yields[:,1]/ np.sum(yields[:,1])
    for c in range(4):
        ax.bar(category, yields[c, use_weights], bottom=bottom)
        bottom += yields[c,use_weights]
        if use_weights:
            ax.set_ylim([0,1])
        ax.set_title(f'n={round(np.sum(yields[:,0]))} \n yield={round(sumw, 2)}', fontsize=16)
    print("----------------------------------------")
    print(f'Predicted {category} purity: {yields[cat_index,1]})')
    print(f'contains {yields[0][0]} VBF, {yields[1][0]} VH, {yields[2][0]} ggF')
    print(f'with yields: {yields[0][1]} VBF, {yields[1][1]} VH, {yields[2][1]} ggF')

# def cutBased_purity(ax, df, category = '', use_weights = False, label =''):
#     cut_mask = CutBased_events['true_cat']

    # return sumw
    # bins = [0,1]
    # _, _, patches = 
    # print(le.transform([category]))
    # print()


fig, axes  = plt.subplots(1,3, figsize = (10,6))
categories = ['VBF', 'VH', 'ggF', 'ttH']
for i in range(3):
    a = purity_plot(axes[i], X_test, use_weights = True, category = categories[i])
fig.legend(categories, ncol=4, loc = 'lower center', bbox_to_anchor=(0.5, -0.07))
fig.suptitle('BDT purity', x=0.5, y=0.92)
fig.tight_layout()





In [None]:
'''Get Cut based yields'''
from pathlib import Path
main_dir = '/uscms/home/bweiss/nobackup/hbb/'
categories = ['VBF', 'VH', 'ggF', 'ttH']
# proc = 'VH'
CutBased_events = pd.DataFrame()
CBtoBDT_cats = {'signal-vh': 'VH',
                'signal-ggf': 'ggF',
                'signal-vbf': 'VBF',
                'VH': 'signal-vh',
                'ggF': 'signal-ggf',
                'VBF': 'signal-vbf',}
CB_cats = ['signal-vh','signal-ggf','signal-vbf']
print(list(CBtoBDT_cats.keys()))
for proc in categories:
    paths = get_paths('2023', main_dir, proc = proc)
    for samp in paths:
        # print(samp)
        MC_name = os.path.basename(Path(samp).parent.parent)
        k = 0
        for region in os.listdir(str(Path(samp).parent)):
            if any(sig in region for sig in CB_cats):
                if k==0:
                    print(f'Processing MC: {MC_name} categorized as {region}')
                reg_in_cat = os.path.join(Path(samp).parent, region)
                for p in os.listdir(reg_in_cat):
                    # print(p)
                    X = pd.read_parquet(os.path.join(reg_in_cat, p), columns = ['weight'])
                    X['true_cat'] = proc
                    X['CB_cat'] = CBtoBDT_cats[region]
                    # X['CB_cat'] = region
                    sumW_file = './sumW.json'
                    with open(sumW_file, 'r') as json_file:
                        sumW_dict = json.load(json_file)
                        sumW = sumW_dict[MC_name]
                    X['weight_final'] = X['weight']/sumW
                    CutBased_events = pd.concat([CutBased_events, X], axis = 0, ignore_index=True)
                k+=1               

print(CutBased_events)
# print(os.listdir(str(Path(samp).parent)))




In [None]:
'''Plot cut based purity'''
def cutBased_purity(ax, X, category = '', use_weights = False, label =''): # produce a stacked histogram bar with the makeup of that category
    # cat_mask = X_test['category'] == category
    # Y_predict = X[cutBased_cat]
    # Y_predict = np.argmax(Y_predict, axis = 1)
    cat_index = le.transform([category])
    pred_mask = X['CB_cat'] == category
    X_pred = X[pred_mask]
    use_weights = int(use_weights)
    yields = np.zeros((4,2))
    bottom = 0
    for cat in range(4):
        truecat = le.inverse_transform([cat])[0]
        cat_mask = X_pred['true_cat'] == truecat
        yields[cat][:] = [np.sum(cat_mask), 
                          np.sum(X_pred['weight_final'][cat_mask].values, axis = 0)]
    sumw = np.sum(yields[:,1])
    yields[:,1] = yields[:,1]/sumw
    for c in range(4):
        ax.bar(category, yields[c, use_weights], bottom=bottom)
        bottom += yields[c,use_weights]
        if use_weights:
            ax.set_ylim([0,1])
        ax.set_title(f'n={round(np.sum(yields[:,0]))} \n yield={round(sumw, 2)}', fontsize=16)
    print("----------------------------------------")
    print(f'Predicted {category} purity: {yields[cat_index,1]})')
    print(f'contains {yields[0][0]} VBF, {yields[1][0]} VH, {yields[2][0]} ggF')
    print(f'with yields: {yields[0][1]} VBF, {yields[1][1]} VH, {yields[2][1]} ggF')

# fig, ax = plt.subplots(1,1)
# cutBased_purity(ax, CutBased_events, category = 'ggF', use_weights=True)

fig, axes  = plt.subplots(1,3, figsize = (10,6))
categories = ['VBF', 'VH', 'ggF', 'ttH']
for i in range(3):
    a = cutBased_purity(axes[i], CutBased_events, use_weights = True, category = categories[i])
fig.legend(categories, ncol=4, loc = 'lower center', bbox_to_anchor=(0.5, -0.07))
fig.suptitle('Cut based purity', x=0.5, y=0.92)
fig.tight_layout()

In [None]:
# print(trained_model.feature_importances_)
plt.figure(figsize=(9,18))
plt.figure(figsize=(9,9))
# plot
# c = X_train.columns
# fields = np.unique(np.array([c[0] for c in X_train.columns]))

# fields=np.array([c[0] for c in X_train.columns])
features = trained_model.get_booster().feature_names
importance = trained_model.feature_importances_
# print(importance)
# y = range(len(importance))


fi = pd.DataFrame({'features': features, 'importance': importance})
fi = fi.sort_values(by = 'importance', ascending=True).reset_index(drop=True)
# print(fi['importance'])
print(fi)

n = 21 #features
y=range(len(fi))
plt.barh(y[-n:], fi['importance'][-n:])
# plt.invert_yaxis()
# plt.bar(fields, trained_model.feature_importances_)
plt.yticks(y[-n:], labels=fi['features'][-n:])
plt.title('Multiclass BDT avg. feature importance')
plt.show()

In [None]:
model.save_model('MultiClassBDT_23Oct25.json')