Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/playbook.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ Add text as follows.

Select **Publish Release**.

The `ci-release.yaml <https://github.com/lsst-dm/prompt_processing/actions/workflows/ci-release.yaml>`_ GitHub Actions workflow uploads the new release to GitHub packages.
The `Release CI <https://github.com/lsst-dm/prompt_processing/actions/workflows/ci-release.yaml>`_ GitHub Actions workflow uploads the new release to GitHub packages.
This may take a few minutes, and the release is not usable until it succeeds.

3. Tag the release

Expand Down
299 changes: 223 additions & 76 deletions python/activator/activator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import uuid
import yaml

import astropy.time
import boto3
from botocore.handlers import validate_bucket_name
import cloudevents.http
Expand Down Expand Up @@ -74,10 +75,12 @@
image_timeout = int(os.environ.get("IMAGE_TIMEOUT", 20))
# Absolute path on this worker's system where local repos may be created
local_repos = os.environ.get("LOCAL_REPOS", "/tmp")
# Kafka server
# Kafka server for raw notifications
kafka_cluster = os.environ["KAFKA_CLUSTER"]
# Kafka group; must be worker-unique to keep workers from "stealing" messages for others.
kafka_group_id = str(uuid.uuid4())
# The time (in seconds) after which to ignore old nextVisit messages.
visit_expire = float(os.environ.get("MESSAGE_EXPIRATION", 3600))
# The topic on which to listen to updates to image_bucket
bucket_topic = os.environ.get("BUCKET_TOPIC", "rubin-prompt-processing")
# Offset for Kafka bucket notification.
Expand Down Expand Up @@ -188,17 +191,171 @@ def _get_local_cache():
return make_local_cache()


def _make_redis_streams_client():
"""Create a new Redis client.
class RedisStreamSession:
"""The use of a single Redis Stream by a single consumer.

Returns
-------
redis_client : `redis.Redis`
Initialized Redis client.
A "session" may include multiple connections to Redis Streams. This object
automatically opens connections as needed, and closes them when they are
unsafe. Connections may also be closed manually by calling `close`.

Parameters
----------
host : `str`
The address of the Redis Streams cluster.
stream : `str`
The name of the stream to listen to.
consumer_id : `str`
The unique Redis Streams consumer for this session.
consumer_group : `str`
The Redis Stream consumer group.
connect : `bool`, optional
Whether to connect to ``stream`` on construction. If `False`, the
connection is deferred until a stream operation is needed.
"""
redis_host = redis_stream_host
redis_client = redis.Redis(host=redis_host)
return redis_client

# invariant: self.client is either an open Redis object, or `None`

def __init__(self, host, stream, consumer_id, consumer_group, *, connect=True):
self.host = host
self.stream = stream
self.groupname = consumer_group
self.consumername = consumer_id
self.client = None
if connect:
self._ensure_connection()

def _make_redis_streams_client(self):
"""Create a new Redis client.

Returns
-------
redis_client : `redis.Redis`
Initialized Redis client.
"""
return redis.Redis(host=self.host)

@staticmethod
def _close_on_error(func):
"""A decorator that closes the Redis client on a connection error.

This is a safety measure: if the caller aborts, the client needs to be
cleaned up; if they can recover, it's best to do so with a fresh
connection.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
# TODO Review Redis Errors and determine what should be retriable.
except redis.exceptions.RedisError:
self.close()
raise
return wrapper

@_close_on_error
def _ensure_connection(self):
"""Check for a valid connection to the stream, opening a new one
if necessary.

After this method returns, ``self.client`` is guaranteed non-`None`.

Exceptions
----------
redis.exceptions.RedisError
Raised if the client could not connect or an existing connection
has gone bad.
"""
if not self.client:
self.client = self._make_redis_streams_client()
_log.debug("Redis Streams client setup")
self.client.ping()

@_close_on_error
def acknowledge(self, message_id):
"""Acknowledge receipt of a message.

Parameters
----------
message_id : `str`
The message to acknowledge.
"""
self._ensure_connection()
self.client.xack(self.stream, self.groupname, message_id)

def close(self):
"""Close the session's active connection, if it has one.

This method is idempotent.
"""
if self.client:
self.client.close()
self.client = None

