Skip to content

Commit

Permalink
Update unit tests and error logging
Browse files Browse the repository at this point in the history
Update logging

Update unit tests
  • Loading branch information
bsmartradio committed Feb 15, 2024
1 parent b6e01f0 commit 64a723f
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 97 deletions.
104 changes: 54 additions & 50 deletions python/lsst/ap/association/packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,9 @@
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
import fastavro
from lsst.alert.packet import Schema
try:
from confluent_kafka import KafkaException
import confluent_kafka
except ImportError as error:
error.add_note("Could not import confluent_kafka. Alerts will not be sent to the alert stream")

"""Methods for packaging Apdb and Pipelines data into Avro alerts.
"""
_ConfluentWireFormatHeader = struct.Struct(">bi")
latest_schema = Schema.from_file().definition

_log = logging.getLogger("lsst." + __name__)
_log.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -103,31 +95,52 @@ def __init__(self, **kwargs):
self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile)
os.makedirs(self.config.alertWriteLocation, exist_ok=True)

self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
self.server = os.getenv("AP_KAFKA_SERVER")
# confluent_kafka configures all of its classes with dictionaries. This one
# sets up the bare minimum that is needed.
self.kafka_config = {
# This is the URL to use to connect to the Kafka cluster.
"bootstrap.servers": self.server,
# These next two properties tell the Kafka client about the specific
# authentication and authorization protocols that should be used when
# connecting.
"security.protocol": "SASL_PLAINTEXT",
"sasl.mechanisms": "SCRAM-SHA-512",
# The sasl.username and sasl.password are passed through over
# SCRAM-SHA-512 auth to connect to the cluster. The username is not
# sensitive, but the password is (of course) a secret value which
# should never be committed to source code.
"sasl.username": self.username,
"sasl.password": self.password,
# Batch size limits the largest size of a kafka alert that can be sent.
# We set the batch size to 2 Mb.
"batch.size": 2097152,
"linger.ms": 5,
}
self.kafka_topic = os.getenv("AP_KAFKA_TOPIC")
if self.config.doProduceAlerts:

self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD")
self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME")
self.server = os.getenv("AP_KAFKA_SERVER")
self.kafka_topic = os.getenv("AP_KAFKA_TOPIC")
# confluent_kafka configures all of its classes with dictionaries. This one
# sets up the bare minimum that is needed.
self.kafka_config = {
# This is the URL to use to connect to the Kafka cluster.
"bootstrap.servers": self.server,
# These next two properties tell the Kafka client about the specific
# authentication and authorization protocols that should be used when
# connecting.
"security.protocol": "SASL_PLAINTEXT",
"sasl.mechanisms": "SCRAM-SHA-512",
# The sasl.username and sasl.password are passed through over
# SCRAM-SHA-512 auth to connect to the cluster. The username is not
# sensitive, but the password is (of course) a secret value which
# should never be committed to source code.
"sasl.username": self.username,
"sasl.password": self.password,
# Batch size limits the largest size of a kafka alert that can be sent.
# We set the batch size to 2 Mb.
"batch.size": 2097152,
"linger.ms": 5,
}

try:
from confluent_kafka import KafkaException
self.kafka_exception = KafkaException
import confluent_kafka
except ImportError as error:
error.add_note("Could not import confluent_kafka. Alerts will not be sent "
"to the alert stream")
_log.error(error)

if not self.password:
raise ValueError("Kafka password environment variable was not set.")
if not self.username:
raise ValueError("Kafka username environment variable was not set.")
if not self.server:
raise ValueError("Kafka server environment variable was not set.")
if not self.kafka_topic:
raise ValueError("Kafka topic environment variable was not set.")
self.producer = confluent_kafka.Producer(**self.kafka_config)

