### Set parameters

In [None]:
%time

VERBOSE = True
DATASET = "multi_1TB" # "mc_1TB", "multi_1TB" - 4 dids, "data_50TB", all"
NUM_FILES = "10" # "0" for all files
IGNORE_CACHE = False
DOWNLOAD_SX_RESULT = True

DASK_CLIENT = "local" # "local" or "scheduler"
DASK_SCHEDULER_ADDRESS = "tcp://dask-kyungeonchoi-56c93432-0.af-jupyter:8786"
DASK_REPORT = False
STEPS_PER_FILE = 1


### Imports

In [None]:
from typing import Any, List, Tuple
import awkward as ak
import dask
import dask_awkward as dak
from dask.distributed import Client, LocalCluster, performance_report
import servicex as sx
import uproot

from datasets import determine_dataset

### Build ServiceX query

In [None]:
%time
def run_query(input_filenames=None):
    
    PERCENT = 1
    
    import uproot
    import awkward as ak
    
    BRANCH_LIST = [
        "AnalysisJetsAuxDyn.pt", "AnalysisJetsAuxDyn.eta", "AnalysisJetsAuxDyn.phi", "AnalysisJetsAuxDyn.m",
        "AnalysisElectronsAuxDyn.pt", "AnalysisElectronsAuxDyn.eta", "AnalysisElectronsAuxDyn.phi",
        "AnalysisElectronsAuxDyn.m", "AnalysisMuonsAuxDyn.pt", "AnalysisMuonsAuxDyn.eta",
        "AnalysisMuonsAuxDyn.phi", "AnalysisJetsAuxDyn.EnergyPerSampling", "AnalysisJetsAuxDyn.SumPtTrkPt500",
        "AnalysisJetsAuxDyn.TrackWidthPt1000", "PrimaryVerticesAuxDyn.z", "PrimaryVerticesAuxDyn.x",
        "PrimaryVerticesAuxDyn.y", "AnalysisJetsAuxDyn.NumTrkPt500", "AnalysisJetsAuxDyn.NumTrkPt1000",
        "AnalysisJetsAuxDyn.SumPtChargedPFOPt500", "AnalysisJetsAuxDyn.Timing",
        "AnalysisJetsAuxDyn.JetConstitScaleMomentum_eta", "AnalysisJetsAuxDyn.ActiveArea4vec_eta",
        "AnalysisJetsAuxDyn.DetectorEta", "AnalysisJetsAuxDyn.JetConstitScaleMomentum_phi",
        "AnalysisJetsAuxDyn.ActiveArea4vec_phi", "AnalysisJetsAuxDyn.JetConstitScaleMomentum_m",
        "AnalysisJetsAuxDyn.JetConstitScaleMomentum_pt", "AnalysisJetsAuxDyn.EMFrac",
        "AnalysisJetsAuxDyn.Width", "AnalysisJetsAuxDyn.ActiveArea4vec_m", "AnalysisJetsAuxDyn.ActiveArea4vec_pt",
        "AnalysisJetsAuxDyn.DFCommonJets_QGTagger_TracksWidth", "AnalysisJetsAuxDyn.PSFrac",
        "AnalysisJetsAuxDyn.JVFCorr", "AnalysisJetsAuxDyn.DFCommonJets_QGTagger_TracksC1",
        "AnalysisJetsAuxDyn.DFCommonJets_fJvt", "AnalysisJetsAuxDyn.DFCommonJets_QGTagger_NTracks",
        "AnalysisJetsAuxDyn.GhostMuonSegmentCount", "AnalysisMuonsAuxDyn.muonSegmentLinks",
        "AnalysisMuonsAuxDyn.msOnlyExtrapolatedMuonSpectrometerTrackParticleLink",
        "AnalysisMuonsAuxDyn.extrapolatedMuonSpectrometerTrackParticleLink",
        "AnalysisMuonsAuxDyn.inDetTrackParticleLink", "AnalysisMuonsAuxDyn.muonSpectrometerTrackParticleLink",
        "AnalysisMuonsAuxDyn.momentumBalanceSignificance", "AnalysisMuonsAuxDyn.topoetcone20_CloseByCorr",
        "AnalysisMuonsAuxDyn.scatteringCurvatureSignificance", "AnalysisMuonsAuxDyn.scatteringNeighbourSignificance",
        "AnalysisMuonsAuxDyn.neflowisol20_CloseByCorr", "AnalysisMuonsAuxDyn.topoetcone20",
        "AnalysisMuonsAuxDyn.topoetcone30", "AnalysisMuonsAuxDyn.topoetcone40", "AnalysisMuonsAuxDyn.neflowisol20",
        "AnalysisMuonsAuxDyn.segmentDeltaEta", "AnalysisMuonsAuxDyn.DFCommonJetDr",
        "AnalysisMuonsAuxDyn.combinedTrackParticleLink", "AnalysisMuonsAuxDyn.InnerDetectorPt",
        "AnalysisMuonsAuxDyn.MuonSpectrometerPt", "AnalysisMuonsAuxDyn.clusterLink",
        "AnalysisMuonsAuxDyn.spectrometerFieldIntegral", "AnalysisElectronsAuxDyn.ambiguityLink",
        "AnalysisMuonsAuxDyn.EnergyLoss", "AnalysisJetsAuxDyn.NNJvtPass", "AnalysisElectronsAuxDyn.topoetcone20_CloseByCorr",
        "AnalysisElectronsAuxDyn.topoetcone20ptCorrection", "AnalysisElectronsAuxDyn.topoetcone20",
        "AnalysisMuonsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVA_pt500_CloseByCorr",
        "AnalysisElectronsAuxDyn.DFCommonElectronsECIDSResult", "AnalysisElectronsAuxDyn.neflowisol20",
        "AnalysisMuonsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVA_pt500", "AnalysisMuonsAuxDyn.ptcone40",
        "AnalysisMuonsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVA_pt1000_CloseByCorr",
        "AnalysisMuonsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVA_pt1000", "AnalysisMuonsAuxDyn.ptvarcone40",
        "AnalysisElectronsAuxDyn.f1", "AnalysisMuonsAuxDyn.ptcone20_Nonprompt_All_MaxWeightTTVA_pt500",
        "PrimaryVerticesAuxDyn.vertexType", "AnalysisMuonsAuxDyn.ptvarcone30", "AnalysisMuonsAuxDyn.ptcone30",
        "AnalysisMuonsAuxDyn.ptcone20_Nonprompt_All_MaxWeightTTVA_pt1000",
        "AnalysisElectronsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVALooseCone_pt500", "AnalysisMuonsAuxDyn.CaloLRLikelihood"
    ]

    filter_name = lambda x: x in BRANCH_LIST

    with uproot.open({input_filenames:"CollectionTree"}, filter_name=filter_name) as f:
    # with uproot.open({input_filenames:"CollectionTree"}) as f:
        branches = {}
        for b in BRANCH_LIST:
            try:
                branches[b] = f[b].array()
                if "Link" in str(branches[b].type):
                    branches[b] = ak.Array(range(len(branches[b])))
                elif "var * var *" in str(branches[b].type):
                    branches[b] = ak.sum(branches[b], axis=1)                
            except:
                pass
        ak_table = ak.Array(branches)

        end = int(len(ak_table)*PERCENT/100)
        ak_table = ak_table[:end]

        return {"servicex_reduction": ak_table}

