# Set parameters

In [None]:
VERBOSE = True
DATASET = "" # "mc_1TB", "multi_1TB" - 4 dids, "all"
IGNORE_CACHE = False
DOWNLOAD_SX_RESULT = False

DASK_CLIENT = "local" # local or scheduler
DASK_SCHEDULER_ADDRESS = ""
PROFILE = False
DASK_REPORT = False


In [1]:
import cProfile
import logging
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import awkward as ak
import dask
import dask_awkward as dak
import uproot
from dask.distributed import Client, LocalCluster, performance_report
from datasets import determine_dataset

import servicex as sx

In [2]:
class ElapsedFormatter(logging.Formatter):
    """Logging formatter that adds an elapsed time record since it was
    first created. Error messages are printed relative to when the code
    started - which makes it easier to understand how long operations took.
    """

    def __init__(self, fmt="%(elapsed)s - %(levelname)s - %(name)s - %(message)s"):
        super().__init__(fmt)
        self._start_time = time.time()

    def format(self, record):
        record.elapsed = f"{time.time() - self._start_time:0>9.4f}"
        return super().format(record)


In [3]:
def run_query(input_filenames=None):
    
    ELECTRON_PT_THRESHOLD = 100e3
    
    import uproot
    import awkward as ak
    
    BRANCH_LIST = [
        "AnalysisJetsAuxDyn.pt", "AnalysisJetsAuxDyn.eta", "AnalysisJetsAuxDyn.phi", "AnalysisJetsAuxDyn.m",
        "AnalysisElectronsAuxDyn.pt", "AnalysisElectronsAuxDyn.eta", "AnalysisElectronsAuxDyn.phi",
        "AnalysisElectronsAuxDyn.m", "AnalysisMuonsAuxDyn.pt", "AnalysisMuonsAuxDyn.eta",
        "AnalysisMuonsAuxDyn.phi", "AnalysisJetsAuxDyn.EnergyPerSampling", "AnalysisJetsAuxDyn.SumPtTrkPt500",
        "AnalysisJetsAuxDyn.TrackWidthPt1000", "PrimaryVerticesAuxDyn.z", "PrimaryVerticesAuxDyn.x",
        "PrimaryVerticesAuxDyn.y", "AnalysisJetsAuxDyn.NumTrkPt500", "AnalysisJetsAuxDyn.NumTrkPt1000",
        "AnalysisJetsAuxDyn.SumPtChargedPFOPt500", "AnalysisJetsAuxDyn.Timing",
        "AnalysisJetsAuxDyn.JetConstitScaleMomentum_eta", "AnalysisJetsAuxDyn.ActiveArea4vec_eta",
        "AnalysisJetsAuxDyn.DetectorEta", "AnalysisJetsAuxDyn.JetConstitScaleMomentum_phi",
        "AnalysisJetsAuxDyn.ActiveArea4vec_phi", "AnalysisJetsAuxDyn.JetConstitScaleMomentum_m",
        "AnalysisJetsAuxDyn.JetConstitScaleMomentum_pt", "AnalysisJetsAuxDyn.EMFrac",
        "AnalysisJetsAuxDyn.Width", "AnalysisJetsAuxDyn.ActiveArea4vec_m", "AnalysisJetsAuxDyn.ActiveArea4vec_pt",
        "AnalysisJetsAuxDyn.DFCommonJets_QGTagger_TracksWidth", "AnalysisJetsAuxDyn.PSFrac",
        "AnalysisJetsAuxDyn.JVFCorr", "AnalysisJetsAuxDyn.DFCommonJets_QGTagger_TracksC1",
        "AnalysisJetsAuxDyn.DFCommonJets_fJvt", "AnalysisJetsAuxDyn.DFCommonJets_QGTagger_NTracks",
        "AnalysisJetsAuxDyn.GhostMuonSegmentCount", "AnalysisMuonsAuxDyn.muonSegmentLinks",
        "AnalysisMuonsAuxDyn.msOnlyExtrapolatedMuonSpectrometerTrackParticleLink",
        "AnalysisMuonsAuxDyn.extrapolatedMuonSpectrometerTrackParticleLink",
        "AnalysisMuonsAuxDyn.inDetTrackParticleLink", "AnalysisMuonsAuxDyn.muonSpectrometerTrackParticleLink",
        "AnalysisMuonsAuxDyn.momentumBalanceSignificance", "AnalysisMuonsAuxDyn.topoetcone20_CloseByCorr",
        "AnalysisMuonsAuxDyn.scatteringCurvatureSignificance", "AnalysisMuonsAuxDyn.scatteringNeighbourSignificance",
        "AnalysisMuonsAuxDyn.neflowisol20_CloseByCorr", "AnalysisMuonsAuxDyn.topoetcone20",
        "AnalysisMuonsAuxDyn.topoetcone30", "AnalysisMuonsAuxDyn.topoetcone40", "AnalysisMuonsAuxDyn.neflowisol20",
        "AnalysisMuonsAuxDyn.segmentDeltaEta", "AnalysisMuonsAuxDyn.DFCommonJetDr",
        "AnalysisMuonsAuxDyn.combinedTrackParticleLink", "AnalysisMuonsAuxDyn.InnerDetectorPt",
        "AnalysisMuonsAuxDyn.MuonSpectrometerPt", "AnalysisMuonsAuxDyn.clusterLink",
        "AnalysisMuonsAuxDyn.spectrometerFieldIntegral", "AnalysisElectronsAuxDyn.ambiguityLink",
        "AnalysisMuonsAuxDyn.EnergyLoss", "AnalysisJetsAuxDyn.NNJvtPass", "AnalysisElectronsAuxDyn.topoetcone20_CloseByCorr",
        "AnalysisElectronsAuxDyn.topoetcone20ptCorrection", "AnalysisElectronsAuxDyn.topoetcone20",
        "AnalysisMuonsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVA_pt500_CloseByCorr",
        "AnalysisElectronsAuxDyn.DFCommonElectronsECIDSResult", "AnalysisElectronsAuxDyn.neflowisol20",
        "AnalysisMuonsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVA_pt500", "AnalysisMuonsAuxDyn.ptcone40",
        "AnalysisMuonsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVA_pt1000_CloseByCorr",
        "AnalysisMuonsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVA_pt1000", "AnalysisMuonsAuxDyn.ptvarcone40",
        "AnalysisElectronsAuxDyn.f1", "AnalysisMuonsAuxDyn.ptcone20_Nonprompt_All_MaxWeightTTVA_pt500",
        "PrimaryVerticesAuxDyn.vertexType", "AnalysisMuonsAuxDyn.ptvarcone30", "AnalysisMuonsAuxDyn.ptcone30",
        "AnalysisMuonsAuxDyn.ptcone20_Nonprompt_All_MaxWeightTTVA_pt1000",
        "AnalysisElectronsAuxDyn.ptvarcone30_Nonprompt_All_MaxWeightTTVALooseCone_pt500", "AnalysisMuonsAuxDyn.CaloLRLikelihood"
    ]

    filter_name = lambda x: x in BRANCH_LIST

    with uproot.open({input_filenames:"CollectionTree"}, filter_name=filter_name) as f:
        branches = {}
        for b in BRANCH_LIST:
            try:
                branches[b] = f[b].array()
                if "m_persKey" in str(branches[b].type):
                    branches[b] = ak.count_nonzero(branches[b]["m_persIndex"], axis=-1)
                elif "var * var *" in str(branches[b].type):
                    branches[b] = ak.sum(branches[b], axis=1)                
            except:
                pass
        ak_table = ak.Array(branches)
        ak_table = ak_table[ak.min(ak_table["AnalysisElectronsAuxDyn.pt"] > ELECTRON_PT_THRESHOLD, axis=1)]

        return {"servicex_reduction": ak_table}

