In [1]:
import math
from __future__ import division
from collections import OrderedDict
from striped.job import SinglePointStripedSession as Session, IPythonDisplay
import histbook as hb

job_server = ("ifdb01.fnal.gov", 8765)
session = Session(job_server)

In [None]:
bg_datasets = """
Summer16.DYJetsToLL_M-50_TuneCUETP8M1_13TeV-madgraphMLM-pythia8
""".split()
bg_datasets = [ds.strip() for ds in bg_datasets if ds.strip()]

In [None]:
import fnal_column_analysis_tools.lookup_tools as lookup_tools
import cloudpickle
import zlib

# import a bunch of correction histograms
weightsext = lookup_tools.extractor()
correctionDescriptions = open("newCorrectionFiles.txt").readlines()
weightsext.add_weight_sets(correctionDescriptions)
weightsext.finalize()
weights_eval = weightsext.make_evaluator()
#let's pickle and zip it
weights_names = zlib.compress(cloudpickle.dumps(weightsext._extractor__names))
weights_vals = zlib.compress(cloudpickle.dumps(weightsext._extractor__weights))

#dir(weights_eval)
#print(weights_eval["muScaleFactor_TightId_Iso"])

In [None]:
all_hists = OrderedDict()
cat = hb.groupby("category", keeporder=True)

def add_1d(*args):
    h = hb.Hist(hb.bin(*args), cat)
    all_hists[args[0]] = h
    return h

add_1d("leadingLeptonPt", 100, 0, 500)
add_1d("zMass", 120, 0, 120)
add_1d("stripeThroughput", 300, 0, 300000)

display_opts = {'width': 300, 'height': 300}
def buildDisplay(hists, cols=2):
    rows = []
    for iRow in range(int(math.ceil(len(hists)/cols))):
        hrowname = hists.keys()[iRow*cols:(iRow+1)*cols]
        hrow = hists.values()[iRow*cols:(iRow+1)*cols]
        rows.append(hb.beside(*tuple(h.overlay("category").step(n, **display_opts) for h,n in zip(hrow, hrowname))))
    return IPythonDisplay(
        hb.below(*tuple(rows))
    )
    
display = buildDisplay(all_hists)

class Callback:
    def __init__(self, display):
        self.Display = display
        
    def on_histogram_update(self, nevents):
        self.Display.update()


In [None]:
#__worker_class__
    
#__worker_class__
    
# break sandbox
# https://stackoverflow.com/questions/33880646/access-module-sys-without-using-import-machinery
sys = next(getattr(c, f).__func__.__globals__['sys'] for c in ().__class__.__base__.__subclasses__() for f in dir(c) if isinstance(getattr(c, f, None), type((lambda: 0).__get__(0))) and 'sys' in getattr(c, f).__func__.__globals__)
if 'sandbox' in sys.modules:
    __builtins__['__import__'] = sys.modules['sandbox'].saved_import
    #import subprocess
    #raise Exception(subprocess.check_output("pip list".split(" ")))
    
import numpy as np
import awkward
import uproot_methods
import cloudpickle
import zlib
import time
#import fnal_column_analysis_tools

class denselookup(object):
    def __init__(self,values,dims): 
        self.__dimension = 0
        whattype = type(dims)
        if whattype == np.ndarray:
            self.__dimension = 1
        else:
            self.__dimension = len(dims)        
        if self.__dimension == 0:
            raise Exception('Could not define dimension for {}'.format(whattype))
        self.__axes = dims
        self.__values = values
        self.__type = type(self.__values)
    
    def __call__(self,*args):        
        inputs = list(args)
        #print inputs
        counts = None
        #print
        for i in xrange(len(inputs)):
            #print i,type(inputs[i])
            if isinstance(inputs[i], awkward.JaggedArray):
                if counts is not None and set(counts) != set(inputs[i].counts) and len(counts) == len(inputs[i].counts):
                    if counts == -1:
                        raise Exception('do not mix JaggedArrays and numpy arrays when calling denselookup')
                    else:
                        raise Exception('counts not uniform between all input jagged arrays!')
                counts = inputs[i].counts
                inputs[i] = inputs[i].flatten()
                #print type(inputs[i])
            elif isinstance(inputs[i], np.ndarray):
                counts = -1
        retval = self.__evaluate(*tuple(inputs))
        for arg in args:
            if isinstance(arg, awkward.JaggedArray):
                retval = awkward.JaggedArray.fromcounts(arg.counts,retval)
                break
        #print retval
        #print
        return retval
                                               
    
    def __evaluate(self,*args):        
        indices = [] 
        for arg in args: 
            if type(arg) == awkward.JaggedArray: raise Exception('JaggedArray in inputs')
        if self.__dimension == 1:
            indices.append(np.maximum(np.minimum(np.searchsorted(self.__axes, args[0], side='right')-1,self.__values.shape[0]-1),0))
        else:
            for dim in xrange(self.__dimension):
                #print self.__axes[dim], self.__values.shape
                indices.append(np.maximum(np.minimum(np.searchsorted(self.__axes[dim], args[dim], side='right')-1,self.__values.shape[len(self.__axes)-dim-1]-1),0))
        indices.reverse()
        return self.__values[tuple(indices)]
    
    def __repr__(self):
        myrepr = "{} dimensional histogram with axes:\n".format(self.__dimension)
        temp = "" 
        if self.__dimension == 1:
            temp = "\t1: {}\n".format(self.__axes)
        else:
            temp = "\t1: {}\n".format(self.__axes[0])
        for idim in xrange(1,self.__dimension):
            temp += "\t{}: {}\n".format(idim+1,self.__axes[idim])        
        myrepr += temp
        return myrepr

