In [1]:
import numpy as np
import awkward as ak
import hist
import warnings
import pickle
from coffea.ml_tools.torch_wrapper import torch_wrapper
import matplotlib.pyplot as plt
import hist
from sklearn.metrics import roc_curve, auc
import math
import os
import uproot
import json
from tqdm import tqdm

In [2]:
warnings.filterwarnings('ignore', 'invalid value')
warnings.filterwarnings('ignore', 'No format')

In [3]:
hgg = ak.from_parquet('/scratch365/cmoore24/training/data/ecfs/nanless_hgg.parquet')
qcd = ak.firsts(ak.from_parquet('/scratch365/cmoore24/training/data/ecfs/q476_ecfs.parquet'))

In [4]:
hgg = hgg[hgg.msoftdrop <= 170]
hgg = hgg[hgg.msoftdrop >= 80]
hgg = hgg[hgg.pt <= 600]
hgg = hgg[hgg.pt >= 450]

In [5]:
qcd = qcd[qcd.msoftdrop <= 170]
qcd = qcd[qcd.msoftdrop >= 80]
qcd = qcd[qcd.pt <= 600]
qcd = qcd[qcd.pt >= 450]

In [6]:
def nan_remover(sample):
    if type(sample) == dict:
        for j in sample:
            mask = ak.ones_like(sample[j][sample[j].fields[0]], dtype='bool')
            mask = ak.fill_none(mask, True)
            for k in sample[j].fields:
                mask = mask & (~ak.is_none(ak.nan_to_none(sample[j][k])))
            sample[j] = sample[j][mask]
    else:
        mask = ak.ones_like(sample[sample.fields[0]], dtype='bool')
        mask = ak.fill_none(mask, True)
        for j in sample.fields:
            if sample[j].fields == []:
                mask = mask & (~ak.is_none(ak.nan_to_none(sample[j])))
            else:
                for k in sample[j].fields:
                    mask = mask & (~ak.is_none(ak.nan_to_none(sample[j][k])))
        sample = sample[mask]
    return sample

In [7]:
hgg = nan_remover(hgg)
qcd = nan_remover(qcd)

In [8]:
def add_ratio(ratio, dataframe):
    dash = ratio.find('/')
    asterisk = ratio.find('*')
    numerator = ratio[:dash]
    denominator = ratio[dash+1:asterisk]
    exponent = float(ratio[asterisk+2:])
    num_ecf = dataframe[numerator]
    den_ecf = dataframe[denominator]
    ecf_ratio = (num_ecf)/(den_ecf**exponent)
    return ecf_ratio

In [9]:
def imapper(array, ratio_list):
    imap = {}
    imap['vars'] = {}
    for i in ratio_list:
        try:
            imap['vars'][i] = array.ratios[i]
        except:
            imap['vars'][i] = array[i]
    return imap

In [10]:
class EnergyCorrelatorFunctionTagger(torch_wrapper):
    def prepare_awkward(self, events, scaler, imap):
        #fatjets = events

        retmap = {
            k: ak.concatenate([x[:, np.newaxis] for x in imap[k].values()], axis=1)
            for k in imap.keys()
        }
        x = ak.values_astype(scaler.transform(retmap['vars']), "float32")
        return (x,), {}

In [11]:
def get_cut(data, target_percentile=0.20):
    data = np.array(data)
    cut_value = np.percentile(data, (1 - target_percentile) * 100)
    return cut_value

In [12]:
ml_dirs = os.listdir('/scratch365/cmoore24/training/hgg/batch/outputs/')
path = '/scratch365/cmoore24/training/hgg/batch/outputs'

In [13]:
with open('/afs/crc.nd.edu/user/c/cmoore24/Public/hgg/ml/ml_processor_work/jsons/subregion_event_totals.json', 'r') as f:
    totals = json.load(f)
