Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions doc/playbook.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,22 @@ To create or edit the Cloud Run service in the Google Cloud Console:
* Select the container image URL from "Artifact Registry > prompt-proto-service"
* In the Variables & Secrets tab, set the following required parameters:

* RUBIN_INSTRUMENT: full instrument class name, including module path
* RUBIN_INSTRUMENT: the "short" instrument name
* PUBSUB_VERIFICATION_TOKEN: choose an arbitrary string matching the Pub/Sub endpoint URL below
* IMAGE_BUCKET: bucket containing raw images (``rubin-prompt-proto-main``)
* CALIB_REPO: repo containing calibrations (and templates)
* CALIB_REPO: URI to repo containing calibrations (and templates)
* IP_APDB: IP address and port of the APDB (see `Databases`_, below)
* IP_REGISTRY: IP address and port of the registry database (see `Databases`_)
* DB_APDB: PostgreSQL database name for the APDB
* DB_REGISTRY: PostgreSQL database name for the registry database

* There is also one optional parameter:
* There are also four optional parameters:

* IMAGE_TIMEOUT: timeout in seconds to wait for raw image, default 50 sec.
* LOCAL_REPOS: absolute path (in the container) where local repos are created, default ``/tmp``.
* USER_APDB: database user for the APDB, default "postgres"
* USER_REGISTRY: database user for the registry database, default "postgres"
* NAMESPACE_APDB: the database namespace for the APDB, defaults to the DB's default namespace

* One variable is set by Cloud Run and should not be overridden:

Expand Down
118 changes: 61 additions & 57 deletions python/activator/activator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,37 @@
import json
import logging
import os
import re
import time
from typing import Optional, Tuple

from flask import Flask, request
from google.cloud import pubsub_v1, storage

from lsst.daf.butler import Butler
from lsst.obs.base import Instrument
from .logger import GCloudStructuredLogFormatter
from .logger import setup_google_logger
from .make_pgpass import make_pgpass
from .middleware_interface import MiddlewareInterface
from .raw import RAW_REGEXP
from .middleware_interface import get_central_butler, MiddlewareInterface
from .raw import Snap
from .visit import Visit

PROJECT_ID = "prompt-proto"

verification_token = os.environ["PUBSUB_VERIFICATION_TOKEN"]
# The full instrument class name, including module path.
config_instrument = os.environ["RUBIN_INSTRUMENT"]
active_instrument = Instrument.from_string(config_instrument)
# The short name for the instrument.
instrument_name = os.environ["RUBIN_INSTRUMENT"]
# URI to the main repository containing calibs and templates
calib_repo = os.environ["CALIB_REPO"]
# Bucket name (not URI) containing raw images
image_bucket = os.environ["IMAGE_BUCKET"]
# Time to wait for raw image upload, in seconds
timeout = os.environ.get("IMAGE_TIMEOUT", 50)
# Absolute path on this worker's system where local repos may be created
local_repos = os.environ.get("LOCAL_REPOS", "/tmp")

# Set up logging for all modules used by this worker.
log_handler = logging.StreamHandler()
log_handler.setFormatter(GCloudStructuredLogFormatter(
labels={"instrument": active_instrument.getName()},
))
logging.basicConfig(handlers=[log_handler])
setup_google_logger(
labels={"instrument": instrument_name},
)
_log = logging.getLogger("lsst." + __name__)
_log.setLevel(logging.DEBUG)
logging.captureWarnings(True)


# Write PostgreSQL credentials.
Expand All @@ -71,25 +68,17 @@
subscriber = pubsub_v1.SubscriberClient()
topic_path = subscriber.topic_path(
PROJECT_ID,
f"{active_instrument.getName()}-image",
f"{instrument_name}-image",
)
subscription = None

storage_client = storage.Client()

# Initialize middleware interface.
# TODO: this should not be done in activator.py, which is supposed to have only
# framework/messaging support (ideally, it should not contain any LSST imports).
# However, we don't want MiddlewareInterface to need to know details like where
# the central repo is located, either, so perhaps we need a new module.
central_butler = Butler(calib_repo,
collections=[active_instrument.makeCollectionName("defaults")],
writeable=True,
inferDefaults=False)
repo = f"/tmp/butler-{os.getpid()}"
butler = Butler(Butler.makeRepo(repo), writeable=True)
_log.info("Created local Butler repo at %s.", repo)
mwi = MiddlewareInterface(central_butler, image_bucket, config_instrument, butler)
mwi = MiddlewareInterface(get_central_butler(calib_repo, instrument_name),
image_bucket,
instrument_name,
local_repos)


