Skip to content

Commit

Permalink
Update unit tests and import handling
Browse files Browse the repository at this point in the history
Moved functions in unit tests and cleaned up packageAlerts.
  • Loading branch information
bsmartradio committed Feb 29, 2024
1 parent 2270892 commit 3861bca
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 190 deletions.
20 changes: 6 additions & 14 deletions python/lsst/ap/association/diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import numpy as np
import pandas as pd
import logging

from lsst.daf.base import DateTime
import lsst.dax.apdb as daxApdb
Expand All @@ -51,9 +50,6 @@
PackageAlertsTask)
from lsst.ap.association.ssoAssociation import SolarSystemAssociationTask

_log = logging.getLogger("lsst." + __name__)
_log.setLevel(logging.DEBUG)


class DiaPipelineConnections(
pipeBase.PipelineTaskConnections,
Expand Down Expand Up @@ -545,16 +541,12 @@ def run(self,
["diaObjectId", "diaForcedSourceId"],
drop=False,
inplace=True)
try:
self.alertPackager.run(associatedDiaSources,
diaCalResult.diaObjectCat,
loaderResult.diaSources,
diaForcedSources,
diffIm,
template)
except ValueError as err:
# Continue processing even if alert sending fails
_log.error(err)
self.alertPackager.run(associatedDiaSources,
diaCalResult.diaObjectCat,
loaderResult.diaSources,
diaForcedSources,
diffIm,
template)

return pipeBase.Struct(apdbMarker=self.config.apdb.value,
associatedDiaSources=associatedDiaSources,
Expand Down
204 changes: 79 additions & 125 deletions python/lsst/ap/association/packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,20 @@
import io
import os
import sys
import logging

from astropy import wcs
import astropy.units as u
from astropy.nddata import CCDData, VarianceUncertainty
import pandas as pd
import struct
import fastavro
# confluent_kafka is not in the standard Rubin environment as it is a third
# party package and is only needed when producing alerts.
try:
import confluent_kafka
from confluent_kafka import KafkaException
except ImportError:
confluent_kafka = None

import lsst.alert.packet as alertPack
import lsst.afw.geom as afwGeom
Expand All @@ -39,13 +46,6 @@
from lsst.pex.exceptions import InvalidParameterError
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
import fastavro

"""Methods for packaging Apdb and Pipelines data into Avro alerts.
"""

_log = logging.getLogger("lsst." + __name__)
_log.setLevel(logging.DEBUG)


class PackageAlertsConfig(pexConfig.Config):
Expand All @@ -71,14 +71,14 @@ class PackageAlertsConfig(pexConfig.Config):

doProduceAlerts = pexConfig.Field(
dtype=bool,
doc="Turn on alert production to kafka if true. Set to false by default",
doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.",
default=False,
)

doWriteAlerts = pexConfig.Field(
dtype=bool,
doc="Write alerts to disk if true. Set to true by default",
default=True,
doc="Write alerts to disk if true.",
default=False,
)


Expand All @@ -96,51 +96,46 @@ def __init__(self, **kwargs):
os.makedirs(self.config.alertWriteLocation, exist_ok=True)

if self.config.doProduceAlerts:
if confluent_kafka is not None:
self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
if not self.password:
raise ValueError("Kafka password environment variable was not set.")
self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
if not self.username:
raise ValueError("Kafka username environment variable was not set.")
self.server = os.getenv("AP_KAFKA_SERVER")
if not self.server:
raise ValueError("Kafka server environment variable was not set.")
self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC")
if not self.kafkaTopic:
raise ValueError("Kafka topic environment variable was not set.")

# confluent_kafka configures all of its classes with dictionaries. This one
# sets up the bare minimum that is needed.
self.kafkaConfig = {
# This is the URL to use to connect to the Kafka cluster.
"bootstrap.servers": self.server,
# These next two properties tell the Kafka client about the specific
# authentication and authorization protocols that should be used when
# connecting.
"security.protocol": "SASL_PLAINTEXT",
"sasl.mechanisms": "SCRAM-SHA-512",
# The sasl.username and sasl.password are passed through over
# SCRAM-SHA-512 auth to connect to the cluster. The username is not
# sensitive, but the password is (of course) a secret value which
# should never be committed to source code.
"sasl.username": self.username,
"sasl.password": self.password,
# Batch size limits the largest size of a kafka alert that can be sent.
# We set the batch size to 2 Mb.
"batch.size": 2097152,
"linger.ms": 5,
}
self.producer = confluent_kafka.Producer(**self.kafkaConfig)

self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
self.server = os.getenv("AP_KAFKA_SERVER")
self.kafka_topic = os.getenv("AP_KAFKA_TOPIC")
# confluent_kafka configures all of its classes with dictionaries. This one
# sets up the bare minimum that is needed.
self.kafka_config = {
# This is the URL to use to connect to the Kafka cluster.
"bootstrap.servers": self.server,
# These next two properties tell the Kafka client about the specific
# authentication and authorization protocols that should be used when
# connecting.
"security.protocol": "SASL_PLAINTEXT",
"sasl.mechanisms": "SCRAM-SHA-512",
# The sasl.username and sasl.password are passed through over
# SCRAM-SHA-512 auth to connect to the cluster. The username is not
# sensitive, but the password is (of course) a secret value which
# should never be committed to source code.
"sasl.username": self.username,
"sasl.password": self.password,
# Batch size limits the largest size of a kafka alert that can be sent.
# We set the batch size to 2 Mb.
"batch.size": 2097152,
"linger.ms": 5,
}

try:
from confluent_kafka import KafkaException
self.kafka_exception = KafkaException
import confluent_kafka
except ImportError as error:
error.add_note("Could not import confluent_kafka. Alerts will not be sent "
"to the alert stream")
_log.error(error)

if not self.password:
raise ValueError("Kafka password environment variable was not set.")
if not self.username:
raise ValueError("Kafka username environment variable was not set.")
if not self.server:
raise ValueError("Kafka server environment variable was not set.")
if not self.kafka_topic:
raise ValueError("Kafka topic environment variable was not set.")
self.producer = confluent_kafka.Producer(**self.kafka_config)
else:
raise RuntimeError("Produce alerts is set but confluent_kafka is not present in "
"the environment. Alerts will not be sent to the alert stream.")

@timeMethod
def run(self,
Expand All @@ -153,6 +148,10 @@ def run(self,
):
"""Package DiaSources/Object and exposure data into Avro alerts.
Alerts can be sent to the alert stream if ``doProduceAlerts`` is set
and written to disk if ``doWriteAlerts`` is set. Both can be set at the
same time, and are independent of one another.
Writes Avro alerts to a location determined by the
``alertWriteLocation`` configurable.
Expand Down Expand Up @@ -223,21 +222,13 @@ def run(self,
diffImCutout,
templateCutout))

if self.config.doProduceAlerts and "confluent_kafka" in sys.modules:
if self.config.doProduceAlerts:
self.produceAlerts(alerts, ccdVisitId)

elif self.config.doProduceAlerts and "confluent_kafka" not in sys.modules:
raise Exception("Produce alerts is set but confluent_kafka is not in the environment.")

if self.config.doWriteAlerts:
with open(os.path.join(self.config.alertWriteLocation,
f"{ccdVisitId}.avro"),
"wb") as f:
with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f:
self.alertSchema.store_alerts(f, alerts)

if not self.config.doProduceAlerts and not self.config.doWriteAlerts:
raise Exception("Neither produce alerts nor write alerts is set.")

def _patchDiaSources(self, diaSources):
"""Add the ``programId`` column to the data.
Expand All @@ -249,7 +240,7 @@ def _patchDiaSources(self, diaSources):
diaSources["programId"] = 0

def createDiaSourceExtent(self, bboxSize):
"""Create a extent for a box for the cutouts given the size of the
"""Create an extent for a box for the cutouts given the size of the
square BBox that covers the source footprint.
Parameters
Expand All @@ -270,19 +261,29 @@ def createDiaSourceExtent(self, bboxSize):
return extent

def produceAlerts(self, alerts, ccdVisitId):
"""Serialize alerts and send them to the alert stream using
confluent_kafka's producer.
Parameters
----------
alerts : `dict`
Dictionary of alerts to be sent to the alert stream.
ccdVisitId : `int`
ccdVisitId of the alerts sent to the alert stream. Used to write
out alerts which fail to be sent to the alert stream.
"""
for alert in alerts:
alert_bytes = self._serialize_alert(alert, schema=self.alertSchema.definition, schema_id=1)
alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1)
try:
self.producer.produce(self.kafka_topic, alert_bytes, callback=self._delivery_callback)
self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback)
self.producer.flush()

except self.kafka_exception as e:
_log.error('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alert_bytes)))
except KafkaException as e:
self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes)))