with open('/afs/crc.nd.edu/user/c/cmoore24/Public/hgg/ml/ml_processor_work/jsons/my_xsecs.json', 'r') as f:
    xsecs = json.load(f)

In [14]:
with open('model_results.json', 'w') as f:
    json.dump({}, f)

In [15]:
for i in tqdm(ml_dirs):

    ## Read in the list of ECFs used in this training
    with open(f'{path}/{i}/selected_ecfs.txt', 'r') as f:
        ecf_list = f.readlines()
    ecf_list = [item.strip() for item in ecf_list]
    
    ## Calculate the actual ratios and add them to an appropriate dictionary
    hgg_ratios = {}
    for j in ecf_list:
        hgg_ratios[j] = add_ratio(j, hgg.ECFs)    

    qcd_ratios = {}
    for j in ecf_list:
        qcd_ratios[j] = add_ratio(j, qcd.ECFs)

    ## Read in the model specific files
    model = f'{path}/{i}/traced_model.pt'
    scaler = f'{path}/{i}/scaler.pkl'
    with open(scaler, 'rb') as f:
        scaler = pickle.load(f)

    ## Start building the model
    tagger = EnergyCorrelatorFunctionTagger(model)

    hgg_imap = imapper(hgg_ratios, ecf_list)
    hgg_scores = tagger(hgg, scaler, hgg_imap)[:,0]

    qcd_imap = imapper(qcd_ratios, ecf_list)
    qcd_scores = tagger(qcd, scaler, qcd_imap)[:,0]

    ## Remove NaNs 
    mask = ~np.isnan(qcd_scores)
    qcd_scores = qcd_scores[mask]
    qcd_train = qcd[mask]

    mask = ~np.isnan(hgg_scores)
    hgg_scores = hgg_scores[mask]
    hgg_train = hgg[mask]

    ## Calculate ROC AUC
    bkg_zeros = ak.zeros_like(qcd_scores)
    sig_ones = ak.ones_like(hgg_scores)
    combined = ak.concatenate([qcd_scores,hgg_scores])
    combined_truth = ak.concatenate([bkg_zeros, sig_ones])

    try:
        fpr, tpr, thresholds = roc_curve(combined_truth, combined)
        roc_auc = auc(fpr, tpr)
    except:
        with open('model_results.json', 'r') as f:
            results = json.load(f)
    
        results[ecf] = {'roc_auc': None, 'sculpt_metric': None}
    
        with open('model_results.json', 'w') as f:
            json.dump(results, f)
        continue


    ## Build Hists for sculpting metrics
    cut = get_cut(qcd_scores)

    mask = ~((qcd_scores >= cut))
    qcd_cut_msd = qcd_train.msoftdrop[mask]
    qcd_fail_hist = hist.Hist.new.Reg(40, 80, 170, name='msd', label='QCD MSD').Weight()
    qcd_fail_hist.fill(msd=qcd_cut_msd);

    mask = ((qcd_scores >= cut))
    qcd_cut_msd = qcd_train.msoftdrop[mask]
    qcd_pass_hist = hist.Hist.new.Reg(40, 80, 170, name='msd', label='QCD MSD').Weight()
    qcd_pass_hist.fill(msd=qcd_cut_msd);

    scale = ((44.99*(xsecs['qcd']['qcd_470to600']*1000))/totals['qcd']['470to600'])
    qcd_pass_hist.view(flow=True)[:] *= scale
    qcd_fail_hist.view(flow=True)[:] *= scale

    total_qcd_hist = qcd_pass_hist + qcd_fail_hist

    sculpt_metric = sum(abs(total_qcd_hist.density() - qcd_pass_hist.density()))

    with open('model_results.json', 'r') as f:
        results = json.load(f)

    results[i] = {'roc_auc': roc_auc, 'sculpt_metric': sculpt_metric}

    with open('model_results.json', 'w') as f:
        json.dump(results, f)

100%|███████████████████████████████████████| 648/648 [7:11:00<00:00, 39.91s/it]
