Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-23032: CreateSDM functor for bitpacking mutiple flag columns #107

Merged
merged 3 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
117 changes: 116 additions & 1 deletion python/lsst/ap/association/transformDiaSourceCatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np
import os
import yaml

from lsst.daf.base import DateTime
import lsst.pex.config as pexConfig
Expand All @@ -40,6 +41,11 @@ class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections,
defaultTemplates={"coaddName": "deep", "fakesType": ""}):
"""Butler connections for TransformDiaSourceCatalogTask.
"""
diaSourceSchema = connTypes.InitInput(
doc="Schema for DIASource catalog output by ImageDifference.",
storageClass="SourceCatalog",
name="{fakesType}{coaddName}Diff_diaSrc_schema",
)
diaSourceCat = connTypes.Input(
doc="Catalog of DiaSources produced during image differencing.",
name="{fakesType}{coaddName}Diff_diaSrc",
Expand All @@ -64,6 +70,13 @@ class TransformDiaSourceCatalogConfig(pipeBase.PipelineTaskConfig,
pipelineConnections=TransformDiaSourceCatalogConnections):
"""
"""
flagMap = pexConfig.Field(
dtype=str,
doc="Yaml file specifying SciencePipelines flag fields to bit packs.",
default=os.path.join(getPackageDir("ap_association"),
"data",
"association-flag-map.yaml"),
)
functorFile = pexConfig.Field(
dtype=str,
doc='Path to YAML file specifying Science DataModel functors to use '
Expand All @@ -85,10 +98,38 @@ class TransformDiaSourceCatalogTask(TransformCatalogBaseTask):

ConfigClass = TransformDiaSourceCatalogConfig
_DefaultName = "transformDiaSourceCatalog"
RunnerClass = pipeBase.ButlerInitializedTaskRunner

def __init__(self, **kwargs):
def __init__(self, initInputs, **kwargs):
super().__init__(**kwargs)
self.funcs = self.getFunctors()
self.inputSchema = initInputs['diaSourceSchema'].schema
self._create_bit_pack_mappings()

def _create_bit_pack_mappings(self):
"""Setup all flag bit packings.
"""
self.bit_pack_columns = []
with open(self.config.flagMap) as yaml_stream:
table_list = list(yaml.safe_load_all(yaml_stream))
for table in table_list:
if table['tableName'] == 'DiaSource':
self.bit_pack_columns = table['columns']
break
Comment on lines +115 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the for loop here, since you're just iterating through until you find the table with the correct name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering the yaml file could have flags for different tables (DiaObject, SSObject) and no enforced ordering, it makes sense to just go through the named tables and load the one we want.


# Test that all flags requested are present in the input schemas.
# Output schemas are flexible, however if names are not specified in
# the Apdb schema, flag columns will not be persisted.
for outputFlag in self.bit_pack_columns:
bitList = outputFlag['bitList']
for bit in bitList:
try:
self.inputSchema.find(bit['name'])
except KeyError:
raise KeyError(
"Requested column %s not found in input DiaSource "
"schema. Please check that the requested input "
"column exists." % bit['name'])

def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
Expand Down Expand Up @@ -135,6 +176,7 @@ def run(self,
diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD)
diaSourceDf["diaObjectId"] = 0
diaSourceDf["pixelId"] = 0
self.bitPackFlags(diaSourceDf)

df = self.transform(band,
ParquetTable(dataFrame=diaSourceDf),
Expand Down Expand Up @@ -179,3 +221,76 @@ def computeBBoxSizes(self, inputCatalog):
outputBBoxSizes[idx] = bboxSize

return outputBBoxSizes

def bitPackFlags(self, df):
"""Pack requested flag columns in inputRecord into single columns in
outputRecord.

Parameters
----------
df : `pandas.DataFrame`
DataFrame to read bits from and pack them into.
"""
for outputFlag in self.bit_pack_columns:
bitList = outputFlag['bitList']
value = np.zeros(len(df), dtype=np.uint64)
for bit in bitList:
# Hard type the bit arrays.
value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64)
df[outputFlag['columnName']] = value


class UnpackApdbFlags:
"""Class for unpacking bits from integer flag fields stored in the Apdb.

Attributes
----------
flag_map_file : `str`
Absolute or relative path to a yaml file specifiying mappings of flags
to integer bits.
table_name : `str`
Name of the Apdb table the integer bit data are coming from.
"""

def __init__(self, flag_map_file, table_name):
self.bit_pack_columns = []
with open(flag_map_file) as yaml_stream:
table_list = list(yaml.safe_load_all(yaml_stream))
for table in table_list:
if table['tableName'] == table_name:
self.bit_pack_columns = table['columns']
break

self.output_flag_columns = {}

for column in self.bit_pack_columns:
names = []
for bit in column["bitList"]:
names.append((bit["name"], bool))
self.output_flag_columns[column["columnName"]] = names

def unpack(self, input_flag_values, flag_name):
"""Determine individual boolean flags from an input array of unsigned
ints.

Parameters
----------
input_flag_values : array-like of type uint
Input integer flags to unpack.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Input" does not need to be capitalized.

flag_name : `str`
Apdb column name of integer flags to unpack. Names of packed int
flags are given by the flag_map_file.

Returns
-------
output_flags : `numpy.ndarray`
Numpy named tuple of booleans.
"""
bit_names_types = self.output_flag_columns[flag_name]
output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types)

for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types):
masked_bits = np.bitwise_and(input_flag_values, 2 ** bit_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove spaces in 2 ** bit_idx

output_flags[bit_name] = masked_bits

return output_flags
2 changes: 1 addition & 1 deletion tests/test-flag-map.yaml → tests/data/test-flag-map.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file lists columns for bit-packing using columns for use in the testing
#
#
tableName: DiaSource
doc: "Flag packing definitions for the DiaSource table."
columns:
Expand Down
3 changes: 3 additions & 0 deletions tests/data/testDiaSource.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ funcs:
filterName:
functor: Column
args: filterName
flags:
functor: Column
args: flags
1 change: 1 addition & 0 deletions tests/test_diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _makeDefaultConfig(cls, doPackageAlerts=False):
config.diaSourceDpddifier.flagMap = os.path.join(
getPackageDir("ap_association"),
"tests",
"data",
"test-flag-map.yaml")
config.doPackageAlerts = doPackageAlerts
return config
Expand Down
1 change: 1 addition & 0 deletions tests/test_map_ap_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def _create_map_dia_source_config(self):
configurable.flagMap = os.path.join(
getPackageDir("ap_association"),
"tests",
"data",
"test-flag-map.yaml")

return configurable
Expand Down
80 changes: 74 additions & 6 deletions tests/test_transformDiaSourceCatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
import lsst.afw.image as afwImage
import lsst.geom as geom
import lsst.meas.base.tests as measTests
from lsst.pipe.base import Struct
from lsst.utils import getPackageDir
import lsst.utils.tests

from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags


class TestTransformDiaSourceCatalogTask(unittest.TestCase):

Expand All @@ -45,7 +48,12 @@ def setUp(self):
for srcIdx in range(nSources):
dataset.addSource(100000.0, geom.Point2D(self.xyLoc, self.xyLoc))
schema = dataset.makeMinimalSchema()
schema.addField("base_PixelFlags_flag", type="Flag")
schema.addField("base_PixelFlags_flag_offimage", type="Flag")
self.exposure, self.inputCatalog = dataset.realize(10.0, schema, randomSeed=1234)
# Make up expected task inputs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please re-word this comment? I find the meaning ambiguous.

self.initInputs = {"diaSourceSchema": Struct(schema=schema)}
self.initInputsBadFlags = {"diaSourceSchema": Struct(schema=dataset.makeMinimalSchema())}

self.expId = 4321
self.date = dafBase.DateTime(nsecs=1400000000 * 10**9)
Expand All @@ -67,11 +75,18 @@ def test_run(self):
"""Test output dataFrame is created and values are correctly inserted
from the exposure.
"""
transformConfig = TransformDiaSourceCatalogConfig()
transformConfig.functorFile = os.path.join(getPackageDir("ap_association"),
"tests/data/",
"testDiaSource.yaml")
transformTask = TransformDiaSourceCatalogTask(config=transformConfig)
transConfig = TransformDiaSourceCatalogConfig()
transConfig.flagMap = os.path.join(
getPackageDir("ap_association"),
"tests",
"data",
"test-flag-map.yaml")
transConfig.functorFile = os.path.join(getPackageDir("ap_association"),
"tests",
"data",
"testDiaSource.yaml")
transformTask = TransformDiaSourceCatalogTask(initInputs=self.initInputs,
config=transConfig)
result = transformTask.run(self.inputCatalog,
self.exposure,
self.filterName,
Expand All @@ -87,15 +102,68 @@ def test_run(self):
self.assertEqual(src["pixelId"], 0)
self.assertEqual(src["diaObjectId"], 0)

def test_run_dia_source_wrong_flags(self):
"""Test that the proper errors are thrown when requesting flag columns
that are not in the input schema.
"""
with self.assertRaises(KeyError):
TransformDiaSourceCatalogTask(initInputs=self.initInputsBadFlags)

def test_computeBBoxSize(self):
"""Test the values created for diaSourceBBox.
"""
transform = TransformDiaSourceCatalogTask()
transConfig = TransformDiaSourceCatalogConfig()
transConfig.flagMap = os.path.join(
getPackageDir("ap_association"),
"tests",
"data",
"test-flag-map.yaml")
transform = TransformDiaSourceCatalogTask(initInputs=self.initInputs,
config=transConfig)
bboxArray = transform.computeBBoxSizes(self.inputCatalog)

# Default in catalog is 18.
self.assertEqual(bboxArray[0], self.bboxSize)

def test_bit_unpacker(self):
"""Test that the integer bit packer is functioning correctly.
"""
transConfig = TransformDiaSourceCatalogConfig()
transConfig.flagMap = os.path.join(
getPackageDir("ap_association"),
"tests",
"data",
"test-flag-map.yaml")
transConfig.functorFile = os.path.join(getPackageDir("ap_association"),
"tests",
"data",
"testDiaSource.yaml")
transform = TransformDiaSourceCatalogTask(initInputs=self.initInputs,
config=transConfig)
for idx, obj in enumerate(self.inputCatalog):
if idx in [1, 3, 5]:
obj.set("base_PixelFlags_flag", 1)
if idx in [1, 4, 6]:
obj.set("base_PixelFlags_flag_offimage", 1)
outputCatalog = transform.run(self.inputCatalog,
self.exposure,
self.filterName,
ccdVisitId=self.expId).diaSourceTable

unpacker = UnpackApdbFlags(transConfig.flagMap, "DiaSource")
flag_values = unpacker.unpack(outputCatalog["flags"], "flags")

for idx, flag in enumerate(flag_values):
if idx in [1, 3, 5]:
self.assertTrue(flag['base_PixelFlags_flag'])
else:
self.assertFalse(flag['base_PixelFlags_flag'])

if idx in [1, 4, 6]:
self.assertTrue(flag['base_PixelFlags_flag_offimage'])
else:
self.assertFalse(flag['base_PixelFlags_flag_offimage'])


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
Expand Down