def check_for_snap(
Expand Down Expand Up @@ -123,6 +112,35 @@ def check_for_snap(
return blobs[0].name


def parse_next_visit(http_request):
"""Parse a next_visit event and extract its data.

Parameters
----------
http_request : `flask.Request`
The request to be parsed.

Returns
-------
next_visit : `activator.visit.Visit`
The next_visit message contained in the request.

Raises
------
ValueError
Raised if ``http_request`` is not a valid message.
"""
envelope = http_request.get_json()
if not envelope:
raise ValueError("no Pub/Sub message received")
if not isinstance(envelope, dict) or "message" not in envelope:
raise ValueError("invalid Pub/Sub message format")

payload = base64.b64decode(envelope["message"]["data"])
data = json.loads(payload)
return Visit(**data)


@app.route("/next-visit", methods=["POST"])
def next_visit_handler() -> Tuple[str, int]:
"""A Flask view function for handling next-visit events.
Expand All @@ -145,22 +163,13 @@ def next_visit_handler() -> Tuple[str, int]:
)
_log.debug(f"Created subscription '{subscription.name}'")
try:
envelope = request.get_json()
if not envelope:
msg = "no Pub/Sub message received"
try:
expected_visit = parse_next_visit(request)
except ValueError as msg:
_log.warn(f"error: '{msg}'")
return f"Bad Request: {msg}", 400

if not isinstance(envelope, dict) or "message" not in envelope:
msg = "invalid Pub/Sub message format"
_log.warn(f"error: '{msg}'")
return f"Bad Request: {msg}", 400

payload = base64.b64decode(envelope["message"]["data"])
data = json.loads(payload)
expected_visit = Visit(**data)
assert expected_visit.instrument == active_instrument.getName(), \
f"Expected {active_instrument.getName()}, received {expected_visit.instrument}."
assert expected_visit.instrument == instrument_name, \
f"Expected {instrument_name}, received {expected_visit.instrument}."
expid_set = set()

# Copy calibrations for this detector/visit
Expand All @@ -175,9 +184,10 @@ def next_visit_handler() -> Tuple[str, int]:
expected_visit.detector,
)
if oid:
m = re.match(RAW_REGEXP, oid)
raw_info = Snap.from_oid(oid)
_log.debug("Found %r already present", raw_info)
mwi.ingest_image(expected_visit, oid)
expid_set.add(m.group('expid'))
expid_set.add(raw_info.exp_id)

_log.debug(f"Waiting for snaps from {expected_visit}.")
start = time.time()
Expand All @@ -202,20 +212,14 @@ def next_visit_handler() -> Tuple[str, int]:
for received in response.received_messages:
ack_list.append(received.ack_id)
oid = received.message.attributes["objectId"]
m = re.match(RAW_REGEXP, oid)
if m:
instrument, detector, group, snap, expid = m.groups()
_log.debug("instrument, detector, group, snap, expid = %s", m.groups())
if (
instrument == expected_visit.instrument
and int(detector) == int(expected_visit.detector)
and group == str(expected_visit.group)
and int(snap) < int(expected_visit.snaps)
):
try:
raw_info = Snap.from_oid(oid)
_log.debug("Received %r", raw_info)
if raw_info.is_consistent(expected_visit):
# Ingest the snap
mwi.ingest_image(oid)
expid_set.add(expid)
else:
expid_set.add(raw_info.exp_id)
except ValueError:
_log.error(f"Failed to match object id '{oid}'")
subscriber.acknowledge(subscription=subscription.name, ack_ids=ack_list)

Expand Down
28 changes: 27 additions & 1 deletion python/activator/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,38 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__all__ = ["GCloudStructuredLogFormatter"]
__all__ = ["GCloudStructuredLogFormatter", "setup_google_logger"]

import json
import logging


# TODO: replace with something more extensible, once we know what needs to
# vary besides the formatter (handler type?).
def setup_google_logger(labels=None):
"""Set global logging settings for prompt_prototype.
Calling this function makes `GCloudStructuredLogFormatter` the root
formatter and redirects all warnings to go through it.
Parameters
----------
labels : `dict` [`str`, `str`]
Any metadata that should be attached to all logs. See
``LogEntry.labels`` in Google Cloud REST API documentation.
Returns
-------
handler : `logging.Handler`
The handler used by the root logger.
"""
log_handler = logging.StreamHandler()
log_handler.setFormatter(GCloudStructuredLogFormatter(labels))
logging.basicConfig(handlers=[log_handler])
logging.captureWarnings(True)
return log_handler


class GCloudStructuredLogFormatter(logging.Formatter):
"""A formatter that can be parsed by the Google Cloud logging agent.
Expand Down
9 changes: 6 additions & 3 deletions python/activator/make_pgpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import stat


PSQL_DB = "postgres"
PSQL_USER = "postgres"


Expand All @@ -44,15 +43,19 @@ def make_pgpass():
"""
try:
ip_apdb = os.environ["IP_APDB"]
db_apdb = os.environ["DB_APDB"]
user_apdb = os.environ.get("USER_APDB", PSQL_USER)
pass_apdb = os.environ["PSQL_APDB_PASS"]
ip_registry = os.environ["IP_REGISTRY"]
db_registry = os.environ["DB_REGISTRY"]
user_registry = os.environ.get("USER_REGISTRY", PSQL_USER)
pass_registry = os.environ["PSQL_REGISTRY_PASS"]
except KeyError as e:
raise RuntimeError("Addresses and passwords have not been configured") from e

filename = os.path.join(os.environ["HOME"], ".pgpass")
with open(filename, mode="wt") as file:
file.write(f"{ip_apdb}:{PSQL_DB}:{PSQL_USER}:{pass_apdb}\n")
file.write(f"{ip_registry}:{PSQL_DB}:{PSQL_USER}:{pass_registry}\n")
file.write(f"{ip_apdb}:{db_apdb}:{user_apdb}:{pass_apdb}\n")
file.write(f"{ip_registry}:{db_registry}:{user_registry}:{pass_registry}\n")
# Only user may access the file
os.chmod(filename, stat.S_IRUSR)
Loading