In [4]:
def query_servicex(
    ignore_cache: bool,
    num_files: int,
    ds_names: List[str],
    download: bool,
    query: Tuple[Any, str],
) -> Dict[str, List[str]]:
    """Load and execute the servicex query. Returns a complete list of paths
    (be they local or url's) for the root or parquet files.
    """
    logging.info("Building ServiceX query")

    # Do the query.
    # TODO: Where is the enum that does DeliveryEnum come from?
    # TODO: Why does `Sample` fail type checking - that type ignore has already hidden one bug!
    # TODO: If I change Name after running, cache seems to fail (name doesn't track).
    # TODO: If you change the name of the item you'll get a multiple cache hit!
    # TODO: `servicex cache list` doesn't work and can't figure out how to make it work.
    # TODO: servicex_query_cache.json is being ignored (feature?)
    # TODO: Why does OutputFormat and delivery not work as enums? And fail typechecking with
    #       strings?
    # TODO: If some of these submissions work and others do not, we lose the ability to track the
    #       ones we fired off.
    #       an example is a title that is longer than 128 characters causes an immediate crash -
    #       but other queries
    #       already worked. Cache recovery @ the server would mean this wasn't important.

    spec = sx.ServiceXSpec(
        General=sx.General(
            ServiceX="uc-af",
            Codegen=query[1],
            OutputFormat=sx.ResultFormat.root,  # type: ignore
            Delivery=("LocalCache" if download else "SignedURLs"),  # type: ignore
        ),
        Sample=[
            sx.Sample(
                Name=f"speed_test_{ds_name}"[0:128],
                RucioDID=ds_name,
                Query=query[0],
                NFiles=num_files,
                IgnoreLocalCache=ignore_cache,
            )  # type: ignore
            for ds_name in ds_names
        ],
    )
    for ds_name in ds_names:
        logging.info(f"Querying dataset {ds_name}")
    if num_files == 0:
        logging.info("Running on the full dataset.")
    else:
        logging.info(f"Running on {num_files} files of dataset.")

    logging.info("Starting ServiceX query")
    results = sx.deliver(spec)
    assert results is not None
    return results

