# Demo 3: DNN inference

In this demo we show how different workflows can be used to optimize the performance of a user analysis, with a generic DNN model inference as an example.


In [None]:
import numpy as np
import pandas as pd
import torch

from python.event_selection import load_events
from python.dnn_model import NeuralNet

sources = ["data", "ttbar", "dy"]
server = "file:/depot/cms/purdue-af/demos/"
model_dir = "/depot/cms/purdue-af/demos/"
dfs = {}

features = ['mu1_pt', 'mu1_eta', 'mu2_pt', 'mu2_eta', 'dimuon_mass', 'met']

# load datasets for inference
for src in sources:
    dfs[src] = load_events(f"{server}/{src}.root")[features]


### Connect to an existing cluster

In [None]:
# from dask.distributed import Client

# client = Client("tcp://127.0.0.1:42573")
# client

### Or create a local cluster

In [None]:
from dask.distributed import LocalCluster, Client
cluster = LocalCluster()
client = Client(cluster)
client

In [None]:
if torch.cuda.is_available():
    print("Will use GPU for inference.")
else:
    print("Will use CPUs for inference.")

def inference(inp):
    label = inp[0]
    df = inp[1]
    #model_path="/depot/cms/purdue-af/triton/models/test-model/1/model.pt"
    model_path=model_dir+"/model.ckpt"
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = NeuralNet(6, [16, 8], 1).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    df = torch.from_numpy(df.values).to(device).float()
    scores = model(df) 
    scores = scores.cpu().detach().numpy()
    return {
        "label": label,
        "output": scores.ravel()
    }

scattered_data = client.scatter(list(dfs.items()))
futures = client.map(inference, scattered_data)
results = client.gather(futures)

print("\nInference outputs:")
for res in results:
    print(res["label"], res["output"])

In [None]:
import tritonclient.grpc as grpcclient

#triton_address = '128.211.160.154:8001' #5gb
#triton_address = '128.211.160.153:8001' #10gb
#triton_address = '128.211.160.147:8001' #20gb
triton_address = 'hammer-f000.rcac.purdue.edu:8001'

print(f"Connecting to Triton inference sever at {triton_address}")

keepalive_options = grpcclient.KeepAliveOptions(
    keepalive_time_ms=2**31 - 1,
    keepalive_timeout_ms=20000,
    keepalive_permit_without_calls=False,
    http2_max_pings_without_data=2
)

def inference_triton(inp):
    # Create Triton client
    try:
        triton_client = grpcclient.InferenceServerClient(
            url=triton_address,
            verbose=False,
            keepalive_options=keepalive_options
        )
    except Exception as e:
        print("Channel creation failed: " + str(e))
        sys.exit()
    
    label= inp[0]
    df = inp[1]
    
    # Inputs and outputs should be compatible with model metadata
    # stored in /depot/cms/purdue-af/triton/models/test-model/config.pbtxt
    inputs = [grpcclient.InferInput('INPUT__0', df.shape, "FP64")]
    outputs = [grpcclient.InferRequestedOutput('OUTPUT__0')]
    
    # Load input data
    inputs[0].set_data_from_numpy(df.values)
    
    # Run inference on Triton server.
    # Models are stored in /depot/cms/purdue-af/triton/models/
    results = triton_client.infer(
        model_name="test-model",
        inputs=inputs,
        outputs=outputs,
        headers={'test': '1'},
    )

    output = results.as_numpy('OUTPUT__0')
    return {
        "label": label,
        "output": output.flatten()
    }

# results = []
# n = 1
# for i in range(n):
#     for label, df in dfs.items():
#         results.append(inference_triton([label, df]))

scattered_data = client.scatter(list(dfs.items()))
futures = client.map(inference_triton, scattered_data)
results = client.gather(futures)

print("\nInference outputs:")
for res in results:
    print(res["label"], res["output"])

In [None]:
import matplotlib.pyplot as plt
bins = np.linspace(0, 1, 100)
plt.figure(figsize=(5,4))

dnn = {res["label"]: res["output"] for res in results}

plt.hist(dnn["dy"], bins, alpha=0.3, label='dy', density=True)
plt.hist(dnn["ttbar"], bins, alpha=0.3, label='ttbar', density=True)
plt.hist(dnn["data"], bins, alpha=0.3, label='data', density=True)
plt.xlabel('DNN Score')
plt.ylabel('Events')
leg = plt.legend(loc='upper left')