class evaluator(object):
    def __init__(self,names,primitives):
        self.__functions = {}
        for key, idx in names.iteritems():
            self.__functions[key] = denselookup(*primitives[idx])
            
    def __dir__(self):
        return self.__functions.keys()
        
    def __getitem__(self, key):
        return self.__functions[key]


def p4_pt(p4):
    # p4 is [n,4] numpy array
    # returns [n] numpy array
    return np.sqrt(p4[:,0]**2 + p4[:,1]**2)

def p4_eta(p4):
    p42 = p4**2
    p3 = np.sqrt(np.sum(p42[:,:3], axis=-1))
    return p4[:,2]/p3


JaggedWithLorentz = awkward.Methods.mixin(uproot_methods.classes.TLorentzVector.ArrayMethods, awkward.JaggedArray)

class JaggedDecoratedFourVector(awkward.JaggedArray,):
    def __init__(self,jagged):        
        super(JaggedDecoratedFourVector, self).__init__(jagged.starts,
                                                        jagged.stops,
                                                        jagged.content)
        if 'p4' not in self.columns:
            raise Exception('JaggedDecoratedFourVector declared without "p4" column: {}'.format(self.columns))
        
        self._ispair = False
        self._iscross = False
        if hasattr(jagged,'_ispair'):
            self._ispair = jagged._ispair
        if hasattr(jagged,'_iscross'):
            self._iscross = jagged._iscross
        
    @classmethod
    def fromcounts(cls,counts,p4,**kwargs):
        the_p4 = p4
        if not isinstance(p4,uproot_methods.TLorentzVectorArray):
            the_p4 = uproot_methods.TLorentzVectorArray(p4[:,0],p4[:,1],p4[:,2],p4[:,3])
        items = {'p4':the_p4}
        items.update(kwargs)
        return JaggedDecoratedFourVector(awkward.JaggedArray.fromcounts(counts,awkward.Table(items)))
    
    @property
    def p4(self):
        return self['p4']
    
    def at(self,what):
        raw = super(JaggedDecoratedFourVector,self).at(what)
        if 'p4' in raw.columns:
            return JaggedDecoratedFourVector(raw)
        if( np.sum(raw.counts) == 0):
            raise Exception("{}".format(raw.columns))
    
        return raw
    
    def distincts(self):
        return self.pairs(same=False)
    
    def pairs(self, same=True):
        outs = super(JaggedDecoratedFourVector, self).pairs(same)        
        if( sum(outs.counts) > 0 ):
            outs['p4'] = outs.at(0)['p4'] + outs.at(1)['p4']
        else:
            outs['p4'] = JaggedWithLorentz.fromcounts(outs.counts,[])        
        outs._ispair = True
        return JaggedDecoratedFourVector(outs)
    
    def cross(self, other):
        outs = super(JaggedDecoratedFourVector, self).cross(other)
        #currently JaggedArray.cross() has some funny behavior when it encounters the
        # p4 column and make some wierd new column... for now I just delete it and reorder
        # everything looks ok after that
        if outs._iscross:
            keys = outs.columns
            reorder = False
            for key in keys:
                if not isinstance(outs[key].content,awkward.array.table.Table):
                    del outs[key]
                    reorder = True
            if reorder:
                keys = outs.columns
                realkey = {}
                for i in xrange(len(keys)):
                    realkey[keys[i]] = str(i)
                for key in keys:
                    if realkey[key] != key:
                        outs[realkey[key]] = outs[key]
                        del outs[key]
            keys = outs.columns
            for key in keys:                    
                if 'p4' not in outs.columns:
                    outs['p4'] = outs.at(int(key))['p4']
                else:
                    outs['p4'] = outs['p4'] + outs.at(int(key))['p4']
        else:
            outs['p4'] = outs.at(0)['p4'] + outs.at(1)['p4']
            outs._iscross = True
        return JaggedDecoratedFourVector(outs)
    
    def __getattr__(self,what):
        if what in self.columns:
            return self[what]
        if what[0] == '_' and what[1:].isdigit() :
            return self[what[1:]]
        return getattr(super(JaggedDecoratedFourVector,self),what)