@_close_on_error
def read_message(self):
"""Attempt to read one message from the stream.

Returns
-------
message_id : `str` or `None`
The Redis Streams message ID. `None` if no message was read.
message : `dict` [`str`, `str`]
The message contents. Empty if no message was read.

Exceptions
----------
redis.exceptions.RedisError
Raised if the stream could not be read.
ValueError
Raised if a message was received but the message was invalid.
Invalid messages may not be acknowledged (a message ID might not
exist) and do not close the stream, even if ``close_on_receipt``
is set.
"""
self._ensure_connection()
raw_message = self.client.xreadgroup(
streams={self.stream: ">"}, # Read new messages (">" for pending messages)
consumername=self.consumername,
groupname=self.groupname,
count=1 # Read one message at a time
)

if not raw_message:
return None, {}
else:
return self._decode_redis_streams_message(raw_message)

@staticmethod
def _decode_redis_streams_message(fan_out_message):
"""Decode redis streams message from binary.

Parameters
----------
fan_out_message
Fan out message, as a list of dicts.

Returns
-------
redis_streams_message_id : `str`
Redis streams message id decoded from bytes.
fan_out_visit_decoded : `dict` [`str`, `str`]
Fan out visit message decoded from bytes.

Raises
------
ValueError
Raised if the message could not be decoded.
"""
try:
# Decode redis streams message id
redis_streams_message_id = (fan_out_message[0][1][0][0]).decode("utf-8")
# Decode and unpack fan out message from redis stream
fan_out_visit_bytes = fan_out_message[0][1][0][1]
fan_out_visit_decoded = {key.decode("utf-8"): value.decode("utf-8")
for key, value in fan_out_visit_bytes.items()}
return redis_streams_message_id, fan_out_visit_decoded
except (LookupError, UnicodeError) as e:
raise ValueError("Invalid redis stream message") from e


def _time_since(start_time):
Expand All @@ -217,30 +374,6 @@ def _time_since(start_time):
return time.time() - start_time


def _decode_redis_streams_message(fan_out_message):
"""Decode redis streams message from binary.

Parameters
----------
fan_out_message
Fan out message, as a list of dicts.

Returns
-------
redis_streams_message_id : `str`
Redis streams message id decoded from bytes.
fan_out_visit_decoded : `dict` [`str`, `str`]
Fan out visit message decoded from bytes.
"""
# Decode redis streams message id
redis_streams_message_id = (fan_out_message[0][1][0][0]).decode("utf-8")
# Decode and unpack fan out message from redis stream
fan_out_visit_bytes = fan_out_message[0][1][0][1]
fan_out_visit_decoded = {key.decode("utf-8"): value.decode("utf-8")
for key, value in fan_out_visit_bytes.items()}
return redis_streams_message_id, fan_out_visit_decoded


def _calculate_time_since_fan_out_message_delivered(redis_streams_message_id):
"""Calculates time from fan out message to when message is unpacked
in prompt processing.
Expand All @@ -261,6 +394,37 @@ def _calculate_time_since_fan_out_message_delivered(redis_streams_message_id):
return _time_since(message_timestamp/1000.0)


def is_processable(visit, expire) -> bool:
"""Test whether a nextVisit message should be processed, or rejected out
of hand.

This function emits explanatory logs as a side effect.

Parameters
----------
visit : `FannedOutVisit`
The nextVisit message to consider processing.
expire : `float`
The maximum age, in seconds, that a message can still be handled.

