diff --git a/tools/tokenserver/process_account_events.py b/tools/tokenserver/process_account_events.py index 69cf700cf7..648d8f9cb8 100644 --- a/tools/tokenserver/process_account_events.py +++ b/tools/tokenserver/process_account_events.py @@ -40,17 +40,25 @@ from database import Database -logger = logging.getLogger("tokenserver.scripts.process_account_deletions") +# Logging is initialized in `main` by `util.configure_script_logging()` +# Please do not call `logging.basicConfig()` before then, since this may +# cause duplicate error messages to be generated. +APP_LABEL = "tokenserver.scripts.process_account_events" -def process_account_events(queue_name, aws_region=None, queue_wait_time=20): +def process_account_events( + queue_name, + aws_region=None, + queue_wait_time=20, + metrics=None): """Process account events from an SQS queue. This function polls the specified SQS queue for account-realted events, processing each as it is found. It polls indefinitely and does not return; to interrupt execution you'll need to e.g. SIGINT the process. """ - logger.info("Processing account events from %s", queue_name) + logger = logging.getLogger(APP_LABEL) + logger.info(f"Processing account events from {queue_name}") database = Database() try: # Connect to the SQS queue. @@ -69,7 +77,7 @@ def process_account_events(queue_name, aws_region=None, queue_wait_time=20): msg = queue.read(wait_time_seconds=queue_wait_time) if msg is None: continue - process_account_event(database, msg.get_body()) + process_account_event(database, msg.get_body(), metrics=metrics) # This intentionally deletes the event even if it was some # unrecognized type. Not point leaving a backlog. queue.delete_message(msg) @@ -78,9 +86,10 @@ def process_account_events(queue_name, aws_region=None, queue_wait_time=20): raise -def process_account_event(database, body): +def process_account_event(database, body, metrics=None): """Parse and process a single account event.""" # Try very hard not to error out if there's junk in the queue. + logger = logging.getLogger(APP_LABEL) email = None event_type = None generation = None @@ -105,23 +114,30 @@ def process_account_event(database, body): logger.exception("Invalid account message: %s", e) else: if email is not None: - if event_type == "delete": - # Mark the user as retired. - # Actual cleanup is done by a separate process. - logger.info("Processing account delete for %r", email) - database.retire_user(email) - elif event_type == "reset": - logger.info("Processing account reset for %r", email) - update_generation_number(database, email, generation) - elif event_type == "passwordChange": - logger.info("Processing password change for %r", email) - update_generation_number(database, email, generation) - else: - logger.warning("Dropping unknown event type %r", - event_type) - - -def update_generation_number(database, email, generation): + record_metric = True + match event_type: + case "delete": + # Mark the user as retired. + # Actual cleanup is done by a separate process. + logger.info("Processing account delete for %r", email) + database.retire_user(email) + case "reset": + logger.info("Processing account reset for %r", email) + update_generation_number( + database, email, generation, metrics=metrics) + case "passwordChange": + logger.info("Processing password change for %r", email) + update_generation_number( + database, email, generation, metrics=metrics) + case _: + record_metric = False + logger.warning("Dropping unknown event type %r", + event_type) + if record_metric and metrics: + metrics.incr(event_type) + + +def update_generation_number(database, email, generation, metrics=None): """Update the maximum recorded generation number for the given user. When the FxA server sends us an update to the user's generation @@ -145,6 +161,8 @@ def update_generation_number(database, email, generation): user = database.get_user(email) if user is not None: database.update_user(user, generation - 1) + if metrics: + metrics.incr("decr_generation") def main(args=None): @@ -161,17 +179,41 @@ def main(args=None): help="Number of seconds to wait for jobs on the queue") parser.add_option("-v", "--verbose", action="count", dest="verbosity", help="Control verbosity of log messages") + parser.add_option("", "--human_logs", action="store_true", + help="Human readable logs") + parser.add_option( + "", + "--metric_host", + default=None, + help="Metric host name" + ) + parser.add_option( + "", + "--metric_port", + default=None, + help="Metric host port" + ) opts, args = parser.parse_args(args) + # set up logging + logger = util.configure_script_logging(opts, logger_name=APP_LABEL) + + logger.info("Starting up..") + + # set up metrics: + metrics = util.Metrics(opts, namespace="tokenserver") + if len(args) != 1: parser.print_usage() return 1 - util.configure_script_logging(opts) - queue_name = args[0] - process_account_events(queue_name, opts.aws_region, opts.queue_wait_time) + process_account_events( + queue_name, + opts.aws_region, + opts.queue_wait_time, + metrics=metrics) return 0 diff --git a/tools/tokenserver/purge_old_records.py b/tools/tokenserver/purge_old_records.py index 89f90a59f6..930b758631 100644 --- a/tools/tokenserver/purge_old_records.py +++ b/tools/tokenserver/purge_old_records.py @@ -15,11 +15,11 @@ """ +import backoff import binascii import hawkauthlib import logging import optparse -import os import random import requests import time @@ -29,12 +29,11 @@ from database import Database from util import format_key_id -LOGGER = "tokenserver.scripts.purge_old_records" -logger = logging.getLogger(LOGGER) -log_level = os.environ.get("PYTHON_LOG", "INFO").upper() -logger.setLevel(log_level) -logger.debug(f"Setting level to {log_level}") +# Logging is initialized in `main` by `util.configure_script_logging()` +# Please do not call `logging.basicConfig()` before then, since this may +# cause duplicate error messages to be generated. +LOGGER = "tokenserver.scripts.purge_old_records" PATTERN = "{node}/1.5/{uid}" @@ -50,6 +49,7 @@ def purge_old_records( force=False, override_node=None, uid_range=None, + metrics=None, ): """Purge old records from the database. @@ -63,6 +63,7 @@ def purge_old_records( a (likely) different set of records to work on. A cheap, imperfect randomization. """ + logger = logging.getLogger(LOGGER) logger.info("Purging old user records") try: database = Database() @@ -103,7 +104,11 @@ def purge_old_records( row.node ) if not dryrun: - database.delete_user_record(row.uid) + if metrics: + metrics.incr( + "delete_user", + tags={"type": "nodeless"}) + retryable(database.delete_user_record, row.uid) # NOTE: only delete_user+service_data calls count # against the counter elif not row.downed: @@ -112,17 +117,28 @@ def purge_old_records( row.uid, row.node) if not dryrun: - delete_service_data( + retryable( + delete_service_data, row, secret, timeout=request_timeout, - dryrun=dryrun + dryrun=dryrun, + metrics=metrics, ) - database.delete_user_record(row.uid) + if metrics: + metrics.incr("delete_data") + retryable( + database.delete_user_record, + row.uid) + if metrics: + metrics.incr( + "delete_user", + tags={"type": "not_down"} + ) counter += 1 elif force: delete_sd = not points_to_active( - database, row, override_node) + database, row, override_node, metrics=metrics) logger.info( "Forcing tokenserver record delete: " f"{row.uid} on {row.node} " @@ -137,29 +153,36 @@ def purge_old_records( # request refers to a node not contained by # the existing data set. # (The call mimics a user DELETE request.) - try: - delete_service_data( + retryable( + delete_service_data, row, secret, timeout=request_timeout, dryrun=dryrun, # if an override was specifed, # use that node ID - override_node=override_node - ) - except requests.HTTPError: - logger.warn( - "Delete failed for user " - f"{row.uid} [{row.node}]" + override_node=override_node, + metrics=metrics, ) - if override_node: - # Assume the override_node should be - # reachable - raise - database.delete_user_record(row.uid) + if metrics: + metrics.incr( + "delete_data", + tags={"type": "force"} + ) + + retryable( + database.delete_user_record, + row.uid) + if metrics: + metrics.incr( + "delete_data", + tags={"type": "force"} + ) counter += 1 if max_records and counter >= max_records: logger.info("Reached max_records, exiting") + if metrics: + metrics.incr("max_records") return True if len(rows) < max_per_loop: break @@ -172,7 +195,8 @@ def purge_old_records( def delete_service_data( - user, secret, timeout=60, dryrun=False, override_node=None): + user, secret, timeout=60, dryrun=False, override_node=None, + metrics=None): """Send a data-deletion request to the user's service node. This is a little bit of hackery to cause the user's service node to @@ -202,10 +226,25 @@ def delete_service_data( return resp = requests.delete(endpoint, auth=auth, timeout=timeout) if resp.status_code >= 400 and resp.status_code != 404: + if metrics: + metrics.incr("error.gone") resp.raise_for_status() -def points_to_active(database, replaced_at_row, override_node): +def retry_giveup(e): + return 500 <= e.response.status_code < 505 + + +@backoff.on_exception( + backoff.expo, + requests.HTTPError, + giveup=retry_giveup + ) +def retryable(fn, *args, **kwargs): + fn(*args, **kwargs) + + +def points_to_active(database, replaced_at_row, override_node, metrics=None): """Determine if a `replaced_at` user record has the same generation/client_state as their active record. @@ -232,7 +271,10 @@ def points_to_active(database, replaced_at_row, override_node): replaced_at_row_keys_changed_at or replaced_at_row.generation, binascii.unhexlify(replaced_at_row.client_state), ) - return user_fxa_kid == replaced_at_row_fxa_kid + override = user_fxa_kid == replaced_at_row_fxa_kid + if override and metrics: + metrics.incr("override") + return override return False @@ -254,7 +296,6 @@ def main(args=None): This function parses command-line arguments and passes them on to the purge_old_records() function. """ - logger = logging.getLogger(LOGGER) usage = "usage: %prog [options] secret" parser = optparse.OptionParser(usage=usage) parser.add_option( @@ -339,9 +380,33 @@ def main(args=None): default=None, help="End of UID range to check" ) + parser.add_option( + "", + "--human_logs", + action="store_true", + help="Human readable logs" + ) + parser.add_option( + "", + "--metric_host", + default=None, + help="Metric host name" + ) + parser.add_option( + "", + "--metric_port", + default=None, + help="Metric host port" + ) opts, args = parser.parse_args(args) + # set up logging + logger = util.configure_script_logging(opts, logger_name=LOGGER) + + # set up metrics: + metrics = util.Metrics(opts, namespace="tokenserver") + if len(args) == 0: parser.print_usage() return 1 @@ -350,8 +415,6 @@ def main(args=None): secret = args[-1] logger.debug(f"Secret: {secret}") - util.configure_script_logging(opts) - uid_range = None if opts.range_start or opts.range_end: uid_range = (opts.range_start, opts.range_end) @@ -368,6 +431,7 @@ def main(args=None): force=opts.force, override_node=opts.override_node, uid_range=uid_range, + metrics=metrics, ) if not opts.oneshot: while True: @@ -388,6 +452,7 @@ def main(args=None): force=opts.force, override_node=opts.override_node, uid_range=uid_range, + metrics=metrics, ) return 0 diff --git a/tools/tokenserver/requirements.txt b/tools/tokenserver/requirements.txt index 2fc4fa3752..809a747315 100644 --- a/tools/tokenserver/requirements.txt +++ b/tools/tokenserver/requirements.txt @@ -6,3 +6,6 @@ sqlalchemy==1.4.46 testfixtures tokenlib==2.0.0 PyBrowserID==0.14.0 +datadog +backoff + diff --git a/tools/tokenserver/util.py b/tools/tokenserver/util.py index 2810da51e7..c87857bd50 100644 --- a/tools/tokenserver/util.py +++ b/tools/tokenserver/util.py @@ -10,6 +10,11 @@ import sys import time import logging +import os +import json +from datetime import datetime + +from datadog import initialize, statsd from browserid.utils import encode_bytes as encode_bytes_b64 @@ -23,26 +28,58 @@ def run_script(main): sys.exit(exitcode) -def configure_script_logging(opts=None): +def configure_script_logging(opts=None, logger_name=""): """Configure stdlib logging to produce output from the script. This basically configures logging to send messages to stderr, with formatting that's more for human readability than machine parsing. It also takes care of the --verbosity command-line option. """ - if not opts or not opts.verbosity: - loglevel = logging.WARNING - elif opts.verbosity == 1: - loglevel = logging.INFO + + verbosity = ( + opts and getattr( + opts, "verbosity", logging.NOTSET)) or logging.NOTSET + logger = logging.getLogger(logger_name) + level = os.environ.get("PYTHON_LOG", "").upper() or \ + max(logging.DEBUG, logging.WARNING - (verbosity * 10)) or \ + logger.getEffectiveLevel() + + # if we've previously setup a handler, adjust it instead + if logger.hasHandlers(): + handler = logger.handlers[0] else: - loglevel = logging.DEBUG + handler = logging.StreamHandler() - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter("%(message)s")) - handler.setLevel(loglevel) + formatter = GCP_JSON_Formatter() + # if we've opted for "human_logs", specify a simpler message. + if opts: + if getattr(opts, "human_logs", None): + formatter = logging.Formatter( + "{levelname:<8s}: {message}", + style="{") + + handler.setFormatter(formatter) + handler.setLevel(level) logger = logging.getLogger("") logger.addHandler(handler) - logger.setLevel(loglevel) + logger.setLevel(level) + return logger + + +# We need to reformat a few things to get the record to display correctly +# This includes "escaping" the message as well as converting the timestamp +# into a parsable format. +class GCP_JSON_Formatter(logging.Formatter): + + def format(self, record): + return json.dumps({ + "severity": record.levelname, + "message": record.getMessage(), + "timestamp": datetime.fromtimestamp( + record.created).strftime( + "%Y-%m-%dT%H:%M:%SZ" # RFC3339 + ), + }) def format_key_id(keys_changed_at, key_hash): @@ -56,3 +93,21 @@ def format_key_id(keys_changed_at, key_hash): def get_timestamp(): """Get current timestamp in milliseconds.""" return int(time.time() * 1000) + + +class Metrics(): + + def __init__(self, opts, namespace=""): + options = dict( + namespace=namespace, + statsd_namespace=namespace, + statsd_host=getattr( + opts, "metric_host", os.environ.get("METRIC_HOST")), + statsd_port=getattr( + opts, "metric_port", os.environ.get("METRIC_PORT")), + ) + self.prefix = options.get("namespace") + initialize(**options) + + def incr(self, label, tags=None): + statsd.increment(label, tags=tags)