class ColumnGroup(object):
    def __init__(self,events,objName,*args):
        self.__map = {}        
        eventObj = getattr(events,objName)
        self.__counts = getattr(events,objName).count        
        for arg in args:
            callStack = arg.split('.')
            retval = getattr(eventObj,callStack[0])
            for i in xrange(1,len(callStack)):
                retval = getattr(retval,callStack[i])
            self.__map[arg] = retval
            
    def __getitem__(self,name):
        return self.__map[name]
    
    def columnsWithout(self,toremove):
        out = {}
        out.update(self.__map)
        if isinstance(toremove,str):
            del out[toremove]
        else:
            for key in toremove:
                del out[key]
        return out
    
    def columns(self):
        return self.__map
    
    def counts(self):
        return self.__counts
    
class PhysicalColumnGroup(ColumnGroup):
    def __init__(self,events,objName,p4Name,*args):
        self.__p4  = p4Name
        allargs = [p4Name]
        allargs.extend(args)        
        super(PhysicalColumnGroup,self).__init__(events,objName,*allargs)
        if p4Name is not None:
            self.setP4Name(p4Name)
    
    def setP4Name(self,name):
        if name not in self.columns().keys():
            raise Exception('{} not an available name in this PhysicalColumnGroup'.format(name))
        self.__p4 = name
    
    def p4Name(self):
        if self.__p4 is None:
            raise Exception('p4 is not set for this PhysicalColumnGroup')
        return self.__p4
    
    def p4Column(self):        
        return self[self.p4Name()]
    
    def otherColumns(self):
        return self.columnsWithout(self.p4Name())

def jaggedFromColumnGroup(cgroup):
    if isinstance(cgroup,PhysicalColumnGroup):
        return JaggedDecoratedFourVector.fromcounts(counts = cgroup.counts(),
                                                    p4 = cgroup.p4Column(),
                                                    **cgroup.otherColumns())
    else:
        return awkward.JaggedArray.fromcounts(cgroup.counts(),
                                              awkward.Table(cgroup.columns()))
    