In [6]:
def calculate_total_count(
    ds_name: str, steps_per_file: int, files: List[str]
) -> Tuple[Any, Any]:
    """Calculate the non zero fields in the files.

    Args:
        steps_per_file (int): The number of steps to split the file into.
        files (List[str]): The list of files in which to count the fields.

    Returns:
        _: DASK graph for the total count.
    """
    data, report_to_be = uproot.dask(
        {f: "servicex_reduction" for f in files},
        open_files=False,
        steps_per_file=steps_per_file,
        allow_read_errors_with_report=True,
    )

    # Now, do the counting.
    # The straight forward way to do this leads to a very large dask graph. We can
    # do a little prep work here and make it more clean.
    logging.debug(
        f"{ds_name}: Generating the dask compute graph for"
        f" {len(data.fields)} fields"  # type: ignore
    )

    total_count = 0
    assert isinstance(data, dak.Array)  # type: ignore
    for field in data.fields:
        logging.debug(f"{ds_name}: Counting field {field}")
        if str(data[field].type.content).startswith("var"):
            count = ak.count_nonzero(data[field], axis=-1)
            for _ in range(count.ndim - 1):  # type: ignore
                count = ak.count_nonzero(count, axis=-1)

            total_count = total_count + count  # type: ignore
        else:
            # We get a not implemented error when we try to do this
            # on leaves like run-number or event-number (e.g. scalars)
            # Maybe we should just be adding a 1. :-)
            logging.debug(
                f"{ds_name}: Field {field} is not a scalar field. Skipping count."
            )

    total_count = ak.count_nonzero(total_count, axis=0)

    n_optimized_tasks = len(dask.optimize(total_count)[0].dask)  # type: ignore
    logging.log(
        logging.INFO,
        f"{ds_name}: Number of tasks in the dask graph: optimized: "
        f"{n_optimized_tasks:,} "  # type: ignore
        f"unoptimized: {len(total_count.dask):,}",  # type: ignore
    )

    # total_count.visualize(optimize_graph=True)  # type: ignore
    # opt = Path("mydask.png")
    # opt.replace("dask-optimized.png")
    # total_count.visualize(optimize_graph=False)  # type: ignore
    # opt.replace("dask-unoptimized.png")

    return report_to_be, total_count

