Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-40441: Rewrite HiPS custom QG builder to inherit from new base class. #913

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
273 changes: 133 additions & 140 deletions python/lsst/pipe/tasks/hips.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""Tasks for making and manipulating HIPS images."""

__all__ = ["HighResolutionHipsTask", "HighResolutionHipsConfig", "HighResolutionHipsConnections",
"HighResolutionHipsQuantumGraphBuilder",
"GenerateHipsTask", "GenerateHipsConfig", "GenerateColorHipsTask", "GenerateColorHipsConfig"]

from collections import defaultdict
Expand All @@ -41,9 +42,11 @@

from lsst.sphgeom import RangeSet, HealpixPixelization
from lsst.utils.timer import timeMethod
from lsst.daf.butler import Butler, DataCoordinate, DatasetRef, Quantum
from lsst.daf.butler import Butler
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.pipe.base.quantum_graph_builder import QuantumGraphBuilder
from lsst.pipe.base.quantum_graph_skeleton import QuantumGraphSkeleton, DatasetKey
import lsst.afw.geom as afwGeom
import lsst.afw.math as afwMath
import lsst.afw.image as afwImage
Expand Down Expand Up @@ -331,19 +334,19 @@ def build_quantum_graph_cli(cls, argv):
sys.exit(1)

pipeline = pipeBase.Pipeline.from_uri(args.pipeline)
expanded_pipeline = list(pipeline.toExpandedPipeline())
pipeline_graph = pipeline.to_graph()

if len(expanded_pipeline) != 1:
if len(pipeline_graph.tasks) != 1:
raise RuntimeError(f"Pipeline file {args.pipeline} may only contain one task.")

(task_def,) = expanded_pipeline
(task_node,) = pipeline_graph.tasks.values()

butler = Butler(args.butler_config, collections=args.input)

if args.subparser_name == "segment":
# Do the segmentation
hpix_pixelization = HealpixPixelization(level=args.hpix_build_order)
dataset = task_def.connections.coadd_exposure_handles.name
dataset = task_node.inputs["coadd_exposure_handles"].dataset_type_name
data_ids = set(butler.registry.queryDataIds("tract", datasets=dataset).expanded())
region_pixels = []
for data_id in data_ids:
Expand Down Expand Up @@ -378,15 +381,16 @@ def build_quantum_graph_cli(cls, argv):
"time": f"{datetime.now()}",
}

qg = cls.build_quantum_graph(
task_def,
butler.registry,
args.hpix_build_order,
build_ranges,
builder = HighResolutionHipsQuantumGraphBuilder(
pipeline_graph,
butler,
input_collections=args.input,
output_run=args.output_run,
constraint_order=args.hpix_build_order,
constraint_ranges=build_ranges,
where=args.where,
collections=args.input,
metadata=metadata,
)
qg = builder.build(metadata, attach_datastore_records=True)
qg.saveUri(args.save_qgraph)

@classmethod
Expand Down Expand Up @@ -488,95 +492,101 @@ def _make_cli_parser(cls):

return parser

@classmethod
def build_quantum_graph(
cls,
task_def,
registry,

class HighResolutionHipsQuantumGraphBuilder(QuantumGraphBuilder):
"""A custom a `lsst.pipe.base.QuantumGraphBuilder` for running
`HighResolutionHipsTask` only.

This is a temporary workaround for incomplete butler query support for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a ticket number for the real solution?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, and I'm not sure how I want to do it yet; there are a few options but none seem terribly close. Might be best to drop the "temporary" instead.

HEALPix dimensions.