class Worker(object):
    def __init__(self):
        self.Columns = ["Electron.charge", "Electron.p4", "Electron.tightID",
                        "Muon.charge", "Muon.p4", "Muon.tightID",
                        "Trigger.pass"]
        
    def run(self, events, job):
        tic = time.time()
        weights_eval = evaluator(cloudpickle.loads(zlib.decompress(job["weights_names"])),
                                 cloudpickle.loads(zlib.decompress(job["weights_vals"])))
        
        #35: HLT_Ele32_WPTight_Gsf_v
        #36: HLT_Ele35_WPTight_Gsf_v
        #45: HLT_IsoMu20_v
        #46: HLT_IsoMu22_v
        #48: HLT_IsoMu24_v
        #50: HLT_IsoMu27_v
        # yes, this is trivial for now but perhaps it's more fun later
        triggerColumns = ColumnGroup(events,"Trigger","pass")
        triggers = jaggedFromColumnGroup(triggerColumns)
        
        #only care about processing events which have triggered
        events_triggered = events.filter()
        events_triggered.Mask = ( triggers["pass"][:,35] |
                                  triggers["pass"][:,36] |
                                  triggers["pass"][:,45] |   
                                  triggers["pass"][:,46] |
                                  triggers["pass"][:,48] |
                                  triggers["pass"][:,50]   ).astype( dtype=np.bool )
                
        #events = events_triggered(events)
                
        electronCols = PhysicalColumnGroup(events,"Electron","p4","charge","tightID")
        electrons_new = jaggedFromColumnGroup(electronCols)            
        electrons_new['SF'] = weights_eval["eleScaleFactor_TightId_POG"](electrons_new.p4.eta,
                                                                         electrons_new.p4.pt)
        
        muonCols = PhysicalColumnGroup(events,"Muon","p4","charge","tightID")
        muons_new = jaggedFromColumnGroup(muonCols)
        #muons_new['SF'] = weights_eval["muScaleFactor_TightId_Iso"](np.absolute(muons_new.p4.eta),
        #                                                            muons_new.p4.pt)
        
        selected_electrons = electrons_new[(electrons_new.p4.pt > 20) &
                                           (np.absolute(electrons_new.p4.eta) < 2.5) &
                                           (electrons_new.tightID > 0)]
        
        selected_muons = muons_new[(muons_new.p4.pt > 20) &
                                   (np.absolute(muons_new.p4.eta) < 2.5) &
                                   (muons_new.tightID > 0)]
        
        dielectrons = selected_electrons.distincts()
        dimuons = selected_muons.distincts()
        
        selected_dielectrons = dielectrons[((dielectrons.at(0).p4.pt > 38) | (dielectrons.at(1).p4.pt > 38))]
        selected_dimuons = dimuons[((dimuons.at(0).p4.pt > 30) | (dimuons.at(1).p4.pt > 30))]

        zee_cat = ((selected_dielectrons.counts == 1) & 
                   (selected_dimuons.counts == 0))
        #fill electrons
        job.fill(
            category="ee",
            leadingLeptonPt=np.maximum(selected_dielectrons[zee_cat].at(0).p4.pt,
                                       selected_dielectrons[zee_cat].at(1).p4.pt).flatten()
        )        
        job.fill(
            category="ee",
            zMass=selected_dielectrons[zee_cat].p4.mass.flatten()
        )

        zmm_cat = ((selected_dimuons.counts == 1) & 
                   (selected_dielectrons.counts == 0))
        #fill muons
        job.fill(
            category="mm",
            leadingLeptonPt=np.maximum(selected_dimuons[zmm_cat].at(0).p4.pt,
                                       selected_dimuons[zmm_cat].at(1).p4.pt).flatten()
        )        
        job.fill(
            category="mm",
            zMass=selected_dimuons[zmm_cat].p4.mass.flatten()
        )
           
        #fill both
        job.fill(
            category="all",
            leadingLeptonPt=np.hstack([np.maximum(selected_dimuons[zmm_cat].at(0).p4.pt,
                                                  selected_dimuons[zmm_cat].at(1).p4.pt).flatten(), 
                                       np.maximum(selected_dielectrons[zee_cat].at(0).p4.pt,
                                                  selected_dielectrons[zee_cat].at(1).p4.pt).flatten()]),
        )
        job.fill(
            category="all",            
            zMass=np.hstack([selected_dielectrons[zee_cat].p4.mass.flatten(), 
                             selected_dimuons[zmm_cat].p4.mass.flatten()])
        )
        
        #profiling info
        toc = time.time()
        job.fill(
            category="all",
            stripeThroughput=len(events.Muon.count)/(toc-tic),
        )
        

In [None]:
display.init()
callback = Callback(display)
_ = map(lambda h: h.clear(), all_hists.values())

import time
nevents_total = 0
t1 = time.time()
for dataset in bg_datasets:
    job = session.createJob(dataset, 
            fraction=1.,
            user_callback=callback, 
            user_params = {"weights_names":weights_names,
                           "weights_vals":weights_vals},
            histograms=all_hists.values()
    )
    job.run()
    runtime = job.TFinish - job.TStart
    nevents = job.EventsProcessed
    nevents_total += nevents
    print "%-70s %7.3f M events, %7.3f M events/sec" % (dataset[:70], float(nevents)/1e6, nevents/runtime/1000000)
    #print "%s: %.6f million events/second" % (dataset, nevents/runtime/1000000)
    display.update()

t2 = time.time()
print "Total events processed: %d in %.1f seconds -> %.6f million events/second" %(nevents_total, t2-t1, nevents_total/(t2-t1)/1000000)


In [None]:
def moment(df, n, binval='mid'):
    """
        df: DataFrame with single-level MultiIndex specifying binning, and a shape of (nbins, )
        n: n-th moment of distribution, ignoring (over/under/nan)flow bins
        binval in ['left', 'right', 'mid']: point in the bin to use as bin value
    """
    vals = np.array([getattr(b, binval) for b in df.index if type(b) is not str])[1:-1]
    weights = np.array(df)[1:-2]  # ignore nanflow (last bin)
    moment = sum(pow(vals, n)*weights) / sum(weights)
    return moment

def mean(df, binval='mid'):
    return moment(df, 1, binval)

def std(df, binval='mid'):
    return np.sqrt(max(moment(df, 2, binval) - moment(df, 1, binval)**2, 0.))

In [None]:
stripeThroughput = mean(all_hists['stripeThroughput'].pandas()['count()']["all"])
nWorkers = 180
processingTime = nevents_total/stripeThroughput/nWorkers
print "Stripe processing throughput: %.0f evt/s" % stripeThroughput
print "Total throughput: %.0f evt/s" % (nevents_total/(t2-t1), )
print "Striped server overhead: %.1f %%" % ((1-processingTime/(t2-t1))*100, )