In [1]:
import sys
import numpy as np
import uproot
import hist
import awkward as ak

import tritonclient.grpc as grpcclient
from tritonclient.utils import InferenceServerException

from coffea import processor
from coffea.nanoevents.schemas.base import BaseSchema, zip_forms
from coffea.nanoevents.methods import base

if sys.version_info >= (3, 0):
    import queue
else:
    import Queue as queue

In [2]:
# options

verbose     = False # enable verbose output for grpcclient
batch_size  = 1000 # number of events to send in one inference request
url         = "agc-triton-inference-server:8001" # url of inference server. use 8001 for grpcclient and 8000 for httpclient
test_events = "testevents.csv" # input csv file
model_name  = "binary_classifier" # name of ML model to use (make sure it is loaded properly)
model_vers  = "" # specify model version if necessary
num_batches = 5 # number of batches to process (number of events will be num_batches*batch_size)
num_cores   = 4 # scaling for setup with FuturesExecutor

## Set up gRPC client and get model info

In [3]:
# create gRPC client (communicates with inference server)
triton_client = grpcclient.InferenceServerClient(url=url, 
                                                 verbose=verbose)

In [4]:
model_metadata = triton_client.get_model_metadata(model_name, model_vers)

In [5]:
model_config = triton_client.get_model_config(model_name=model_name, 
                                              model_version=model_vers)

In [6]:
model_config = model_config.config

In [7]:
input_metadata = model_metadata.inputs[0]
input_config = model_config.input[0]
output_metadata = model_metadata.outputs[0]

input_batch_dim = (model_config.max_batch_size > 0)

max_batch_size = model_config.max_batch_size
input_name = input_metadata.name
output_name = output_metadata.name
n_features = input_metadata.shape[1 if input_batch_dim else 0]
format = input_config.format
dtype = input_metadata.datatype

## Load data

In [8]:
processor_base = processor.ProcessorABC

class SigBkgInference(processor_base):
    
    def __init__(self, batch_size, model_name, model_vers, url, verbose):
        self.hist = (hist.Hist.new.Reg(50, -7, 7, name="var0", label="Variable 0")
                     .Reg(50, -45, 5, name="var1", label="Variable 1")
                     .Reg(50, 0, 15, name="var2", label="Variable 2")
                     .IntCat(range(2), name="classification", label="ML Classification (Sig/Bkg)")
                     .Weight())
        
        self.batch_size = batch_size
        self.input_name = input_name
        self.dtype = dtype
        self.output_name = output_name
        self.model_name = model_name
        self.model_vers = model_vers
        self.url = url
        self.verbose = verbose
        
    def process(self, events):
        
        histogram = self.hist.copy()
        
        # get event variables
        var0 = events.var0
        var1 = events.var1
        var2 = events.var2
        var3 = events.var3
        
        features = ak.concatenate([var0[..., np.newaxis],
                                   var1[..., np.newaxis],
                                   var2[..., np.newaxis],
                                   var3[..., np.newaxis]], axis=1)
        
        
        # start triton client
        triton_client = grpcclient.InferenceServerClient(url=self.url, verbose=self.verbose)
        # model_metadata = triton_client.get_model_metadata(self.model_name, self.model_vers)
        # model_config = triton_client.get_model_config(model_name=self.model_name, 
        #                                               model_version=self.model_vers)
        triton_client.close()
#         model_config = model_config.config
#         input_metadata = model_metadata.inputs[0]
#         input_config = model_config.input[0]
#         output_metadata = model_metadata.outputs[0]
        