In [None]:
%time
def build_servicex_spec(
    ignore_cache: bool,
    num_files: int,
    download: bool
):
    """Load the servicex query 
    """

    # List of RucioDIDs    
    ds_names = determine_dataset(DATASET)

    # Build the query
    query = (run_query, "python")

    for ds_name in ds_names:
        print(f"Querying dataset {ds_name}")
    if num_files == 0:
        print("Running on the full dataset.")
    else:
        print(f"Running on {num_files} files of dataset.")

    spec = sx.ServiceXSpec(
        General=sx.General(
            ServiceX="testing4",
            Codegen=query[1],
            OutputFormat=sx.ResultFormat.root,  # type: ignore
            Delivery=("LocalCache" if download else "SignedURLs"),  # type: ignore
        ),
        Sample=[
            sx.Sample(
                Name=f"{ds_name}"[0:40],
                RucioDID=ds_name,
                Function=query[0],
                NFiles=num_files,
                IgnoreLocalCache=ignore_cache,
            )  # type: ignore
            for ds_name in ds_names
        ],
    )

    return spec

In [None]:
%time
sx_spec = build_servicex_spec(IGNORE_CACHE, NUM_FILES, DOWNLOAD_SX_RESULT)

### Run ServiceX

In [None]:
%time
# Returns a complete list of paths (be they local or url's) for the root or parquet files.
results = sx.deliver(sx_spec)