Parameters
----------
pipeline_graph : `lsst.pipe.base.PipelineGraph`
Pipeline graph with exactly one task, which must be a configuration
of `HighResolutionHipsTask`.
butler : `lsst.daf.butler.Butler`
Client for the butler data repository. May be read-only.
input_collections : `str` or `Iterable` [ `str` ], optional
Collection or collections to search for input datasets, in order.
If not provided, ``butler.collections`` will be searched.
output_run : `str`, optional
Name of the output collection. If not provided, ``butler.run`` will
be used.
constraint_order : `int`
HEALPix order used to constrain which quanta are generated, via
``constraint_indices``. This should be a coarser grid (smaller
order) than the order used for the task's quantum and output data
IDs, and ideally something between the spatial scale of a patch or
the data repository's "common skypix" system (usually ``htm7``).
constraint_ranges : `lsst.sphgeom.RangeSet`
RangeSet that describes constraint pixels (HEALPix NEST, with order
``constraint_order``) to constrain generated quanta.
where : `str`, optional
A boolean `str` expression of the form accepted by
`Registry.queryDatasets` to constrain input datasets. This may
contain a constraint on tracts, patches, or bands, but not HEALPix
indices. Constraints on tracts and patches should usually be
unnecessary, however - existing coadds that overlap the given
HEALpix indices will be selected without such a constraint, and
providing one may reject some that should normally be included.
"""

def __init__(
self,
pipeline_graph,
butler,
*,
input_collections=None,
output_run=None,
constraint_order,
constraint_ranges,
where=None,
collections=None,
metadata=None,
where="",
):
"""Generate a `QuantumGraph` for running just this task.

This is a temporary workaround for incomplete butler query support for
HEALPix dimensions.
super().__init__(pipeline_graph, butler, input_collections=input_collections, output_run=output_run)
self.constraint_order = constraint_order
self.constraint_ranges = constraint_ranges
self.where = where