with open(os.path.join(self.config.alertWriteLocation,
f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
f.write(alert_bytes)
f.write(alertBytes)

self.producer.flush()

Expand Down Expand Up @@ -466,11 +467,11 @@ def streamCcdDataToBytes(self, cutout):
cutoutBytes = streamer.getvalue()
return cutoutBytes

def _serialize_alert(self, alert, schema=None, schema_id=0):
def _serializeAlert(self, alert, schema=None, schema_id=0):
"""Serialize an alert to a byte sequence for sending to Kafka.
Parameters
----------`
----------
alert : `dict`
An alert payload to be serialized.
schema : `dict`, optional
Expand All @@ -479,7 +480,7 @@ def _serialize_alert(self, alert, schema=None, schema_id=0):
schema_id : `int`, optional
The Confluent Schema Registry ID of the schema. By default, 0 (an
invalid ID) is used, indicating that the schema is not registered.
`
Returns
-------
serialized : `bytes`
Expand All @@ -491,38 +492,12 @@ def _serialize_alert(self, alert, schema=None, schema_id=0):

buf = io.BytesIO()
# TODO: Use a proper schema versioning system (DM-42606)
buf.write(self._serialize_confluent_wire_header(schema_id))
buf.write(self._serializeConfluentWireHeader(schema_id))
fastavro.schemaless_writer(buf, schema, alert)
return buf.getvalue()

def _deserialize_alert(self, alert_bytes, schema=None):
"""Deserialize an alert message from Kafka.
Paramaters
----------
alert_bytes : `bytes`
Binary-encoding serialized Avro alert, including Confluent Wire
Format prefix.
schema : `dict`, optional
An Avro schema definition describing how to encode `alert`. By default,
the latest schema is used.
Returns
-------
alert : `dict`
An alert payload.
"""
if schema is None:
schema = self.alertSchema.definition

header_bytes = alert_bytes[:5]
version = self._deserialize_confluent_wire_header(header_bytes)
assert version == 0
content_bytes = io.BytesIO(alert_bytes[5:])
return fastavro.schemaless_reader(content_bytes, schema)

@staticmethod
def _serialize_confluent_wire_header(schema_version):
def _serializeConfluentWireHeader(schema_version):
"""Returns the byte prefix for Confluent Wire Format-style Kafka messages.
Parameters
Expand All @@ -545,29 +520,8 @@ def _serialize_confluent_wire_header(schema_version):
ConfluentWireFormatHeader = struct.Struct(">bi")
return ConfluentWireFormatHeader.pack(0, schema_version)

@staticmethod
def _deserialize_confluent_wire_header(raw):
"""Parses the byte prefix for Confluent Wire Format-style Kafka messages.
Parameters
----------
raw : `bytes`
The 5-byte encoded message prefix.
Returns
-------
schema_version : `int`
A version number which indicates the Confluent Schema Registry ID
number of the Avro schema used to encode the message that follows this
header.
"""
ConfluentWireFormatHeader = struct.Struct(">bi")
_, version = ConfluentWireFormatHeader.unpack(raw)
return version

def _delivery_callback(self, err, msg):
if err:
_log.debug('%% Message failed delivery: %s\n' % err)
self.log.warning('Message failed delivery: %s\n' % err)
else:
_log.debug('%% Message delivered to %s [%d] @ %d\n' %
(msg.topic(), msg.partition(), msg.offset()))
self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())

0 comments on commit 3861bca

Please sign in to comment.