In [5]:
def main(
    ignore_cache: bool = False,
    num_files: int = 10,
    dask_report: bool = False,
    ds_names: Optional[List[str]] = None,
    download_sx_result: bool = False,
    steps_per_file: int = 3,
    query: Optional[Tuple[Any, str]] = None,
):
    """Match the operations found in `materialize_branches` notebook:
    Load all the branches from some dataset, and then count the flattened
    number of items, and, finally, print them out.
    """
    assert query is not None, "No query provided to run."

    # Make sure there is a file here to save the SX query ID's to
    # improve performance!
    sx_query_ids = Path("./servicex_query_cache.json")
    if not sx_query_ids.exists():
        sx_query_ids.touch()

    assert ds_names is not None
    dataset_files = query_servicex(
        ignore_cache=ignore_cache,
        num_files=num_files,
        ds_names=ds_names,
        download=download_sx_result,
        query=query,
    )

    for ds, files in dataset_files.items():
        logging.info(f"Dataset {ds} has {len(files)} files")
        assert len(files) > 0, "No files found in the dataset"

    # now materialize everything.
    logging.info(
        f"Using `uproot.dask` to open files (splitting files {steps_per_file} ways)."
    )
    # The 20 steps per file was tuned for this query and 8 CPU's and 32 GB of memory.
    logging.info("Starting build of DASK graphs")
    all_dask_data = {
        k: calculate_total_count(k, steps_per_file, files)
        for k, files in dataset_files.items()
    }
    logging.info("Done building DASK graphs.")

    # Do the calc now.
    logging.info("Computing the total count")
    all_tasks = {k: v[1] for k, v in all_dask_data.items()}
    if dask_report:
        with performance_report(filename="dask-report.html"):
            results = dask.compute(*all_tasks.values())  # type: ignore
            result_dict = dict(zip(all_tasks.keys(), results))
    else:
        results = dask.compute(*all_tasks.values())  # type: ignore
        result_dict = dict(zip(all_tasks.keys(), results))

    for k, r in result_dict.items():
        logging.info(f"{k}: result = {r:,}")

    # Scan through for any exceptions that happened during the dask processing.
    all_report_tasks = {k: v[0] for k, v in all_dask_data.items()}
    all_reports = dask.compute(*all_report_tasks.values())  # type: ignore
    for k, report_list in zip(all_report_tasks.keys(), all_reports):
        for process in report_list:
            if process.exception is not None:
                logging.error(
                    f"Exception in process '{process.message}' on file {process.args[0]} "
                    "for ds {k}"
                )

In [None]:
handler = logging.StreamHandler()
handler.setFormatter(ElapsedFormatter())
root_logger = logging.getLogger()

# Set the logging level based on the verbosity flag.
# make sure the time comes out so people can "track" what is going on.
if VERBOSE:
    root_logger.setLevel(level=logging.INFO)
else:
    root_logger.setLevel(level=logging.WARNING)
root_logger.addHandler(handler)


steps_per_file = 1
if DASK_CLIENT == "local":
    # Do not know how to do it otherwise.
    n_workers = 8
    logging.debug("Creating local Dask cluster for {n_workers} workers")
    cluster = LocalCluster(
        n_workers=n_workers, processes=False, threads_per_worker=1
    )
    client = Client(cluster)
    steps_per_file = 20
elif DASK_CLIENT == "scheduler":
    logging.debug("Connecting to Dask scheduler at {DASK_SCHEDULER_ADDRESS}")
    client = Client(DASK_SCHEDULER_ADDRESS)
    steps_per_file = 2


ds_names = determine_dataset(DATASET)

# Build the query
query = (run_query, "python")


# Now run the main function
if PROFILE is False:
    main(
        ignore_cache=IGNORE_CACHE,
        dask_report=DASK_REPORT,
        ds_names=ds_names,
        download_sx_result=DOWNLOAD_SX_RESULT,
        steps_per_file=steps_per_file,
        query=query,
    )
else:
    cProfile.run(
        "main(ignore_cache=IGNORE_CACHE, "
        "dask_report=DASK_REPORT, ds_name = ds_name, "
        "download_sx_result=DOWNLOAD_SX_RESULT, steps_per_file=steps_per_file"
        "query=query)",
        "sx_materialize_branches.pstats",
    )
    logging.info("Profiling data saved to `sx_materialize_branches.pstats`")