In [None]:
import math
import numpy as np
import awkward as ak
import time
from datetime import datetime

In [None]:
from utils.tritonutils import wrapped_triton

# create instance of triton model
triton_model = wrapped_triton( "triton+grpc://test-3.apps.okddev.fnal.gov:443/reconstruction_bdt_xgb/1")

In [None]:
# for local model
from xgboost import XGBClassifier
local_model = XGBClassifier()
local_model.load_model('/srv/models/xgb_demo.json')
local_model

In [None]:
#small test
data = np.random.rand(5,20).astype(np.float32)

In [None]:
#testing local model
local_model.predict_proba(data)

In [None]:
#testing triton model
out = triton_model({'input__0':data},'output__0')
out

In [None]:
def process_jets(in_jets, batch_size=10000, use_triton=False):
    
    print('Running triton server inference' if use_triton else 'Running local inference')
    
    # define variables to track processing time
    njets = np.array([])
    t = np.array([])
    t_begin = time.time()
    
    # loop through input data batches and run inference on each batch
    for ii in range(0, len(in_jets), batch_size):
        print('%i/%i jets processed, processing next batch'%(ii,len(in_jets)))

        # get a batch of data
        try:
            jets_eval = in_jets[ii:ii + batch_size]
            njets = np.append(njets, ii+batch_size)
        except:
            jets_eval = in_jets[ii:-1]
            njets = np.append(njets, len(in_jets))

        ## structure inputs slightly differently and run inference depending on model
        # triton model
        if use_triton:
            X = {}
            c = 0
            for k in jets_eval.fields:
                X[f'{k}__{c}'] = ak.to_numpy(jets_eval[k])
                c += 1
                
            # triton inference
            outputs = triton_model(X)
                
        # local model   
        else:
            X = []
            for k in jets_eval.fields:
                X.append(torch.from_numpy(ak.to_numpy(jets_eval[k])))
                
            # local inference
            with torch.no_grad():
                outputs = local_model(*X).detach().numpy()

        t = np.append(t, time.time()-t_begin)
        
    print('Total time elapsed = %f sec'%t[-1])

    return njets, t
    

In [None]:
# create random inputs to test
n_inputs = 10000000
test_inputs = {'input': np.random.rand(n_inputs,20).astype(np.float32)}

test_inputs_ak = ak.Array(test_inputs)

In [None]:
def process_jets(in_jets, batch_size=1024, use_triton=False):
    
    print('Running triton server inference' if use_triton else 'Running local inference')
    
    # define variables to track processing time
    njets = np.array([])
    t = np.array([])
    t_begin = time.time()
    
    # loop through input data batches and run inference on each batch
    for ii in range(0, len(in_jets), batch_size):
        print('%i/%i jets processed, processing next batch'%(ii,len(in_jets)))

        # get a batch of data
        try:
            jets_eval = in_jets[ii:ii + batch_size]
            njets = np.append(njets, ii+batch_size)
        except:
            jets_eval = in_jets[ii:-1]
            njets = np.append(njets, len(in_jets))

        ## structure inputs slightly differently and run inference depending on model
        # triton model
        if use_triton:
            X = {}
            c = 0
            for k in jets_eval.fields:
                X[f'{k}__{c}'] = ak.to_numpy(jets_eval[k])
                c += 1
                
            # triton inference
            outputs = triton_model(X)
                
        # local model   
        else:
            for k in jets_eval.fields:
                X = ak.to_numpy(jets_eval[k])
                
            # local inference
            outputs = local_model.predict_proba(X)

        t = np.append(t, time.time()-t_begin)
        
    print('Total time elapsed = %f sec'%t[-1])

    return njets, t
    

In [None]:
# to minimize noise due to connection with triton server, take avg results of n trials
n = 1
local_t = None
for ii in range(n):
    print('------ Trial %i ------'%ii)
    if local_t is None:
        local_njets, local_t = process_jets(test_inputs_ak, use_triton=False, batch_size=100000)
    else:
        local_njets, local_t_temp = process_jets(test_inputs_ak, use_triton=False, batch_size=100000)
        local_t += local_t_temp
local_t /= n

In [None]:
# to minimize noise due to connection with triton server, take avg results of n trials
n = 10
triton_t = None
for ii in range(n):
    print('------ Trial %i ------'%ii)
    if triton_t is None:
        triton_njets, triton_t = process_jets(test_inputs_ak, use_triton=True, batch_size=100000)
    else:
        triton_njets, triton_t_temp = process_jets(test_inputs_ak, use_triton=True, batch_size=100000)
        triton_t += triton_t_temp
triton_t /= n

In [None]:
# FOR TESTING WITHOUT ANY EXTRA FLUFF, just choose the batch size and time to test for
c = 0
batch_size = 1000000
print(datetime.now())
t_start = time.time()
while (time.time()-t_start)<480: # 60*minutes you want to run
    data = np.random.rand(batch_size,20).astype(np.float32)
    out = triton_model({'input__0':data},'output__0')
    c +=1
print(datetime.now())
print('%i batches processed, batch size of %i'%(c,batch_size))

## now plot

In [None]:
import matplotlib.pyplot as plt
from matplotlib import gridspec
import mplhep as hep
plt.style.use(hep.style.CMS)

In [None]:
fig = plt.figure()
# set height ratios for subplots
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])

# the first subplot
ax0 = plt.subplot(gs[0])
# log scale for axis Y of the first subplot
ax0.set_yscale("log")
line0, = ax0.plot(local_njets, local_t, color='r')
line1, = ax0.plot(triton_njets, triton_t, color='b')

# the second subplot
# shared axis X
ax1 = plt.subplot(gs[1], sharex = ax0)
line2, = ax1.plot(local_njets, local_t/triton_t, color='black', linestyle='--')
plt.setp(ax0.get_xticklabels(), visible=False)
# remove last tick label for the second subplot
yticks = ax1.yaxis.get_major_ticks()
yticks[-1].label1.set_visible(False)

# put legend on first subplot
ax0.legend((line0, line1), ('local model', 'triton model'), loc='lower right')

ax0.set_ylabel('time elapsed (s)')
ax1.set_ylabel('$t_{local}/t_{triton}$')
ax1.set_xlabel('# inputs processed')

# remove vertical gap between subplots
plt.subplots_adjust(hspace=.0)
plt.rcParams["figure.figsize"] = (7,6)
#plt.savefig("results/timing_results_xgb.eps", bbox_inches="tight")
plt.show()