Skip to content

Commit

Permalink
Abort submission if Quantum missing value.
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleGower committed Apr 29, 2022
1 parent 1c4f7c9 commit 2e30eb2
Show file tree
Hide file tree
Showing 7 changed files with 906 additions and 20 deletions.
2 changes: 2 additions & 0 deletions doc/changes/DM-34265.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* Fix cluster naming bug where variables in clusterTemplate replaced too early.
* Fix cluster naming bug if no clusterTemplate nor templateDataId given.
3 changes: 3 additions & 0 deletions doc/changes/DM-34265.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
* Abort submission if a Quantum is missing a dimension required by the clustering definition.
* Abort submission if clustering definition results in cycles in the ClusteredQuantumGraph.
* Add unit tests for the quantum clustering functions.
6 changes: 5 additions & 1 deletion python/lsst/ctrl/bps/clustered_quantum_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def from_quantum_node(cls, quantum_node, template):

# Use dictionary plus template format string to create name. To avoid
# key errors from generic patterns, use defaultdict.
name = template.format_map(defaultdict(lambda: "", info))
try:
name = template.format_map(defaultdict(lambda: "", info))
except TypeError:
_LOG.error("Problems creating cluster name. template='%s', info=%s", template, info)
raise
name = re.sub("_+", "_", name)
_LOG.debug("template name = %s", name)

Expand Down
104 changes: 85 additions & 19 deletions python/lsst/ctrl/bps/quantum_clustering_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from collections import defaultdict

from lsst.pipe.base import NodeId
from networkx import DiGraph, is_directed_acyclic_graph

from . import ClusteredQuantumGraph, QuantaCluster

Expand Down Expand Up @@ -71,7 +72,7 @@ def single_quantum_clustering(config, qgraph, name):
if found:
template = "{node_number}_{label}_" + template_data_id
else:
template = "{node_number:08d}"
template = "{node_number}"
cached_template[qnode.taskDef.label] = template

cluster = QuantaCluster.from_quantum_node(qnode, cached_template[qnode.taskDef.label])
Expand All @@ -91,6 +92,66 @@ def single_quantum_clustering(config, qgraph, name):
return cqgraph


def _check_clusters_tasks(cluster_config, taskGraph):
"""Check cluster definitions in terms of pipetask lists.
Parameters
----------
cluster_config : `lsst.ctrl.bps.BpsConfig`
The cluster section from the BPS configuration.
taskGraph : `lsst.pipe.base.taskGraph`
Directed graph of tasks.
Returns
-------
task_labels : `set` [`str`]
Set of task labels from the cluster definitions.
Raises
-------
RuntimeError
Raised if task label appears in more than one cluster def or
if there's a cycle in the cluster defs.
"""

# Build a "clustered" task graph to check for cycle.
task_to_cluster = {}
task_labels = set()
clustered_task_graph = DiGraph()

# Create clusters based on given configuration.
for cluster_label in cluster_config:
_LOG.debug("cluster = %s", cluster_label)
cluster_tasks = [pt.strip() for pt in cluster_config[cluster_label]["pipetasks"].split(",")]
for task_label in cluster_tasks:
if task_label in task_labels:
raise RuntimeError(
f"Task label {task_label} appears in more than one cluster definition. "
"Aborting submission."
)
task_labels.add(task_label)
task_to_cluster[task_label] = cluster_label
clustered_task_graph.add_node(cluster_label)

# Create clusters for tasks not covered by clusters.
for task in taskGraph:
if task.label not in task_labels:
task_to_cluster[task.label] = task.label
clustered_task_graph.add_node(task.label)

# Create dependencies between clusters.
for edge in taskGraph.edges:
if task_to_cluster[edge[0].label] != task_to_cluster[edge[1].label]:
clustered_task_graph.add_edge(task_to_cluster[edge[0].label], task_to_cluster[edge[1].label])

_LOG.debug("clustered_task_graph.edges = %s", [e for e in clustered_task_graph.edges])

if not is_directed_acyclic_graph(clustered_task_graph):
raise RuntimeError("Cluster pipetasks do not create a DAG")

return task_labels


