-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DM-40441: Rewrite HiPS custom QG builder to inherit from new base class. #913
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
"""Tasks for making and manipulating HIPS images.""" | ||
|
||
__all__ = ["HighResolutionHipsTask", "HighResolutionHipsConfig", "HighResolutionHipsConnections", | ||
"HighResolutionHipsQuantumGraphBuilder", | ||
"GenerateHipsTask", "GenerateHipsConfig", "GenerateColorHipsTask", "GenerateColorHipsConfig"] | ||
|
||
from collections import defaultdict | ||
|
@@ -41,9 +42,11 @@ | |
|
||
from lsst.sphgeom import RangeSet, HealpixPixelization | ||
from lsst.utils.timer import timeMethod | ||
from lsst.daf.butler import Butler, DataCoordinate, DatasetRef, Quantum | ||
from lsst.daf.butler import Butler | ||
import lsst.pex.config as pexConfig | ||
import lsst.pipe.base as pipeBase | ||
from lsst.pipe.base.quantum_graph_builder import QuantumGraphBuilder | ||
from lsst.pipe.base.quantum_graph_skeleton import QuantumGraphSkeleton, DatasetKey | ||
import lsst.afw.geom as afwGeom | ||
import lsst.afw.math as afwMath | ||
import lsst.afw.image as afwImage | ||
|
@@ -331,19 +334,19 @@ def build_quantum_graph_cli(cls, argv): | |
sys.exit(1) | ||
|
||
pipeline = pipeBase.Pipeline.from_uri(args.pipeline) | ||
expanded_pipeline = list(pipeline.toExpandedPipeline()) | ||
pipeline_graph = pipeline.to_graph() | ||
|
||
if len(expanded_pipeline) != 1: | ||
if len(pipeline_graph.tasks) != 1: | ||
raise RuntimeError(f"Pipeline file {args.pipeline} may only contain one task.") | ||
|
||
(task_def,) = expanded_pipeline | ||
(task_node,) = pipeline_graph.tasks.values() | ||
|
||
butler = Butler(args.butler_config, collections=args.input) | ||
|
||
if args.subparser_name == "segment": | ||
# Do the segmentation | ||
hpix_pixelization = HealpixPixelization(level=args.hpix_build_order) | ||
dataset = task_def.connections.coadd_exposure_handles.name | ||
dataset = task_node.inputs["coadd_exposure_handles"].dataset_type_name | ||
data_ids = set(butler.registry.queryDataIds("tract", datasets=dataset).expanded()) | ||
region_pixels = [] | ||
for data_id in data_ids: | ||
|
@@ -378,15 +381,16 @@ def build_quantum_graph_cli(cls, argv): | |
"time": f"{datetime.now()}", | ||
} | ||
|
||
qg = cls.build_quantum_graph( | ||
task_def, | ||
butler.registry, | ||
args.hpix_build_order, | ||
build_ranges, | ||
builder = HighResolutionHipsQuantumGraphBuilder( | ||
pipeline_graph, | ||
butler, | ||
input_collections=args.input, | ||
output_run=args.output_run, | ||
constraint_order=args.hpix_build_order, | ||
constraint_ranges=build_ranges, | ||
where=args.where, | ||
collections=args.input, | ||
metadata=metadata, | ||
) | ||
qg = builder.build(metadata, attach_datastore_records=True) | ||
qg.saveUri(args.save_qgraph) | ||
|
||
@classmethod | ||
|
@@ -488,95 +492,101 @@ def _make_cli_parser(cls): | |
|
||
return parser | ||
|
||
@classmethod | ||
def build_quantum_graph( | ||
cls, | ||
task_def, | ||
registry, | ||
|
||
class HighResolutionHipsQuantumGraphBuilder(QuantumGraphBuilder): | ||
"""A custom a `lsst.pipe.base.QuantumGraphBuilder` for running | ||
`HighResolutionHipsTask` only. | ||
|
||
This is a temporary workaround for incomplete butler query support for | ||
HEALPix dimensions. | ||
|
||
Parameters | ||
---------- | ||
pipeline_graph : `lsst.pipe.base.PipelineGraph` | ||
Pipeline graph with exactly one task, which must be a configuration | ||
of `HighResolutionHipsTask`. | ||
butler : `lsst.daf.butler.Butler` | ||
Client for the butler data repository. May be read-only. | ||
input_collections : `str` or `Iterable` [ `str` ], optional | ||
Collection or collections to search for input datasets, in order. | ||
If not provided, ``butler.collections`` will be searched. | ||
output_run : `str`, optional | ||
Name of the output collection. If not provided, ``butler.run`` will | ||
be used. | ||
constraint_order : `int` | ||
HEALPix order used to constrain which quanta are generated, via | ||
``constraint_indices``. This should be a coarser grid (smaller | ||
order) than the order used for the task's quantum and output data | ||
IDs, and ideally something between the spatial scale of a patch or | ||
the data repository's "common skypix" system (usually ``htm7``). | ||
constraint_ranges : `lsst.sphgeom.RangeSet` | ||
RangeSet that describes constraint pixels (HEALPix NEST, with order | ||
``constraint_order``) to constrain generated quanta. | ||
where : `str`, optional | ||
A boolean `str` expression of the form accepted by | ||
`Registry.queryDatasets` to constrain input datasets. This may | ||
contain a constraint on tracts, patches, or bands, but not HEALPix | ||
indices. Constraints on tracts and patches should usually be | ||
unnecessary, however - existing coadds that overlap the given | ||
HEALpix indices will be selected without such a constraint, and | ||
providing one may reject some that should normally be included. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
pipeline_graph, | ||
butler, | ||
*, | ||
input_collections=None, | ||
output_run=None, | ||
constraint_order, | ||
constraint_ranges, | ||
where=None, | ||
collections=None, | ||
metadata=None, | ||
where="", | ||
): | ||
"""Generate a `QuantumGraph` for running just this task. | ||
|
||
This is a temporary workaround for incomplete butler query support for | ||
HEALPix dimensions. | ||
super().__init__(pipeline_graph, butler, input_collections=input_collections, output_run=output_run) | ||
self.constraint_order = constraint_order | ||
self.constraint_ranges = constraint_ranges | ||
self.where = where | ||
|
||
Parameters | ||
---------- | ||
task_def : `lsst.pipe.base.TaskDef` | ||
Task definition. | ||
registry : `lsst.daf.butler.Registry` | ||
Client for the butler database. May be read-only. | ||
constraint_order : `int` | ||
HEALPix order used to contrain which quanta are generated, via | ||
``constraint_indices``. This should be a coarser grid (smaller | ||
order) than the order used for the task's quantum and output data | ||
IDs, and ideally something between the spatial scale of a patch or | ||
the data repository's "common skypix" system (usually ``htm7``). | ||
constraint_ranges : `lsst.sphgeom.RangeSet` | ||
RangeSet which describes constraint pixels (HEALPix NEST, with order | ||
constraint_order) to constrain generated quanta. | ||
where : `str`, optional | ||
A boolean `str` expression of the form accepted by | ||
`Registry.queryDatasets` to constrain input datasets. This may | ||
contain a constraint on tracts, patches, or bands, but not HEALPix | ||
indices. Constraints on tracts and patches should usually be | ||
unnecessary, however - existing coadds that overlap the given | ||
HEALpix indices will be selected without such a constraint, and | ||
providing one may reject some that should normally be included. | ||
collections : `str` or `Iterable` [ `str` ], optional | ||
Collection or collections to search for input datasets, in order. | ||
If not provided, ``registry.defaults.collections`` will be | ||
searched. | ||
metadata : `dict` [ `str`, `Any` ] | ||
Graph metadata. It is required to contain "output_run" key with the | ||
name of the output RUN collection. | ||
""" | ||
config = task_def.config | ||
def process_subgraph(self, subgraph): | ||
# Docstring inherited. | ||
(task_node,) = subgraph.tasks.values() | ||
|
||
dataset_types = pipeBase.PipelineDatasetTypes.fromPipeline(pipeline=[task_def], registry=registry) | ||
# Since we know this is the only task in the pipeline, we know there | ||
# is only one overall input and one overall output. | ||
(input_dataset_type,) = dataset_types.inputs | ||
|
||
# Extract the main output dataset type (which needs multiple | ||
# DatasetRefs, and tells us the output HPX level), and make a set of | ||
# what remains for more mechanical handling later. | ||
output_dataset_type = dataset_types.outputs[task_def.connections.hips_exposures.name] | ||
incidental_output_dataset_types = dataset_types.outputs.copy() | ||
incidental_output_dataset_types.remove(output_dataset_type) | ||
# is only one overall input and one regular output. | ||
(input_dataset_type_node,) = subgraph.inputs_of(task_node.label).values() | ||
assert input_dataset_type_node is not None, "PipelineGraph should be resolved by base class." | ||
(output_edge,) = task_node.outputs.values() | ||
output_dataset_type_node = subgraph.dataset_types[output_edge.parent_dataset_type_name] | ||
(hpx_output_dimension,) = ( | ||
registry.dimensions.skypix_dimensions[d] for d in output_dataset_type.dimensions.skypix.names | ||
self.butler.dimensions.skypix_dimensions[d] | ||
for d in output_dataset_type_node.dimensions.skypix.names | ||
) | ||
|
||
constraint_hpx_pixelization = registry.dimensions[f"healpix{constraint_order}"].pixelization | ||
common_skypix_name = registry.dimensions.commonSkyPix.name | ||
common_skypix_pixelization = registry.dimensions.commonSkyPix.pixelization | ||
constraint_hpx_pixelization = ( | ||
self.butler.dimensions.skypix_dimensions[f"healpix{self.constraint_order}"].pixelization | ||
) | ||
common_skypix_name = self.butler.dimensions.commonSkyPix.name | ||
common_skypix_pixelization = self.butler.dimensions.commonSkyPix.pixelization | ||
|
||
# We will need all the pixels at the quantum resolution as well | ||
task_dimensions = registry.dimensions.conform(task_def.connections.dimensions) | ||
(hpx_dimension,) = ( | ||
registry.dimensions.skypix_dimensions[d] for d in task_dimensions.names if d != "band" | ||
self.butler.dimensions.skypix_dimensions[d] for d in task_node.dimensions.names if d != "band" | ||
) | ||
hpx_pixelization = hpx_dimension.pixelization | ||
if hpx_pixelization.level < self.constraint_order: | ||
raise ValueError(f"Quantum order {hpx_pixelization.level} must be < {self.constraint_order}") | ||
hpx_ranges = self.constraint_ranges.scaled(4**(hpx_pixelization.level - self.constraint_order)) | ||
|
||
if hpx_pixelization.level < constraint_order: | ||
raise ValueError(f"Quantum order {hpx_pixelization.level} must be < {constraint_order}") | ||
hpx_ranges = constraint_ranges.scaled(4**(hpx_pixelization.level - constraint_order)) | ||
|
||
# We can be generous in looking for pixels here, because we constraint by actual | ||
# patch regions below. | ||
# We can be generous in looking for pixels here, because we constrain | ||
# by actual patch regions below. | ||
common_skypix_ranges = RangeSet() | ||
for begin, end in constraint_ranges: | ||
for begin, end in self.constraint_ranges: | ||
for hpx_index in range(begin, end): | ||
constraint_hpx_region = constraint_hpx_pixelization.pixel(hpx_index) | ||
common_skypix_ranges |= common_skypix_pixelization.envelope(constraint_hpx_region) | ||
|
||
# To keep the query from getting out of hand (and breaking) we simplify until we have fewer | ||
# than 100 ranges which seems to work fine. | ||
# To keep the query from getting out of hand (and breaking) we simplify | ||
# until we have fewer than 100 ranges which seems to work fine. | ||
for simp in range(1, 10): | ||
if len(common_skypix_ranges) < 100: | ||
break | ||
|
@@ -596,24 +606,26 @@ def build_quantum_graph( | |
where_terms.append(f"({common_skypix_name} >= cpx{n}a AND {common_skypix_name} <= cpx{n}b)") | ||
bind[f"cpx{n}a"] = begin | ||
bind[f"cpx{n}b"] = stop | ||
if where is None: | ||
if not self.where: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is pre-existing but you are always calculating the join of where_terms. where = " OR ".join(where_terms)
if self.where:
where = f"({self.where}) AND ({where})" might be better? |
||
where = " OR ".join(where_terms) | ||
else: | ||
where = f"({where}) AND ({' OR '.join(where_terms)})" | ||
where = f"({self.where}) AND ({' OR '.join(where_terms)})" | ||
# Query for input datasets with this constraint, and ask for expanded | ||
# data IDs because we want regions. Immediately group this by patch so | ||
# we don't do later geometric stuff n_bands more times than we need to. | ||
input_refs = registry.queryDatasets( | ||
input_dataset_type, | ||
input_refs = self.butler.registry.queryDatasets( | ||
input_dataset_type_node.dataset_type, | ||
where=where, | ||
findFirst=True, | ||
collections=collections, | ||
collections=self.input_collections, | ||
bind=bind | ||
).expanded() | ||
inputs_by_patch = defaultdict(set) | ||
patch_dimensions = registry.dimensions.conform(["patch"]) | ||
patch_dimensions = self.butler.dimensions.conform(["patch"]) | ||
for input_ref in input_refs: | ||
inputs_by_patch[input_ref.dataId.subset(patch_dimensions)].add(input_ref) | ||
dataset_key = DatasetKey(input_ref.datasetType.name, input_ref.dataId.required_values) | ||
self.existing_datasets.inputs[dataset_key] = input_ref | ||
inputs_by_patch[input_ref.dataId.subset(patch_dimensions)].add(dataset_key) | ||
if not inputs_by_patch: | ||
message_body = "\n".join(input_refs.explain_no_results()) | ||
raise RuntimeError(f"No inputs found:\n{message_body}") | ||
|
@@ -622,66 +634,47 @@ def build_quantum_graph( | |
# that overlap each one. Use that to associate inputs with output | ||
# pixels, but only for the output pixels we've already identified. | ||
inputs_by_hpx = defaultdict(set) | ||
for patch_data_id, input_refs_for_patch in inputs_by_patch.items(): | ||
for patch_data_id, input_keys_for_patch in inputs_by_patch.items(): | ||
patch_hpx_ranges = hpx_pixelization.envelope(patch_data_id.region) | ||
for begin, end in patch_hpx_ranges & hpx_ranges: | ||
for hpx_index in range(begin, end): | ||
inputs_by_hpx[hpx_index].update(input_refs_for_patch) | ||
# Iterate over the dict we just created and create the actual quanta. | ||
quanta = [] | ||
output_run = metadata["output_run"] | ||
for hpx_index, input_refs_for_hpx_index in inputs_by_hpx.items(): | ||
inputs_by_hpx[hpx_index].update(input_keys_for_patch) | ||
|
||
# Iterate over the dict we just created and create preliminary quanta. | ||
skeleton = QuantumGraphSkeleton([task_node.label]) | ||
for hpx_index, input_keys_for_hpx_index in inputs_by_hpx.items(): | ||
# Group inputs by band. | ||
input_refs_by_band = defaultdict(list) | ||
for input_ref in input_refs_for_hpx_index: | ||
input_refs_by_band[input_ref.dataId["band"]].append(input_ref) | ||
input_keys_by_band = defaultdict(list) | ||
for input_key in input_keys_for_hpx_index: | ||
input_ref = self.existing_datasets.inputs[input_key] | ||
input_keys_by_band[input_ref.dataId["band"]].append(input_key) | ||
# Iterate over bands to make quanta. | ||
for band, input_refs_for_band in input_refs_by_band.items(): | ||
data_id = registry.expandDataId({hpx_dimension: hpx_index, "band": band}) | ||
|
||
for band, input_keys_for_band in input_keys_by_band.items(): | ||
data_id = self.butler.registry.expandDataId({hpx_dimension.name: hpx_index, "band": band}) | ||
quantum_key = skeleton.add_quantum_node(task_node.label, data_id) | ||
# Add inputs to the skelton | ||
skeleton.add_input_edges(quantum_key, input_keys_for_band) | ||
# Add the regular outputs. | ||
hpx_pixel_ranges = RangeSet(hpx_index) | ||
hpx_output_ranges = hpx_pixel_ranges.scaled(4**(config.hips_order - hpx_pixelization.level)) | ||
output_data_ids = [] | ||
hpx_output_ranges = hpx_pixel_ranges.scaled( | ||
4**(task_node.config.hips_order - hpx_pixelization.level) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a note anywhere explaining where the 4 comes from? (it turns up in a couple of places). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the number of healpixels at level N+1 in each healpixel at level N. I'll add a comment somewhere when I get a chance. |
||
) | ||
for begin, end in hpx_output_ranges: | ||
for hpx_output_index in range(begin, end): | ||
output_data_ids.append( | ||
registry.expandDataId({hpx_output_dimension: hpx_output_index, "band": band}) | ||
dataset_key = skeleton.add_dataset_node( | ||
output_dataset_type_node.name, | ||
self.butler.registry.expandDataId( | ||
{hpx_output_dimension: hpx_output_index, "band": band} | ||
), | ||
) | ||
outputs = { | ||
dt: [DatasetRef(dt, data_id, run=output_run)] for dt in incidental_output_dataset_types | ||
} | ||
outputs[output_dataset_type] = [DatasetRef(output_dataset_type, data_id, run=output_run) | ||
for data_id in output_data_ids] | ||
quanta.append( | ||
Quantum( | ||
taskName=task_def.taskName, | ||
taskClass=task_def.taskClass, | ||
dataId=data_id, | ||
initInputs={}, | ||
inputs={input_dataset_type: input_refs_for_band}, | ||
outputs=outputs, | ||
) | ||
) | ||
|
||
if len(quanta) == 0: | ||
raise RuntimeError("Given constraints yielded empty quantum graph.") | ||
|
||
# Define initOutputs refs. | ||
empty_data_id = DataCoordinate.make_empty(registry.dimensions) | ||
init_outputs = {} | ||
global_init_outputs = [] | ||
if config_dataset_type := dataset_types.initOutputs.get(task_def.configDatasetName): | ||
init_outputs[task_def] = [DatasetRef(config_dataset_type, empty_data_id, run=output_run)] | ||
packages_dataset_name = pipeBase.PipelineDatasetTypes.packagesDatasetName | ||
if packages_dataset_type := dataset_types.initOutputs.get(packages_dataset_name): | ||
global_init_outputs.append(DatasetRef(packages_dataset_type, empty_data_id, run=output_run)) | ||
|
||
return pipeBase.QuantumGraph( | ||
quanta={task_def: quanta}, | ||
initOutputs=init_outputs, | ||
globalInitOutputs=global_init_outputs, | ||
metadata=metadata, | ||
) | ||
skeleton.add_output_edge(quantum_key, dataset_key) | ||
# Add auxiliary outputs (log, metadata). | ||
for write_edge in task_node.iter_all_outputs(): | ||
if write_edge.connection_name == output_edge.connection_name: | ||
continue | ||
dataset_key = skeleton.add_dataset_node(write_edge.parent_dataset_type_name, data_id) | ||
skeleton.add_output_edge(quantum_key, dataset_key) | ||
return skeleton | ||
|
||
|
||
class HipsPropertiesSpectralTerm(pexConfig.Config): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a ticket number for the real solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, and I'm not sure how I want to do it yet; there are a few options but none seem terribly close. Might be best to drop the "temporary" instead.