In [None]:
import boto3
import uproot
import os
import time
import tempfile
import sys
import shutil
import json
import pickle

sys.path += ["../hepaccelerate"]
import hepaccelerate

from dask.distributed import Client, get_worker

In [None]:
cl = Client()
cl

In [None]:
from distributed.diagnostics.plugin import WorkerPlugin

class SandboxPlugin(WorkerPlugin):
    def __init__(self, env):
        self.env = env
    
    def setup(self, worker):
        self.worker = worker
        
        #os.chdir(worker.local_directory)
        
        os.system("tar xf sandbox.tgz")
        os.chdir("hepaccelerate-cms")
        if not "./hepaccelerate" in sys.path:
            sys.path += ["./hepaccelerate", "./coffea"]
        
    def transition(self, key, start, finish, *args, **kwargs):
        pass

In [None]:
session = boto3.Session()

credentials = session.get_credentials()
access_key = credentials.access_key
secret_key = credentials.secret_key

os.environ["AWS_ACCESS_KEY_ID"] = access_key
os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key

s3 = boto3.client(
    's3',
)

In [None]:
os.remove("/home/jovyan/sandbox.tgz")
cwd = os.getcwd()
os.chdir("/home/jovyan")
os.system("tar -czf sandbox.tgz hepaccelerate-cms")
os.chdir(cwd)
plugin = SandboxPlugin(os.environ)
cl.restart(timeout=120)
cl.upload_file("/home/jovyan/sandbox.tgz")
cl.register_worker_plugin(plugin)

In [None]:
s3.download_file(
    "hepaccelerate-hmm-skim-merged",
    "files.txt",
    "files.txt"
)

In [None]:
!wc -l files.txt

In [None]:
fl = [f.strip() for f in open("files.txt").readlines()]

In [None]:
def get_num_events(fn, env=os.environ):
    s3 = boto3.client(
        's3',
        aws_access_key_id=env["AWS_ACCESS_KEY_ID"],
        aws_secret_access_key=env["AWS_SECRET_ACCESS_KEY"],
    )
    
    t0 = time.time()
    tmp = tempfile.mktemp()
    s3.download_file(
        "hepaccelerate-hmm-skim-merged",
        fn[1:],
        tmp
    )
    tf = uproot.open(tmp)
    tt = tf.get("Events")
    nev = len(tt)
    file_size = os.path.getsize(tmp)
    os.remove(tmp)
    t1 = time.time()
    
    ret = {
        "num_events": nev,
        "file_size": file_size,
        "time_delta": t1 - t0,
        "t1": t1,
        "t0": t0,
    }
    
    return ret

In [None]:
from dask.distributed import progress

In [None]:
rets = cl.map(get_num_events, fl[:10])
progress(rets)

In [None]:
rets2 = [r.result() for r in rets]

In [None]:
def aggregate(rets):
    total_size = sum([r["file_size"] for r in rets])
    start_time = min([r["t0"] for r in rets])
    end_time = max([r["t1"] for r in rets])
    return {
        "total_size": total_size,
        "start_time": start_time,
        "end_time": end_time
    }

In [None]:
agg = aggregate(rets2)
agg["total_size"] / (agg["end_time"] - agg["start_time"]) / 1000 / 1000

In [None]:
def check_sandbox(args, env=os.environ):
    fn, dataset_name, dataset_era, is_mc, num_chunk, random_seed = args
    
    worker = get_worker()

    tmproot = tempfile.mktemp(suffix=".root")
    s3 = boto3.client(
        's3',
        aws_access_key_id=env["AWS_ACCESS_KEY_ID"],
        aws_secret_access_key=env["AWS_SECRET_ACCESS_KEY"],
    )
    s3.download_file(
        "hepaccelerate-hmm-skim-merged",
        fn[1:],
        tmproot
    )
    
    job_descriptions = [
        {
            "dataset_name": dataset_name,
            "dataset_era": dataset_era,
            "filenames": [tmproot],
            "is_mc": is_mc,
            "dataset_num_chunk": num_chunk,
            "random_seed": random_seed
        }
    ]
    
    tmpout = tempfile.mkdtemp(suffix="_out")
    tmpjson = tempfile.mktemp(suffix=".json")
    with open(tmpjson, "w") as fi:
        json.dump(job_descriptions, fi)
    
    os.system("PYTHONPATH=hepaccelerate:coffea:. python tests/hmm/run_jd.py {0} {1}".format(tmpjson, tmpout))
    
    ret = pickle.load(open(tmpout + "/{0}_{1}_{2}.pkl".format(dataset_name, dataset_era, num_chunk), "rb"))
    os.remove(tmproot)
    os.remove(tmpjson)
    shutil.rmtree(tmpout)
    
    return ret

In [None]:
args = []

i = 0
for fn in fl[:10]:
    args += [(fn, "dy_0j", "2016", True, i, i)]
    i += 1

In [None]:
futs = cl.map(check_sandbox, args, retries=0)
rets = [f.result() for f in futs]

In [None]:
rets[0]