@timeMethod
def run(self,
Expand Down Expand Up @@ -258,31 +271,20 @@ def createDiaSourceExtent(self, bboxSize):

def produceAlerts(self, alerts, ccdVisitId):

if not self.password:
raise ValueError("Kafka password environment variable was not set.")
if not self.username:
raise ValueError("Kafka username environment variable was not set.")
if not self.server:
raise ValueError("Kafka server environment variable was not set.")
if not self.kafka_topic:
raise ValueError("Kafka topic environment variable was not set.")
p = confluent_kafka.Producer(**self.kafka_config)
topic = self.kafka_topic

for alert in alerts:
alert_bytes = self._serialize_alert(alert, schema=self.alertSchema.definition, schema_id=1)
try:
p.produce(topic, alert_bytes, callback=self._delivery_callback)
p.flush()
self.producer.produce(self.kafka_topic, alert_bytes, callback=self._delivery_callback)
self.producer.flush()

except KafkaException as e:
except self.kafka_exception as e:
_log.error('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alert_bytes)))

with open(os.path.join(self.config.alertWriteLocation,
f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f:
f.write(alert_bytes)

p.flush()
self.producer.flush()

def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId):
"""Grab an image as a cutout and return a calibrated CCDData image.
Expand Down Expand Up @@ -540,7 +542,8 @@ def _serialize_confluent_wire_header(schema_version):
The Confluent Wire Format is described more fully here:
https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format
"""
return _ConfluentWireFormatHeader.pack(0, schema_version)
ConfluentWireFormatHeader = struct.Struct(">bi")
return ConfluentWireFormatHeader.pack(0, schema_version)

@staticmethod
def _deserialize_confluent_wire_header(raw):
Expand All @@ -558,7 +561,8 @@ def _deserialize_confluent_wire_header(raw):
number of the Avro schema used to encode the message that follows this
header.
"""
_, version = _ConfluentWireFormatHeader.unpack(raw)
ConfluentWireFormatHeader = struct.Struct(">bi")
_, version = ConfluentWireFormatHeader.unpack(raw)
return version

def _delivery_callback(self, err, msg):
Expand Down
111 changes: 64 additions & 47 deletions tests/test_packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@
import sys
from astropy import wcs
from astropy.nddata import CCDData
try:
import confluent_kafka
except ImportError as error:
error.msg += ("Could not import confluent_kafka. Alerts will not be sent to the alert stream")
import logging

from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
from lsst.afw.cameraGeom.testUtils import DetectorWrapper
Expand All @@ -46,6 +43,15 @@
import lsst.utils.tests
import utils_tests

_log = logging.getLogger("lsst." + __name__)
_log.setLevel(logging.DEBUG)

try:
import confluent_kafka # noqa: F401
from confluent_kafka import KafkaException
except ModuleNotFoundError as e:
_log.error('Kafka module not found: {}'.format(e))


def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
"""Run object and source catalogs through the Apdb to get the correct
Expand Down Expand Up @@ -105,32 +111,37 @@ def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):


def mock_alert(alert_id):
"""Generate a minimal mock alert. """
"""Generate a minimal mock alert.
"""
return {
"alertId": alert_id,
"diaSource": {
# Below are all the required fields. Set them to zero.
"midpointMjdTai": 0,
"diaSourceId": 0,
"ccdVisitId": 0,
"filterName": "",
# Below are all the required fields containing random values.
"midpointMjdTai": 5,
"diaSourceId": 4,
"ccdVisitId": 2,
"band": 'g',
"programId": 0,
"ra": 0,
"dec": 0,
"x": 0,
"y": 0,
"apFlux": 0,
"apFluxErr": 0,
"snr": 0,
"psfFlux": 0,
"psfFluxErr": 0,
"ra": 12.5,
"dec": -16.9,
"x": 15.7,
"y": 89.8,
"apFlux": 54.85,
"apFluxErr": 70.0,
"snr": 6.7,
"psfFlux": 700.0,
"psfFluxErr": 90.0,
"flags": 0,
}
}


class TestPackageAlerts(lsst.utils.tests.TestCase):
kafka_enabled = "confluent_kafka" in sys.modules

def __init__(self, *args, **kwargs):
TestPackageAlerts.kafka_enabled = "confluent_kafka" in sys.modules
_log.debug('TestPackageAlerts: kafka_enabled={}'.format(self.kafka_enabled))
super(TestPackageAlerts, self).__init__(*args, **kwargs)

def setUp(self):
patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
Expand Down Expand Up @@ -323,91 +334,94 @@ def test_produceAlerts_empty_password(self):
""" Test that produceAlerts raises if the password is empty or missing.
"""
self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
task = PackageAlertsTask()
with self.assertRaisesRegex(ValueError, "Kafka password"):
task.produceAlerts(None, None)
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841

del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
task = PackageAlertsTask()
with self.assertRaisesRegex(ValueError, "Kafka password"):
task.produceAlerts(None, None)
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841

def test_produceAlerts_empty_username(self):
""" Test that produceAlerts raises if the username is empty or missing.
"""
self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
task = PackageAlertsTask()
with self.assertRaisesRegex(ValueError, "Kafka username"):
task.produceAlerts(None, None)
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841

del self.environ['AP_KAFKA_PRODUCER_USERNAME']
task = PackageAlertsTask()
with self.assertRaisesRegex(ValueError, "Kafka username"):
task.produceAlerts(None, None)
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841

def test_produceAlerts_empty_server(self):
""" Test that produceAlerts raises if the server is empty or missing.
"""
self.environ['AP_KAFKA_SERVER'] = ""
task = PackageAlertsTask()
with self.assertRaisesRegex(ValueError, "Kafka server"):
task.produceAlerts(None, None)
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841

del self.environ['AP_KAFKA_SERVER']
task = PackageAlertsTask()
with self.assertRaisesRegex(ValueError, "Kafka server"):
task.produceAlerts(None, None)
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841

def test_produceAlerts_empty_topic(self):
""" Test that produceAlerts raises if the topic is empty or missing.
"""
self.environ['AP_KAFKA_TOPIC'] = ""
task = PackageAlertsTask()
with self.assertRaisesRegex(ValueError, "Kafka topic"):
task.produceAlerts(None, None)
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841

del self.environ['AP_KAFKA_TOPIC']
task = PackageAlertsTask()
with self.assertRaisesRegex(ValueError, "Kafka topic"):
task.produceAlerts(None, None)
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig) # noqa: F841

@patch('confluent_kafka.Producer')
@unittest.skipUnless("confluent_kafka" in sys.modules, "Test requires confluent_kafka.")
@unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled')
def test_produceAlerts_success(self, mock_producer):
""" Test that produceAlerts calls the producer on all provided alerts
when the alerts are all under the batch size limit.
"""

task = PackageAlertsTask()
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig)
alerts = [mock_alert(1), mock_alert(2)]
ccdVisitId = 123

