In [None]:
import os
from distributed import Client
from lpcjobqueue import LPCCondorCluster
import awkward as ak
import numpy as np
import torch
from utils.mlbench import SimpleWorkLog
from utils.mlbench import process_function, create_local_pnmodel, get_triton_client, run_inference_pnmodel, generate_pseudodata_from_seed
import time
import pathlib
#Can use ship_env and the .triton_env with LPCCondorCluster, but here's an alternative that should work for other cluster types
from distributed.diagnostics.plugin import UploadDirectory

In [None]:
cluster = LPCCondorCluster(cores=2, 
                           memory="7.5GB", 
                           disk="4GB", 
                           log_directory='/uscmst1b_scratch/lpc1/3DayLifetime/'+str(os.getlogin),
                           #ship_env=False,
                           #death_timeout=240,
                           #schedule_options={"dashboard_address": f":{__get_port():d}"},
                          )

In [None]:
cluster.adapt(minimum=50, maximum=50)

In [None]:
cluster.workers

In [None]:
client = Client(cluster)
client

In [None]:
client.register_worker_plugin(UploadDirectory("../utils",restart=True,update_path=True), nanny=True) 

In [None]:
client.register_worker_plugin(UploadDirectory("../models",restart=True,update_path=True), nanny=True)  

In [None]:
def test_structure(x):
    import os
    import sys
    import pathlib
    test = pathlib.Path("/srv/utils/")
    success = False
    try:
        from srv.utils.mlbench import SimpleWorkLog
        success = True
    except:
        pass
    success2 = False
    try:
        from utils.mlbench import SimpleWorkLog
        success2 = True
    except:
        pass
    success3 = False
    try:
        from mlbench import SimpleWorkLog
        success3 = True
    except:
        pass
    
    return os.environ, sys.path, success, success2, success3, list(test.iterdir())

def test_triton_dask(worker):
    x = get_triton_client()
    if x is not None:
        return "success"
    else:
        return type(x)

def print_cluster_info(cluster):
    for key in cluster.scheduler_info.keys():
        if key not in ["workers"]:
            print(key, cluster.scheduler_info[key])
        else:
            print(key)
            for address, details in cluster.scheduler_info[key].items():
                print("\t", address)
                maxdkey = max([len(dkey) for dkey in details])
                for dkey, dval in details.items():
                    diff = maxdkey - len(dkey)
                    extras = " "*diff
                    extras += "  =\t"    
                    print("\t\t", dkey, extras, dval)
def test_workers(x):
    results = {}
    try:
        import os
        results["pid"] = os.getpid()
    except:
        results["pid"] = False
        
    import socket
    try:
        import socket
        results["hostname"] = socket.gethostname()
    except:
        results["hostname"] = False
        
    try:
        from utils.mlbench import SimpleWorkLog
        results["utils"] = True
    except:
        results["utils"] = False
        
    try:
        from utils.mlbench import get_triton_client
        _ = get_triton_client()
        results["triton"] = True
    except:
        results["triton"] = False
        
    try:
        from utils.mlbench import create_local_pnmodel
        _ = create_local_pnmodel()
        results["local"] = True
    except:
        results["local"] = False
    
    return results

In [None]:
#Test the workers can perform basic functions
test = client.gather(client.map(test_workers, range(len(cluster.workers))))

In [None]:
unique_workers = {}
for r in test:
    unique_workers[r['hostname']+str(r['pid'])] = r
n_workers = len(unique_workers.keys())
n_utils_imports = sum([r['utils'] for r in unique_workers.values()])
n_triton_functioning = sum([r['triton'] for r in unique_workers.values()])
n_local_functioning = sum([r['local'] for r in unique_workers.values()])
n_workers, n_utils_imports, n_triton_functioning, n_local_functioning

In [None]:
n_workers = 100
long_multiplier = 10
#seeds, #pseudo-events, batchsize, use triton (True/False)
workargstriton = [range(n_workers), [1000]*n_workers, [1000]*n_workers, [True]*n_workers]
workargslocal = [range(n_workers), [1000]*n_workers, [250]*n_workers, [False]*n_workers]
workargstritonlong =  [range(n_workers*long_multiplier), 
                      [9999]*n_workers*long_multiplier, 
                      [1000]*n_workers*long_multiplier, 
                      [True]*n_workers*long_multiplier]

In [None]:
with_outputs, inf_worklogs, errors = run_inference_pnmodel(
    generate_pseudodata_from_seed(983, 1000), 
    get_triton_client(), 
    batchsize=1000, 
    triton=True, 
    worklog=SimpleWorkLog
)

In [None]:
# Triton, N workers trial
print("time", time.time())
pft = time.perf_counter()
futurestriton = client.map(process_function, *workargstriton)
resulttriton = client.gather(futurestriton)
print("runtime(s)", time.perf_counter() - pft)
print("time", time.time())

In [None]:
# Triton, N workers trial long
print("time", time.time())
pft = time.perf_counter()
futurestlong = client.map(process_function, *workargstritonlong)
resulttlong = client.gather(futurestlong)
print("runtime(s)", time.perf_counter() - pft)
print("time", time.time())

In [None]:
# Local, N workers trial
print("time", time.time())
pfl = time.perf_counter()
futureslocal = client.map(process_function, *workargslocal)
resultlocal = client.gather(futureslocal)
print("runtime(s)", time.perf_counter() - pfl)
print("time", time.time())

In [None]:
cluster.close()