Returns
-------
handleable : `bool`
`True` is the message can be processed, `False` otherwise.
"""
# sndStamp is visit publication, in seconds since 1970-01-01 TAI
# For expirations of a few minutes the TAI-UTC difference is significant!
published = astropy.time.Time(visit.private_sndStamp, format="unix_tai").utc.unix
age = round(_time_since(published)) # Microsecond precision is distracting
if age > expire:
_log.warning("Message published on %s UTC is %s old, ignoring.",
time.ctime(published),
astropy.time.TimeDelta(age, format='sec').quantity_str
)
return False
return True


def create_app():
try:
setup_usdf_logger(
Expand Down Expand Up @@ -320,15 +484,13 @@ def keda_start():
_get_central_butler()
_get_local_repo()

# Setup redis client connection. Setup before while loop to avoid performance
# issues of constantly resetting up client connection
redis_client = _make_redis_streams_client()
try:
redis_client.ping()
except redis.exceptions.RedisError:
# Startup handler will quit; make sure the client is cleaned up for that.
redis_client.close()
raise
redis_session = RedisStreamSession(
redis_stream_host,
redis_stream_name,
redis_group_id,
redis_stream_consumer_group,
connect=True,
)

_log.info("Worker ready to handle requests.")

Expand All @@ -346,46 +508,37 @@ def keda_start():
and (_time_since(fan_out_listen_start_time) < fanned_out_msg_listen_timeout):

try:
fan_out_message = redis_client.xreadgroup(
streams={redis_stream_name: ">"}, # Read new messages (">" for pending messages)
consumername=redis_group_id,
groupname=redis_stream_consumer_group,
count=1 # Read one message at a time
)
redis_streams_message_id, fan_out_visit_decoded = redis_session.read_message()
processing_start = time.time()
processing_result = "Unknown"

if not fan_out_message:
if not redis_streams_message_id:
continue
else:

redis_streams_message_id, fan_out_visit_decoded = _decode_redis_streams_message(
fan_out_message)
# TODO: Revisit acknowledgement policy for old messages once fan-out service exists
redis_session.acknowledge(redis_streams_message_id)

# Ack the redis stream message and close redis stream client
# TODO Consider moving xack after process visit completes for catch up processing.
redis_client.xack(redis_stream_name,
redis_stream_consumer_group,
redis_streams_message_id)
redis_client.close()
expected_visit = FannedOutVisit.from_dict(fan_out_visit_decoded)
_log.debug("Unpacked message as %r.", expected_visit)
if is_processable(expected_visit, visit_expire):
# Processing can take a long time, and long-lived connections are ill-behaved
redis_session.close()
else:
continue

# TODO Review Redis Errors and determine what should be retriable.
except redis.exceptions.RedisError as e:
_log.critical("Redis Streams error; aborting.")
_log.exception(e)
redis_client.close()
sys.exit(1)
except (LookupError, UnicodeError) as e:
except ValueError as e:
_log.error("Invalid redis stream message %s", e)
fan_out_listen_start_time = time.time()
continue

with instances_processing_gauge.track_inprogress():
try:

expected_visit = FannedOutVisit.from_dict(fan_out_visit_decoded)
_log.debug("Unpacked message as %r.", expected_visit)

consumer_polls_with_message += 1
if consumer_polls_with_message >= 1:
fan_out_listen_time = _time_since(fan_out_listen_start_time)
Expand Down Expand Up @@ -417,18 +570,9 @@ def keda_start():
_time_since(processing_start), processing_result)
_log.info("Processing completed for %s.", socket.gethostname())

# Reset timer for fan out message polling and start redis client for next poll
# Reset timer for fan out message polling
_log.info("Starting next visit fan out event consumer poll")
fan_out_listen_start_time = time.time()
try:
redis_client = _make_redis_streams_client()
redis_client.ping()
_log.info("Redis Streams client setup for continued polling")
except Exception as e:
_log.critical("Redis Streams client unexpected error in continued polling; aborting")
_log.exception(e)
redis_client.close()
sys.exit(1)

if _time_since(fan_out_listen_start_time) >= fanned_out_msg_listen_timeout:
_log.info("No messages received in %f seconds, shutting down.", fanned_out_msg_listen_timeout)
Expand Down Expand Up @@ -617,8 +761,11 @@ def next_visit_handler() -> tuple[str, int]:
except ValueError as e:
_log.exception("Bad Request")
return f"Bad Request: {e}", 400
process_visit(expected_visit)
return "Pipeline executed", 200
if is_processable(expected_visit, visit_expire):
process_visit(expected_visit)
return "Pipeline executed", 200
else:
return "Stale request, ignoring", 403
except GracefulShutdownInterrupt:
# Safety net to minimize chance of interrupt propagating out of the worker.
# Ideally, this would be a Flask.errorhandler, but Flask ignores BaseExceptions.
Expand Down