#         max_batch_size = model_config.max_batch_size
#         batch_size = np.minimum(max_batch_size, self.batch_size)
#         input_name = input_metadata.name
#         output_name = output_metadata.name
#         dtype = input_metadata.datatype
        
        # get number of batches
        data_length = len(var0)
        num_batches = int(np.ceil(data_length/self.batch_size))
        
        # make inference request for each batch
        classification = np.zeros(data_length)
        startind = 0
        for i in range(num_batches):
            
            if i == num_batches-1: # if we are on last batch
                data_current = features[startind:].to_numpy().astype(np.float32)
            else:
                data_current = features[startind:startind + self.batch_size].to_numpy().astype(np.float32)
            startind += self.batch_size
            
            client = grpcclient
            
#             inpt = [client.InferInput(self.input_name, data_current.shape, self.dtype)]
#             inpt[0].set_data_from_numpy(data_current)
            
#             output = client.InferRequestedOutput(output_name)
            
            
            # get triton client
            # results = triton_client.infer(model_name=self.model_name, 
            #                               inputs=inpt, 
            #                               outputs=[output]).as_numpy(self.output_name)
            
            if i == num_batches-1: # if we are on last batch
                classification[startind:] = np.ones_like(classification[startind:])#results.T[0]
            else:
                classification[startind:startind + self.batch_size] = np.ones_like(classification[startind:startind + self.batch_size])#results.T[0]
                
        
        histogram.fill(var0=var0, 
                       var1=var1, 
                       var2=var2, 
                       classification=classification)
            
        output = {"nevents": {events.metadata["dataset"]: len(var0)}, 
                  "hist": histogram}

        return output

    def postprocess(self, accumulator):
        return accumulator

In [9]:
# define schema to load toy data

class ToySchema(BaseSchema):
    
    def __init__(self, base_form):
        super().__init__(base_form)
        self._form["contents"] = self._build_collections(self._form["contents"])
        
    def _build_collections(self, branch_forms):
        names = ["var0", "var1", "var2", "var3"]
        
        output = {}
        for name in names:
            output[name] = branch_forms[name]
        return output

    @property
    def behavior(self):
        behavior = {}
        behavior.update(base.behavior)
        return behavior

In [10]:
fileset = {'process0': {'files': ['data/testdata0.root',
                                  'data/testdata1.root',
                                  'data/testdata2.root',
                                  'data/testdata3.root',
                                  'data/testdata4.root'],
                        'nevts': 5*20000}
          }

In [11]:
executor = processor.FuturesExecutor(workers=num_cores)

In [12]:
run = processor.Runner(executor=executor, 
                       schema=ToySchema, 
                       savemetrics=True, 
                       metadata_cache={}, 
                       chunksize=500_000)

In [13]:
filemeta = run.preprocess(fileset, treename="events")

Output()

In [14]:
all_histograms, metrics = run(fileset, 
                              "events", 
                              processor_instance=SigBkgInference(batch_size, 
                                                                 model_name, 
                                                                 model_vers, 
                                                                 url, 
                                                                 verbose)
                             )

Output()

In [266]:
# load data 
data = np.loadtxt(test_events, dtype=np.float32, delimiter=',')

In [267]:
# batch information
data_length = data.shape[0]
max_num_batches = int(np.ceil(data_length/batch_size)) # maximum number of batches given number of events in data and batch_size
    
num_batches = np.minimum(num_batches, max_num_batches) # ensure number of batches doesn't extend beyond number of events

In [268]:
# send inference requests    
startind = 0

for i in range(num_batches):
        
        data_current = data[startind:startind + batch_size, :]
        startind += batch_size
        
        client = grpcclient
        
        print(data_current.shape)

        inpt = [client.InferInput(input_name, data_current.shape, dtype)]
        inpt[0].set_data_from_numpy(data_current)

        output = client.InferRequestedOutput(output_name)
        
        results = triton_client.infer(model_name=model_name, 
                                      inputs=inpt, 
                                      outputs=[output])
    
        inference_output = results.as_numpy(output_name)
        print(f"Inference Results for Batch {i}: ", np.round(inference_output).T)

(1000, 4)


KeyboardInterrupt: 

In [31]:
test = np.zeros(100)
test[80:] = np.ones(20)
test

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])