def dimension_clustering(config, qgraph, name):
"""Follow config instructions to make clusters based upon dimensions.
Expand All @@ -115,33 +176,26 @@ def dimension_clustering(config, qgraph, name):
# save mapping in order to create dependencies later
quantum_to_cluster = {}

# save which task labels have been handled
task_labels_seen = set()

cluster_config = config["cluster"]
task_labels = _check_clusters_tasks(cluster_config, qgraph.taskGraph)
for cluster_label in cluster_config:
_LOG.debug("cluster = %s", cluster_label)
cluster_dims = []
if "dimensions" in cluster_config[cluster_label]:
cluster_dims = [d.strip() for d in cluster_config[cluster_label]["dimensions"].split(",")]
_LOG.debug("cluster_dims = %s", cluster_dims)

if "clusterTemplate" in cluster_config[cluster_label]:
template = cluster_config[cluster_label]["clusterTemplate"]
elif cluster_dims:
template = f"{cluster_label}_" + "_".join(f"{{{dim}}}" for dim in cluster_dims)
else:
template = cluster_label
found, template = cluster_config[cluster_label].search("clusterTemplate", opt={"replaceVars": False})
if not found:
if cluster_dims:
template = f"{cluster_label}_" + "_".join(f"{{{dim}}}" for dim in cluster_dims)
else:
template = cluster_label
_LOG.debug("template = %s", template)

cluster_tasks = [pt.strip() for pt in cluster_config[cluster_label]["pipetasks"].split(",")]
for task_label in cluster_tasks:
if task_label in task_labels_seen:
raise ValueError(
f"Task label {task_label} appears in more than one cluster definition. "
"Aborting submission."
)
task_labels_seen.add(task_label)
task_labels.add(task_label)

# Currently getQuantaForTask is currently a mapping taskDef to
# Quanta, so quick enough to call repeatedly.
Expand All @@ -157,22 +211,33 @@ def dimension_clustering(config, qgraph, name):
# Gather info for cluster name template into a dictionary.
info = {}

missing_info = set()
data_id_info = qnode.quantum.dataId.byName()
for dim_name in cluster_dims:
_LOG.debug("dim_name = %s", dim_name)
if dim_name in data_id_info:
info[dim_name] = data_id_info[dim_name]
else:
missing_info.add(dim_name)
if equal_dims:
for pair in [pt.strip() for pt in equal_dims.split(",")]:
dim1, dim2 = pair.strip().split(":")
if dim1 in cluster_dims and dim2 in data_id_info:
info[dim1] = data_id_info[dim2]
missing_info.remove(dim1)
elif dim2 in cluster_dims and dim1 in data_id_info:
info[dim2] = data_id_info[dim1]
missing_info.remove(dim2)

info["label"] = cluster_label
_LOG.debug("info for template = %s", info)

if missing_info:
raise RuntimeError(
"Quantum %s (%s) missing dimensions %s required for cluster %s"
% (qnode.nodeId, data_id_info, ",".join(missing_info), cluster_label)
)

# Use dictionary plus template format string to create name.
# To avoid # key errors from generic patterns, use defaultdict.
cluster_name = template.format_map(defaultdict(lambda: "", info))
Expand All @@ -197,15 +262,15 @@ def dimension_clustering(config, qgraph, name):

# Assume any task not handled above is supposed to be 1 cluster = 1 quantum
for task_def in qgraph.iterTaskGraph():
if task_def.label not in task_labels_seen:
if task_def.label not in task_labels:
_LOG.info("Creating 1-quantum clusters for task %s", task_def.label)
found, template_data_id = config.search(
"templateDataId", opt={"curvals": {"curr_pipetask": task_def.label}, "replaceVars": False}
)
if found:
template = "{node_number}_{label}_" + template_data_id
else:
template = "{node_number:08d}"
template = "{node_number}"

for qnode in qgraph.getNodesForTask(task_def):
cluster = QuantaCluster.from_quantum_node(qnode, template)
Expand All @@ -222,7 +287,8 @@ def dimension_clustering(config, qgraph, name):
cqgraph.add_dependency(
quantum_to_cluster[parent.nodeId], quantum_to_cluster[child.nodeId]
)
except KeyError as e:
except KeyError as e: # pragma: no cover
# For debugging a problem internal to method
nid = NodeId(e.args[0], qgraph.graphID)
qnode = qgraph.getQuantumNodeByNodeId(nid)

Expand Down
160 changes: 160 additions & 0 deletions tests/cqg_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# This file is part of ctrl_bps.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""ClusteredQuantumGraph-related utilities to support ctrl_bps testing.
"""

import uuid

from networkx import is_directed_acyclic_graph


