-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DM-23032: CreateSDM functor for bitpacking mutiple flag columns #107
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
|
||
import numpy as np | ||
import os | ||
import yaml | ||
|
||
from lsst.daf.base import DateTime | ||
import lsst.pex.config as pexConfig | ||
|
@@ -40,6 +41,11 @@ class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections, | |
defaultTemplates={"coaddName": "deep", "fakesType": ""}): | ||
"""Butler connections for TransformDiaSourceCatalogTask. | ||
""" | ||
diaSourceSchema = connTypes.InitInput( | ||
doc="Schema for DIASource catalog output by ImageDifference.", | ||
storageClass="SourceCatalog", | ||
name="{fakesType}{coaddName}Diff_diaSrc_schema", | ||
) | ||
diaSourceCat = connTypes.Input( | ||
doc="Catalog of DiaSources produced during image differencing.", | ||
name="{fakesType}{coaddName}Diff_diaSrc", | ||
|
@@ -64,6 +70,13 @@ class TransformDiaSourceCatalogConfig(pipeBase.PipelineTaskConfig, | |
pipelineConnections=TransformDiaSourceCatalogConnections): | ||
""" | ||
""" | ||
flagMap = pexConfig.Field( | ||
dtype=str, | ||
doc="Yaml file specifying SciencePipelines flag fields to bit packs.", | ||
default=os.path.join(getPackageDir("ap_association"), | ||
"data", | ||
"association-flag-map.yaml"), | ||
) | ||
functorFile = pexConfig.Field( | ||
dtype=str, | ||
doc='Path to YAML file specifying Science DataModel functors to use ' | ||
|
@@ -85,10 +98,38 @@ class TransformDiaSourceCatalogTask(TransformCatalogBaseTask): | |
|
||
ConfigClass = TransformDiaSourceCatalogConfig | ||
_DefaultName = "transformDiaSourceCatalog" | ||
RunnerClass = pipeBase.ButlerInitializedTaskRunner | ||
|
||
def __init__(self, **kwargs): | ||
def __init__(self, initInputs, **kwargs): | ||
super().__init__(**kwargs) | ||
self.funcs = self.getFunctors() | ||
self.inputSchema = initInputs['diaSourceSchema'].schema | ||
self._create_bit_pack_mappings() | ||
|
||
def _create_bit_pack_mappings(self): | ||
"""Setup all flag bit packings. | ||
""" | ||
self.bit_pack_columns = [] | ||
with open(self.config.flagMap) as yaml_stream: | ||
table_list = list(yaml.safe_load_all(yaml_stream)) | ||
for table in table_list: | ||
if table['tableName'] == 'DiaSource': | ||
self.bit_pack_columns = table['columns'] | ||
break | ||
|
||
# Test that all flags requested are present in the input schemas. | ||
# Output schemas are flexible, however if names are not specified in | ||
# the Apdb schema, flag columns will not be persisted. | ||
for outputFlag in self.bit_pack_columns: | ||
bitList = outputFlag['bitList'] | ||
for bit in bitList: | ||
try: | ||
self.inputSchema.find(bit['name']) | ||
except KeyError: | ||
raise KeyError( | ||
"Requested column %s not found in input DiaSource " | ||
"schema. Please check that the requested input " | ||
"column exists." % bit['name']) | ||
|
||
def runQuantum(self, butlerQC, inputRefs, outputRefs): | ||
inputs = butlerQC.get(inputRefs) | ||
|
@@ -135,6 +176,7 @@ def run(self, | |
diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD) | ||
diaSourceDf["diaObjectId"] = 0 | ||
diaSourceDf["pixelId"] = 0 | ||
self.bitPackFlags(diaSourceDf) | ||
|
||
df = self.transform(band, | ||
ParquetTable(dataFrame=diaSourceDf), | ||
|
@@ -179,3 +221,76 @@ def computeBBoxSizes(self, inputCatalog): | |
outputBBoxSizes[idx] = bboxSize | ||
|
||
return outputBBoxSizes | ||
|
||
def bitPackFlags(self, df): | ||
"""Pack requested flag columns in inputRecord into single columns in | ||
outputRecord. | ||
|
||
Parameters | ||
---------- | ||
df : `pandas.DataFrame` | ||
DataFrame to read bits from and pack them into. | ||
""" | ||
for outputFlag in self.bit_pack_columns: | ||
bitList = outputFlag['bitList'] | ||
value = np.zeros(len(df), dtype=np.uint64) | ||
for bit in bitList: | ||
# Hard type the bit arrays. | ||
value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64) | ||
df[outputFlag['columnName']] = value | ||
|
||
|
||
class UnpackApdbFlags: | ||
"""Class for unpacking bits from integer flag fields stored in the Apdb. | ||
|
||
Attributes | ||
---------- | ||
flag_map_file : `str` | ||
Absolute or relative path to a yaml file specifiying mappings of flags | ||
to integer bits. | ||
table_name : `str` | ||
Name of the Apdb table the integer bit data are coming from. | ||
""" | ||
|
||
def __init__(self, flag_map_file, table_name): | ||
self.bit_pack_columns = [] | ||
with open(flag_map_file) as yaml_stream: | ||
table_list = list(yaml.safe_load_all(yaml_stream)) | ||
for table in table_list: | ||
if table['tableName'] == table_name: | ||
self.bit_pack_columns = table['columns'] | ||
break | ||
|
||
self.output_flag_columns = {} | ||
|
||
for column in self.bit_pack_columns: | ||
names = [] | ||
for bit in column["bitList"]: | ||
names.append((bit["name"], bool)) | ||
self.output_flag_columns[column["columnName"]] = names | ||
|
||
def unpack(self, input_flag_values, flag_name): | ||
"""Determine individual boolean flags from an input array of unsigned | ||
ints. | ||
|
||
Parameters | ||
---------- | ||
input_flag_values : array-like of type uint | ||
Input integer flags to unpack. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Input" does not need to be capitalized. |
||
flag_name : `str` | ||
Apdb column name of integer flags to unpack. Names of packed int | ||
flags are given by the flag_map_file. | ||
|
||
Returns | ||
------- | ||
output_flags : `numpy.ndarray` | ||
Numpy named tuple of booleans. | ||
""" | ||
bit_names_types = self.output_flag_columns[flag_name] | ||
output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types) | ||
|
||
for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types): | ||
masked_bits = np.bitwise_and(input_flag_values, 2 ** bit_idx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove spaces in |
||
output_flags[bit_name] = masked_bits | ||
|
||
return output_flags |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,3 +20,6 @@ funcs: | |
filterName: | ||
functor: Column | ||
args: filterName | ||
flags: | ||
functor: Column | ||
args: flags |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,9 +29,12 @@ | |
import lsst.afw.image as afwImage | ||
import lsst.geom as geom | ||
import lsst.meas.base.tests as measTests | ||
from lsst.pipe.base import Struct | ||
from lsst.utils import getPackageDir | ||
import lsst.utils.tests | ||
|
||
from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags | ||
|
||
|
||
class TestTransformDiaSourceCatalogTask(unittest.TestCase): | ||
|
||
|
@@ -45,7 +48,12 @@ def setUp(self): | |
for srcIdx in range(nSources): | ||
dataset.addSource(100000.0, geom.Point2D(self.xyLoc, self.xyLoc)) | ||
schema = dataset.makeMinimalSchema() | ||
schema.addField("base_PixelFlags_flag", type="Flag") | ||
schema.addField("base_PixelFlags_flag_offimage", type="Flag") | ||
self.exposure, self.inputCatalog = dataset.realize(10.0, schema, randomSeed=1234) | ||
# Make up expected task inputs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please re-word this comment? I find the meaning ambiguous. |
||
self.initInputs = {"diaSourceSchema": Struct(schema=schema)} | ||
self.initInputsBadFlags = {"diaSourceSchema": Struct(schema=dataset.makeMinimalSchema())} | ||
|
||
self.expId = 4321 | ||
self.date = dafBase.DateTime(nsecs=1400000000 * 10**9) | ||
|
@@ -67,11 +75,18 @@ def test_run(self): | |
"""Test output dataFrame is created and values are correctly inserted | ||
from the exposure. | ||
""" | ||
transformConfig = TransformDiaSourceCatalogConfig() | ||
transformConfig.functorFile = os.path.join(getPackageDir("ap_association"), | ||
"tests/data/", | ||
"testDiaSource.yaml") | ||
transformTask = TransformDiaSourceCatalogTask(config=transformConfig) | ||
transConfig = TransformDiaSourceCatalogConfig() | ||
transConfig.flagMap = os.path.join( | ||
getPackageDir("ap_association"), | ||
"tests", | ||
"data", | ||
"test-flag-map.yaml") | ||
transConfig.functorFile = os.path.join(getPackageDir("ap_association"), | ||
"tests", | ||
"data", | ||
"testDiaSource.yaml") | ||
transformTask = TransformDiaSourceCatalogTask(initInputs=self.initInputs, | ||
config=transConfig) | ||
result = transformTask.run(self.inputCatalog, | ||
self.exposure, | ||
self.filterName, | ||
|
@@ -87,15 +102,68 @@ def test_run(self): | |
self.assertEqual(src["pixelId"], 0) | ||
self.assertEqual(src["diaObjectId"], 0) | ||
|
||
def test_run_dia_source_wrong_flags(self): | ||
"""Test that the proper errors are thrown when requesting flag columns | ||
that are not in the input schema. | ||
""" | ||
with self.assertRaises(KeyError): | ||
TransformDiaSourceCatalogTask(initInputs=self.initInputsBadFlags) | ||
|
||
def test_computeBBoxSize(self): | ||
"""Test the values created for diaSourceBBox. | ||
""" | ||
transform = TransformDiaSourceCatalogTask() | ||
transConfig = TransformDiaSourceCatalogConfig() | ||
transConfig.flagMap = os.path.join( | ||
getPackageDir("ap_association"), | ||
"tests", | ||
"data", | ||
"test-flag-map.yaml") | ||
transform = TransformDiaSourceCatalogTask(initInputs=self.initInputs, | ||
config=transConfig) | ||
bboxArray = transform.computeBBoxSizes(self.inputCatalog) | ||
|
||
# Default in catalog is 18. | ||
self.assertEqual(bboxArray[0], self.bboxSize) | ||
|
||
def test_bit_unpacker(self): | ||
"""Test that the integer bit packer is functioning correctly. | ||
""" | ||
transConfig = TransformDiaSourceCatalogConfig() | ||
transConfig.flagMap = os.path.join( | ||
getPackageDir("ap_association"), | ||
"tests", | ||
"data", | ||
"test-flag-map.yaml") | ||
transConfig.functorFile = os.path.join(getPackageDir("ap_association"), | ||
"tests", | ||
"data", | ||
"testDiaSource.yaml") | ||
transform = TransformDiaSourceCatalogTask(initInputs=self.initInputs, | ||
config=transConfig) | ||
for idx, obj in enumerate(self.inputCatalog): | ||
if idx in [1, 3, 5]: | ||
obj.set("base_PixelFlags_flag", 1) | ||
if idx in [1, 4, 6]: | ||
obj.set("base_PixelFlags_flag_offimage", 1) | ||
outputCatalog = transform.run(self.inputCatalog, | ||
self.exposure, | ||
self.filterName, | ||
ccdVisitId=self.expId).diaSourceTable | ||
|
||
unpacker = UnpackApdbFlags(transConfig.flagMap, "DiaSource") | ||
flag_values = unpacker.unpack(outputCatalog["flags"], "flags") | ||
|
||
for idx, flag in enumerate(flag_values): | ||
if idx in [1, 3, 5]: | ||
self.assertTrue(flag['base_PixelFlags_flag']) | ||
else: | ||
self.assertFalse(flag['base_PixelFlags_flag']) | ||
|
||
if idx in [1, 4, 6]: | ||
self.assertTrue(flag['base_PixelFlags_flag_offimage']) | ||
else: | ||
self.assertFalse(flag['base_PixelFlags_flag_offimage']) | ||
|
||
|
||
class MemoryTester(lsst.utils.tests.MemoryTestCase): | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need the
for
loop here, since you're just iterating through until you find the table with the correct name?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering the yaml file could have flags for different tables (DiaObject, SSObject) and no enforced ordering, it makes sense to just go through the named tables and load the one we want.