# Print what we got
for sample in results:
    print(f"Dataset {sample} contains {len(results[sample])} files")

### Dask

In [None]:
%time
if DASK_CLIENT == "local":
    # Do not know how to do it otherwise.
    n_workers = 8
    print("Creating local Dask cluster for {n_workers} workers")
    cluster = LocalCluster(
        n_workers=n_workers, processes=False, threads_per_worker=1
    )
    client = Client(cluster)
    steps_per_file = 20
elif DASK_CLIENT == "scheduler":
    print("Connecting to Dask scheduler at {DASK_SCHEDULER_ADDRESS}")
    client = Client(DASK_SCHEDULER_ADDRESS)
    steps_per_file = 2
else:
    print("Unknown dask client!")

In [None]:
%time
def calculate_total_count(
    ds_name: str, steps_per_file: int, files: List[str]
) -> Tuple[Any, Any]:
    """Calculate the non zero fields in the files.

    Args:
        steps_per_file (int): The number of steps to split the file into.
        files (List[str]): The list of files in which to count the fields.

    Returns:
        _: DASK graph for the total count.
    """
    data, report_to_be = uproot.dask(
        {f: "servicex_reduction" for f in files},
        open_files=False,
        steps_per_file=steps_per_file,
        allow_read_errors_with_report=True,
    )

    # Now, do the counting.
    # The straight forward way to do this leads to a very large dask graph. We can
    # do a little prep work here and make it more clean.
    print(
        f"{ds_name}: Generating the dask compute graph for"
        f" {len(data.fields)} fields"  # type: ignore
    )

    total_count = 0
    assert isinstance(data, dak.Array)  # type: ignore
    for field in data.fields:
        # print(f"{ds_name}: Counting field {field}")
        _counter_to_add = ak.count_nonzero(data[field], axis=-1)  # reduce innermost

    total_count = total_count + _counter_to_add  # sum 1-dim array built from new branch

    total_count = ak.count_nonzero(total_count, axis=0)  # reduce to int
        
    n_optimized_tasks = len(dask.optimize(total_count)[0].dask)  # type: ignore
    print(
        f"{ds_name}: Number of tasks in the dask graph: optimized: "
        f"{n_optimized_tasks:,} "  # type: ignore
        f"unoptimized: {len(total_count.dask):,}",  # type: ignore
    )

    # total_count.visualize(optimize_graph=True)  # type: ignore
    # opt = Path("mydask.png")
    # opt.replace("dask-optimized.png")
    # total_count.visualize(optimize_graph=False)  # type: ignore
    # opt.replace("dask-unoptimized.png")

    return report_to_be, total_count

In [None]:
%time
# now materialize everything.
print(
    f"Using `uproot.dask` to open files (splitting files {STEPS_PER_FILE} ways)."
)
# The 20 steps per file was tuned for this query and 8 CPU's and 32 GB of memory.
print("Starting build of DASK graphs")
all_dask_data = {
    k: calculate_total_count(k, STEPS_PER_FILE, files)
    for k, files in results.items()
}
print("Done building DASK graphs.")

# Do the calc now.
print("Computing the total count")
all_tasks = {k: v[1] for k, v in all_dask_data.items()}
if DASK_REPORT:
    with performance_report(filename="dask-report.html"):
        results = dask.compute(*all_tasks.values())  # type: ignore
        result_dict = dict(zip(all_tasks.keys(), results))
else:
    results = dask.compute(*all_tasks.values())  # type: ignore
    result_dict = dict(zip(all_tasks.keys(), results))

for k, r in result_dict.items():
    print(f"{k}: result = {r:,}")

# Scan through for any exceptions that happened during the dask processing.
all_report_tasks = {k: v[0] for k, v in all_dask_data.items()}
all_reports = dask.compute(*all_report_tasks.values())  # type: ignore
for k, report_list in zip(all_report_tasks.keys(), all_reports):
    for process in report_list:
        if process.exception is not None:
            print(
                f"Exception in process '{process.message}' on file {process.args[0]} "
                "for ds {k}"
            )