def check_cqg(cqg, truth=None):
"""Check ClusteredQuantumGraph for correctness used by unit
tests.
Parameters
----------
cqg : `lsst.ctrl.bps.ClusteredQuantumGraph`
truth : `dict` [`str`, `Any`], optional
Information describing what this cluster should look like.
"""
# Checks independent of data

# Check no cycles, only one edge between same two nodes,
assert is_directed_acyclic_graph(cqg._cluster_graph)

# Check has all QGraph nodes (include message about duplicate node).
node_ids = set()
cl_by_label = {}
for cluster in cqg.clusters():
cl_by_label.setdefault(cluster.label, []).append(cluster)
for id_ in cluster.qgraph_node_ids:
qnode = cqg.get_quantum_node(id_)
assert id_ not in node_ids, (
f"Checking cluster {cluster.name}, id {id_} ({qnode.quantum.dataId.byName()}) appears more "
"than once in CQG."
)
node_ids.add(id_)
assert len(node_ids) == len(cqg._quantum_graph)

# If given what should be there, check values.
if truth:
cqg_info = dump_cqg(cqg)
compare_cqg_dicts(truth, cqg_info)


def replace_node_name(name, label, dims):
"""Replace node id in cluster name because they
change every run and thus make testing difficult.
Parameters
----------
name : `str`
Cluster name
label : `str`
Cluster label
dims : `dict` [`str`, `Any`]
Dimension names and values in order to make new name unique.
Returns
-------
name : `str`
New name of cluster.
"""
try:
name_parts = name.split("_")
_ = uuid.UUID(name_parts[0])
if len(name_parts) == 1:
name = f"NODEONLY_{label}_{str(dims)}"
else:
name = f"NODENAME_{'_'.join(name_parts[1:])}"
except ValueError:
pass
return name


def dump_cqg(cqg):
"""Represent ClusteredQuantumGraph as dictionary for testing.
Parameters
----------
cqg : `lsst.ctrl.bps.ClusteredQuantumGraph`
ClusteredQuantumGraph to be represented as a dictionary.
Returns
-------
info : `dict` [`str`, `Any`]
Dictionary represention of ClusteredQuantumGraph.
"""
info = {"name": cqg.name, "nodes": {}}

orig_to_new = {}
for cluster in cqg.clusters():
dims = {}
for key, value in cluster.tags.items():
if key not in ["label", "node_number"]:
dims[key] = value
name = replace_node_name(cluster.name, cluster.label, dims)
orig_to_new[cluster.name] = name
info["nodes"][name] = {"label": cluster.label, "dims": dims, "counts": dict(cluster.quanta_counts)}

info["edges"] = []
for edge in cqg._cluster_graph.edges:
info["edges"].append((orig_to_new[edge[0]], orig_to_new[edge[1]]))

return info


def compare_cqg_dicts(truth, cqg):
"""Compare dicts representing two ClusteredQuantumGraphs.
Parameters
----------
truth : `dict` [`str`, `Any`]
Representation of the expected ClusteredQuantumGraph.
cqg : `dict` [`str`, `Any`]
Representation of the calculated ClusteredQuantumGraph.
Raises
------
AssertionError
Whenever discover discrepancy between dicts.
"""
assert truth["name"] == cqg["name"], f"Mismatch name: truth={truth['name']}, cqg={cqg['name']}"
assert len(truth["nodes"]) == len(
cqg["nodes"]
), f"Mismatch number of nodes: truth={len(truth['nodes'])}, cqg={len(cqg['nodes'])}"
for tkey in truth["nodes"]:
assert tkey in cqg["nodes"], f"Could not find {tkey} in cqg"
tnode = truth["nodes"][tkey]
cnode = cqg["nodes"][tkey]
assert (
tnode["label"] == cnode["label"]
), f"Mismatch cluster label: truth={tnode['label']}, cqg={cnode['label']}"
assert (
tnode["dims"] == cnode["dims"]
), f"Mismatch cluster dims: truth={tnode['dims']}, cqg={cnode['dims']}"
assert (
tnode["counts"] == cnode["counts"]
), f"Mismatch cluster quanta counts: truth={tnode['counts']}, cqg={cnode['counts']}"
assert set(truth["edges"]) == set(
cqg["edges"]
), f"Mismatch edges: truth={truth['edges']}, cqg={cqg['edges']}"

0 comments on commit 2e30eb2

Please sign in to comment.