# Create a variable and assign it an instance of the patched kafka producer
producer_instance = mock_producer.return_value
producer_instance.produce = Mock()
producer_instance.flush = Mock()
task.produceAlerts(alerts, ccdVisitId)
packageAlerts.produceAlerts(alerts, ccdVisitId)

self.assertEqual(producer_instance.produce.call_count, len(alerts))
self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)

@patch('confluent_kafka.Producer')
@unittest.skipUnless("confluent_kafka" in sys.modules, "Test requires confluent_kafka.")
@unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled')
def test_produceAlerts_one_failure(self, mock_producer):
""" Test that produceAlerts correctly fails on one alert
and is writing the failure to disk.
"""
counter = 0

# confluent_kafka is not visible to mock_producer and needs to be
# re-imported here.
def mock_produce(*args, **kwargs):
nonlocal counter
counter += 1
if counter == 2:
raise confluent_kafka.KafkaException
raise KafkaException
else:
return

task = PackageAlertsTask()
packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig)

patcher = patch("builtins.open")
patch_open = patcher.start()
Expand All @@ -418,7 +432,7 @@ def mock_produce(*args, **kwargs):
producer_instance.produce = Mock(side_effect=mock_produce)
producer_instance.flush = Mock()

task.produceAlerts(alerts, ccdVisitId)
packageAlerts.produceAlerts(alerts, ccdVisitId)

self.assertEqual(producer_instance.produce.call_count, len(alerts))
self.assertEqual(patch_open.call_count, 1)
Expand Down Expand Up @@ -479,12 +493,12 @@ def testRun_without_produce(self):
shutil.rmtree(tempdir)

@patch.object(PackageAlertsTask, 'produceAlerts')
@unittest.skipUnless("confluent_kafka" in sys.modules, "Test requires confluent_kafka.")
def testRun_with_produce(self, mock_produceAlerts):
@patch('confluent_kafka.Producer')
@unittest.skipIf("confluent_kafka" not in sys.modules, 'Kafka is not enabled')
def testRun_with_produce(self, mock_produceAlerts, mock_producer):
"""Test that packageAlerts calls produceAlerts when doProduceAlerts
is set to True.
"""

packConfig = PackageAlertsConfig(doProduceAlerts=True)
packageAlerts = PackageAlertsTask(config=packConfig)

Expand Down Expand Up @@ -521,6 +535,9 @@ def test_serialize_alert_round_trip(self, **kwargs):
alert = mock_alert(1)
serialized = PackageAlertsTask._serialize_alert(packageAlerts, alert)
deserialized = PackageAlertsTask._deserialize_alert(packageAlerts, serialized)

for field in alert['diaSource']:
self.assertAlmostEqual(alert['diaSource'][field], deserialized['diaSource'][field], places=7)
self.assertEqual(1, deserialized["alertId"])


Expand Down

0 comments on commit 64a723f

Please sign in to comment.