Parameters
----------
task_def : `lsst.pipe.base.TaskDef`
Task definition.
registry : `lsst.daf.butler.Registry`
Client for the butler database. May be read-only.
constraint_order : `int`
HEALPix order used to contrain which quanta are generated, via
``constraint_indices``. This should be a coarser grid (smaller
order) than the order used for the task's quantum and output data
IDs, and ideally something between the spatial scale of a patch or
the data repository's "common skypix" system (usually ``htm7``).
constraint_ranges : `lsst.sphgeom.RangeSet`
RangeSet which describes constraint pixels (HEALPix NEST, with order
constraint_order) to constrain generated quanta.
where : `str`, optional
A boolean `str` expression of the form accepted by
`Registry.queryDatasets` to constrain input datasets. This may
contain a constraint on tracts, patches, or bands, but not HEALPix
indices. Constraints on tracts and patches should usually be
unnecessary, however - existing coadds that overlap the given
HEALpix indices will be selected without such a constraint, and
providing one may reject some that should normally be included.
collections : `str` or `Iterable` [ `str` ], optional
Collection or collections to search for input datasets, in order.
If not provided, ``registry.defaults.collections`` will be
searched.
metadata : `dict` [ `str`, `Any` ]
Graph metadata. It is required to contain "output_run" key with the
name of the output RUN collection.
"""
config = task_def.config
def process_subgraph(self, subgraph):
# Docstring inherited.
(task_node,) = subgraph.tasks.values()

dataset_types = pipeBase.PipelineDatasetTypes.fromPipeline(pipeline=[task_def], registry=registry)
# Since we know this is the only task in the pipeline, we know there
# is only one overall input and one overall output.
(input_dataset_type,) = dataset_types.inputs

# Extract the main output dataset type (which needs multiple
# DatasetRefs, and tells us the output HPX level), and make a set of
# what remains for more mechanical handling later.
output_dataset_type = dataset_types.outputs[task_def.connections.hips_exposures.name]
incidental_output_dataset_types = dataset_types.outputs.copy()
incidental_output_dataset_types.remove(output_dataset_type)
# is only one overall input and one regular output.
(input_dataset_type_node,) = subgraph.inputs_of(task_node.label).values()
assert input_dataset_type_node is not None, "PipelineGraph should be resolved by base class."
(output_edge,) = task_node.outputs.values()
output_dataset_type_node = subgraph.dataset_types[output_edge.parent_dataset_type_name]
(hpx_output_dimension,) = (
registry.dimensions.skypix_dimensions[d] for d in output_dataset_type.dimensions.skypix.names
self.butler.dimensions.skypix_dimensions[d]
for d in output_dataset_type_node.dimensions.skypix.names
)

constraint_hpx_pixelization = registry.dimensions[f"healpix{constraint_order}"].pixelization
common_skypix_name = registry.dimensions.commonSkyPix.name
common_skypix_pixelization = registry.dimensions.commonSkyPix.pixelization
constraint_hpx_pixelization = (
self.butler.dimensions.skypix_dimensions[f"healpix{self.constraint_order}"].pixelization
)
common_skypix_name = self.butler.dimensions.commonSkyPix.name
common_skypix_pixelization = self.butler.dimensions.commonSkyPix.pixelization

# We will need all the pixels at the quantum resolution as well
task_dimensions = registry.dimensions.conform(task_def.connections.dimensions)
(hpx_dimension,) = (
registry.dimensions.skypix_dimensions[d] for d in task_dimensions.names if d != "band"
self.butler.dimensions.skypix_dimensions[d] for d in task_node.dimensions.names if d != "band"
)
hpx_pixelization = hpx_dimension.pixelization
if hpx_pixelization.level < self.constraint_order:
raise ValueError(f"Quantum order {hpx_pixelization.level} must be < {self.constraint_order}")
hpx_ranges = self.constraint_ranges.scaled(4**(hpx_pixelization.level - self.constraint_order))

if hpx_pixelization.level < constraint_order:
raise ValueError(f"Quantum order {hpx_pixelization.level} must be < {constraint_order}")
hpx_ranges = constraint_ranges.scaled(4**(hpx_pixelization.level - constraint_order))

# We can be generous in looking for pixels here, because we constraint by actual
# patch regions below.
# We can be generous in looking for pixels here, because we constrain
# by actual patch regions below.
common_skypix_ranges = RangeSet()
for begin, end in constraint_ranges:
for begin, end in self.constraint_ranges:
for hpx_index in range(begin, end):
constraint_hpx_region = constraint_hpx_pixelization.pixel(hpx_index)
common_skypix_ranges |= common_skypix_pixelization.envelope(constraint_hpx_region)

# To keep the query from getting out of hand (and breaking) we simplify until we have fewer
# than 100 ranges which seems to work fine.
# To keep the query from getting out of hand (and breaking) we simplify
# until we have fewer than 100 ranges which seems to work fine.
for simp in range(1, 10):
if len(common_skypix_ranges) < 100:
break
Expand All @@ -596,24 +606,26 @@ def build_quantum_graph(
where_terms.append(f"({common_skypix_name} >= cpx{n}a AND {common_skypix_name} <= cpx{n}b)")
bind[f"cpx{n}a"] = begin
bind[f"cpx{n}b"] = stop
if where is None:
if not self.where:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is pre-existing but you are always calculating the join of where_terms.

where = " OR ".join(where_terms)
if self.where:
    where = f"({self.where}) AND ({where})"

might be better?

where = " OR ".join(where_terms)
else:
where = f"({where}) AND ({' OR '.join(where_terms)})"
where = f"({self.where}) AND ({' OR '.join(where_terms)})"
# Query for input datasets with this constraint, and ask for expanded
# data IDs because we want regions. Immediately group this by patch so
# we don't do later geometric stuff n_bands more times than we need to.
input_refs = registry.queryDatasets(
input_dataset_type,
input_refs = self.butler.registry.queryDatasets(
input_dataset_type_node.dataset_type,
where=where,
findFirst=True,
collections=collections,
collections=self.input_collections,
bind=bind
).expanded()
inputs_by_patch = defaultdict(set)
patch_dimensions = registry.dimensions.conform(["patch"])
patch_dimensions = self.butler.dimensions.conform(["patch"])
for input_ref in input_refs:
inputs_by_patch[input_ref.dataId.subset(patch_dimensions)].add(input_ref)
dataset_key = DatasetKey(input_ref.datasetType.name, input_ref.dataId.required_values)
self.existing_datasets.inputs[dataset_key] = input_ref
inputs_by_patch[input_ref.dataId.subset(patch_dimensions)].add(dataset_key)
if not inputs_by_patch:
message_body = "\n".join(input_refs.explain_no_results())
raise RuntimeError(f"No inputs found:\n{message_body}")
Expand All @@ -622,66 +634,47 @@ def build_quantum_graph(
# that overlap each one. Use that to associate inputs with output
# pixels, but only for the output pixels we've already identified.
inputs_by_hpx = defaultdict(set)
for patch_data_id, input_refs_for_patch in inputs_by_patch.items():
for patch_data_id, input_keys_for_patch in inputs_by_patch.items():
patch_hpx_ranges = hpx_pixelization.envelope(patch_data_id.region)
for begin, end in patch_hpx_ranges & hpx_ranges:
for hpx_index in range(begin, end):
inputs_by_hpx[hpx_index].update(input_refs_for_patch)
# Iterate over the dict we just created and create the actual quanta.
quanta = []
output_run = metadata["output_run"]
for hpx_index, input_refs_for_hpx_index in inputs_by_hpx.items():
inputs_by_hpx[hpx_index].update(input_keys_for_patch)

# Iterate over the dict we just created and create preliminary quanta.
skeleton = QuantumGraphSkeleton([task_node.label])
for hpx_index, input_keys_for_hpx_index in inputs_by_hpx.items():
# Group inputs by band.
input_refs_by_band = defaultdict(list)
for input_ref in input_refs_for_hpx_index:
input_refs_by_band[input_ref.dataId["band"]].append(input_ref)
input_keys_by_band = defaultdict(list)
for input_key in input_keys_for_hpx_index:
input_ref = self.existing_datasets.inputs[input_key]
input_keys_by_band[input_ref.dataId["band"]].append(input_key)
# Iterate over bands to make quanta.
for band, input_refs_for_band in input_refs_by_band.items():
data_id = registry.expandDataId({hpx_dimension: hpx_index, "band": band})

for band, input_keys_for_band in input_keys_by_band.items():
data_id = self.butler.registry.expandDataId({hpx_dimension.name: hpx_index, "band": band})
quantum_key = skeleton.add_quantum_node(task_node.label, data_id)
# Add inputs to the skelton
skeleton.add_input_edges(quantum_key, input_keys_for_band)
# Add the regular outputs.
hpx_pixel_ranges = RangeSet(hpx_index)
hpx_output_ranges = hpx_pixel_ranges.scaled(4**(config.hips_order - hpx_pixelization.level))
output_data_ids = []
hpx_output_ranges = hpx_pixel_ranges.scaled(
4**(task_node.config.hips_order - hpx_pixelization.level)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a note anywhere explaining where the 4 comes from? (it turns up in a couple of places).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the number of healpixels at level N+1 in each healpixel at level N. I'll add a comment somewhere when I get a chance.

)
for begin, end in hpx_output_ranges:
for hpx_output_index in range(begin, end):
output_data_ids.append(
registry.expandDataId({hpx_output_dimension: hpx_output_index, "band": band})
dataset_key = skeleton.add_dataset_node(
output_dataset_type_node.name,
self.butler.registry.expandDataId(
{hpx_output_dimension: hpx_output_index, "band": band}
),
)
outputs = {
dt: [DatasetRef(dt, data_id, run=output_run)] for dt in incidental_output_dataset_types
}
outputs[output_dataset_type] = [DatasetRef(output_dataset_type, data_id, run=output_run)
for data_id in output_data_ids]
quanta.append(
Quantum(
taskName=task_def.taskName,
taskClass=task_def.taskClass,
dataId=data_id,
initInputs={},
inputs={input_dataset_type: input_refs_for_band},
outputs=outputs,
)
)

if len(quanta) == 0:
raise RuntimeError("Given constraints yielded empty quantum graph.")

# Define initOutputs refs.
empty_data_id = DataCoordinate.make_empty(registry.dimensions)
init_outputs = {}
global_init_outputs = []
if config_dataset_type := dataset_types.initOutputs.get(task_def.configDatasetName):
init_outputs[task_def] = [DatasetRef(config_dataset_type, empty_data_id, run=output_run)]
packages_dataset_name = pipeBase.PipelineDatasetTypes.packagesDatasetName
if packages_dataset_type := dataset_types.initOutputs.get(packages_dataset_name):
global_init_outputs.append(DatasetRef(packages_dataset_type, empty_data_id, run=output_run))

return pipeBase.QuantumGraph(
quanta={task_def: quanta},
initOutputs=init_outputs,
globalInitOutputs=global_init_outputs,
metadata=metadata,
)
skeleton.add_output_edge(quantum_key, dataset_key)
# Add auxiliary outputs (log, metadata).
for write_edge in task_node.iter_all_outputs():
if write_edge.connection_name == output_edge.connection_name:
continue
dataset_key = skeleton.add_dataset_node(write_edge.parent_dataset_type_name, data_id)
skeleton.add_output_edge(quantum_key, dataset_key)
return skeleton


class HipsPropertiesSpectralTerm(pexConfig.Config):
Expand Down