From 43d05aab661de246cee077d560b9be73c5aff14b Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Mon, 24 Jun 2024 15:16:56 -0700 Subject: [PATCH 01/25] Sketch out guard log. --- guardrails/guard_call_logging.py | 91 ++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 guardrails/guard_call_logging.py diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py new file mode 100644 index 000000000..de725419f --- /dev/null +++ b/guardrails/guard_call_logging.py @@ -0,0 +1,91 @@ +import os +import sqlite3 +import threading +from dataclasses import dataclass, fields + + +@dataclass +class GuardCallLogEntry: + # Keep in sync with the table creation. + guardname: str + start_time: float + end_time: float + prevalidate_text: str + postvalidate_text: str + exception_text: str + log_level: int + + +class _SyncStructuredLogHandler: + LOG_TABLE = "guard_logs" + CREATE_COMMAND = """""" + INSERT_COMMAND = """ + INSERT INTO ? VALUES ( + :guard_name, :start_time, :end_time, :prevalidate_text, :exception_text, + :log_level + ) + """ + + def __init__(self, default_log_path: os.PathLike): + self.db = sqlite3.connect(default_log_path, ) + cursor = self.db.cursor() + # Generate table rows from GuardCallLogEntry. + create_fields = "" + for field in fields(GuardCallLogEntry): + create_fields += field.name + create_fields += " " + if field.type == int: + create_fields += "INTEGER" + elif field.type == float: + create_fields += "REAL" + elif field.type == str: + create_fields += "TEXT" + create_fields += "," + create_fields.removesuffix(",") # Remove the spurious trailing ','. + cursor.execute( + "CREATE TABLE IF NOT EXISTS " + + _SyncStructuredLogHandler.LOG_TABLE + + f"({fields});" + ) + + def log_entry(self, entry: GuardCallLogEntry): + cursor = self.db.cursor() + cursor.execute("""INSERT""") + + def log( + self, + guard_name: str, + start_time: float, + end_time: float, + prevalidate_text: str, + exception_text: str, + log_level: int, + ): + cursor = self.db.cursor() + cursor.execute(INSERT_COMMAND, ( + _SyncStructuredLogHandler.LOG_TABLE, + guard_name, + start_time, + end_time, + prevalidate_text, + exception_text, + log_level + )) + + +class SyncStructuredLogHandlerSingleton(_SyncStructuredLogHandler): + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: # Yes, two 'if' checks to avoid mutex contention. + # This only runs if we definitely need to lock. + with cls._lock: + if cls._instance is None: + cls._instance = cls.create() + return cls._instance + + @classmethod + def create(cls) -> _SyncStructuredLogHandler: + return _SyncStructuredLogHandler() + From 35df11cd7c8d905a4f4a9f621f0246aa80889aae Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Tue, 25 Jun 2024 12:01:01 -0700 Subject: [PATCH 02/25] Add some experimental code sketches. --- guardrails/guard_call_logging.py | 138 ++++++++++++++++++++----------- 1 file changed, 91 insertions(+), 47 deletions(-) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index de725419f..be2c77781 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -1,56 +1,79 @@ import os import sqlite3 import threading -from dataclasses import dataclass, fields +from dataclasses import dataclass +from typing import Iterator, Optional @dataclass -class GuardCallLogEntry: - # Keep in sync with the table creation. - guardname: str +class GuardLogEntry: + guard_name: str start_time: float end_time: float - prevalidate_text: str - postvalidate_text: str - exception_text: str log_level: int + id: int = -1 + prevalidate_text: str = "" + postvalidate_text: str = "" + exception_message: str = "" class _SyncStructuredLogHandler: - LOG_TABLE = "guard_logs" - CREATE_COMMAND = """""" + CREATE_COMMAND = """ + CREATE TABLE IF NOT EXISTS guard_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + guard_name TEXT, + start_time REAL, + end_time REAL, + prevalidate_text TEXT, + postvalidate_text TEXT, + exception_message TEXT, + log_level INTEGER + ); + """ INSERT_COMMAND = """ - INSERT INTO ? VALUES ( - :guard_name, :start_time, :end_time, :prevalidate_text, :exception_text, - :log_level - ) + INSERT INTO guard_logs ( + guard_name, start_time, end_time, prevalidate_text, postvalidate_text, + exception_message, log_level + ) VALUES ( + :guard_name, :start_time, :end_time, :prevalidate_text, :postvalidate_text, + :exception_message, :log_level + ); """ - def __init__(self, default_log_path: os.PathLike): - self.db = sqlite3.connect(default_log_path, ) - cursor = self.db.cursor() - # Generate table rows from GuardCallLogEntry. - create_fields = "" - for field in fields(GuardCallLogEntry): - create_fields += field.name - create_fields += " " - if field.type == int: - create_fields += "INTEGER" - elif field.type == float: - create_fields += "REAL" - elif field.type == str: - create_fields += "TEXT" - create_fields += "," - create_fields.removesuffix(",") # Remove the spurious trailing ','. - cursor.execute( - "CREATE TABLE IF NOT EXISTS " + - _SyncStructuredLogHandler.LOG_TABLE + - f"({fields});" - ) + def __init__(self, log_path: os.PathLike, read_mode: bool): + self.readonly = read_mode + if read_mode: + self.db = _SyncStructuredLogHandler._get_read_connection(log_path) + else: + self.db = _SyncStructuredLogHandler._get_write_connection(log_path) - def log_entry(self, entry: GuardCallLogEntry): - cursor = self.db.cursor() - cursor.execute("""INSERT""") + @classmethod + def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: + try: + db = sqlite3.connect(log_path, isolation_level=None) + db.execute('PRAGMA journal_mode = wal') + db.execute('PRAGMA synchronous = OFF') + # isolation_level = None and pragma WAL means we can READ from the DB + # while threads using it are writing. Synchronous off puts us on the + # highway to the danger zone, depending on how willing we are to lose log + # messages in the event of a guard crash. + except sqlite3.OperationalError as e: + #logging.exception("Unable to connect to guard log handler.") + raise e + with db: + db.execute(_SyncStructuredLogHandler.CREATE_COMMAND) + return db + + @classmethod + def _get_read_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: + # A bit of a hack to open in read-only mode... + db = sqlite3.connect( + "file:" + log_path + "?mode=ro", + isolation_level=None, + uri=True + ) + db.row_factory = sqlite3.Row + return db def log( self, @@ -58,20 +81,41 @@ def log( start_time: float, end_time: float, prevalidate_text: str, + postvalidate_text: str, exception_text: str, log_level: int, ): - cursor = self.db.cursor() - cursor.execute(INSERT_COMMAND, ( - _SyncStructuredLogHandler.LOG_TABLE, - guard_name, - start_time, - end_time, - prevalidate_text, - exception_text, - log_level - )) + assert not self.readonly + with self.db: + self.db.execute(_SyncStructuredLogHandler.INSERT_COMMAND, dict( + guard_name=guard_name, + start_time=start_time, + end_time=end_time, + prevalidate_text=prevalidate_text, + postvalidate_text=postvalidate_text, + exception_message=exception_text, + log_level=log_level + )) + def tail_logs(self, start_offset_idx: int) -> Iterator[GuardLogEntry]: + last_idx = start_offset_idx + cursor = self.db.cursor() + sql = """ + SELECT + guard_name, start_time, end_time, prevalidate_text, postvalidate_text, + exception_message, log_level + FROM guard_logs + WHERE id > ? + ORDER BY start_time; + """ + cursor.execute("SELECT 1 LIMIT 0;") + while True: + for row in cursor: + last_entry = GuardLogEntry(**row) + last_idx = last_entry.id + yield last_entry + # If we're here we've run out of entries to tail. + cursor.execute(sql, (last_idx,)) class SyncStructuredLogHandlerSingleton(_SyncStructuredLogHandler): _instance = None From 9598ab313fd9fcb7b689b2ade816280083d48564 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Tue, 25 Jun 2024 13:19:52 -0700 Subject: [PATCH 03/25] Add docstrings. Small improvements. --- guardrails/guard_call_logging.py | 58 +++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index be2c77781..1abc77e9a 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -1,8 +1,37 @@ +""" +guard_call_logging.py + +A set of tools to track the behavior of guards, specifically with the intent of +collating the pre/post validation text and timing of guard calls. Uses a singleton to +share write access to a SQLite database across threads. + +# Reading logs (basic): +reader = SyncStructuredLogHandlerSingleton.get_reader() +for t in reader.tail_logs(): + print(t) + +# Reading logs (advanced): +reader = SyncStructuredLogHandlerSingleton.get_reader() +reader.db.execute("SELECT * FROM guard_logs;") # Arbitrary SQL support. + +# Saving logs +writer = SynbcStructuredLogHandlerSingleton() +writer.log( + "my_guard_name", start, end, "Raw LLM Output Text", "Sanitized", "exception?", 0 +) + +""" + import os import sqlite3 import threading -from dataclasses import dataclass -from typing import Iterator, Optional +from dataclasses import dataclass, asdict +from typing import Iterator + +# We should support logging a validation outcome, too. +# from guardrails.classes import ValidationOutcome + +LOG_FILENAME = "guardrails_calls.db" @dataclass @@ -97,13 +126,21 @@ def log( log_level=log_level )) - def tail_logs(self, start_offset_idx: int) -> Iterator[GuardLogEntry]: + def log_entry(self, guard_log_entry): + assert not self.readonly + with self.db: + self.db.execute( + _SyncStructuredLogHandler.INSERT_COMMAND, + asdict(guard_log_entry) + ) + + def tail_logs(self, start_offset_idx: int = 0) -> Iterator[GuardLogEntry]: last_idx = start_offset_idx cursor = self.db.cursor() sql = """ SELECT - guard_name, start_time, end_time, prevalidate_text, postvalidate_text, - exception_message, log_level + id, guard_name, start_time, end_time, prevalidate_text, + postvalidate_text, exception_message, log_level FROM guard_logs WHERE id > ? ORDER BY start_time; @@ -117,6 +154,7 @@ def tail_logs(self, start_offset_idx: int) -> Iterator[GuardLogEntry]: # If we're here we've run out of entries to tail. cursor.execute(sql, (last_idx,)) + class SyncStructuredLogHandlerSingleton(_SyncStructuredLogHandler): _instance = None _lock = threading.Lock() @@ -126,10 +164,14 @@ def __new__(cls): # This only runs if we definitely need to lock. with cls._lock: if cls._instance is None: - cls._instance = cls.create() + cls._instance = cls._create() return cls._instance @classmethod - def create(cls) -> _SyncStructuredLogHandler: - return _SyncStructuredLogHandler() + def _create(cls) -> _SyncStructuredLogHandler: + return _SyncStructuredLogHandler(LOG_FILENAME, read_mode=False) + + @classmethod + def get_reader(cls) -> _SyncStructuredLogHandler: + return _SyncStructuredLogHandler(LOG_FILENAME, read_mode=True) From dd024364712d7a3472c64c17161a5b4dd89fd5eb Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Tue, 25 Jun 2024 15:43:46 -0700 Subject: [PATCH 04/25] Update: this is no longer the approach we're going to use. We're switching to using the telemetry. --- guardrails/guard_call_logging.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index 1abc77e9a..658f35ebc 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -25,11 +25,13 @@ import os import sqlite3 import threading +import time from dataclasses import dataclass, asdict from typing import Iterator -# We should support logging a validation outcome, too. -# from guardrails.classes import ValidationOutcome +from guardrails.classes import ValidationOutcome +from guardrails.classes.history import Call + LOG_FILENAME = "guardrails_calls.db" @@ -175,3 +177,29 @@ def _create(cls) -> _SyncStructuredLogHandler: def get_reader(cls) -> _SyncStructuredLogHandler: return _SyncStructuredLogHandler(LOG_FILENAME, read_mode=True) + +class LoggedCall: + def __init__(self, call_log: Call): + # Have to wait until the actual call to get the sync/async method. + self.log_handler = None + self.called_fn = call_log + self.start_time = None + self.end_time = None + + def report_outcome(self, result: ValidationOutcome): + pass + + def __enter__(self): + self.log_handler = SyncStructuredLogHandlerSingleton() + self.start_time = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end_time = time.time() + # Log the outcome of the operation. + self.log_handler.log() + + def __aenter__(self): + self.start_time = time.time() + + def __aexit__(self, exc_type, exc_val, exc_tb): + self.end_time = time.time() From fda35ad27fda4da5bbe106748b4b95b50c4c1942 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Wed, 26 Jun 2024 12:03:23 -0700 Subject: [PATCH 05/25] Start integration with spans. --- guardrails/cli/__init__.py | 2 + guardrails/cli/watch.py | 48 +++++++++++++++++ guardrails/guard_call_logging.py | 38 ++++++++++++-- guardrails/utils/telemetry_utils.py | 11 ++++ tests/log_multithreaded_experiment.py | 76 +++++++++++++++++++++++++++ 5 files changed, 171 insertions(+), 4 deletions(-) create mode 100644 guardrails/cli/watch.py create mode 100644 tests/log_multithreaded_experiment.py diff --git a/guardrails/cli/__init__.py b/guardrails/cli/__init__.py index d16b49c17..0d2c9f92d 100644 --- a/guardrails/cli/__init__.py +++ b/guardrails/cli/__init__.py @@ -3,6 +3,8 @@ import guardrails.cli.validate # noqa from guardrails.cli.guardrails import guardrails as cli from guardrails.cli.hub import hub_command +from guardrails.cli.watch import watch_command + cli.add_typer( hub_command, name="hub", help="Manage validators installed from the Guardrails Hub." diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py new file mode 100644 index 000000000..11de84c7e --- /dev/null +++ b/guardrails/cli/watch.py @@ -0,0 +1,48 @@ +import time +from typing import Optional + +import rich +import typer + +from guardrails.cli.guardrails import guardrails as gr_cli +from guardrails.guard_call_logging import ( + GuardLogEntry, + SyncStructuredLogHandlerSingleton, +) + + +def _print_and_format_log_message(log_msg: GuardLogEntry): + rich.print(log_msg) + + +@gr_cli.command(name="watch") +def watch_command( + num_lines: int = typer.Option( + default=0, + help="Print the last n most recent lines. If omitted, will print all history." + ), + refresh_frequency: Optional[float] = typer.Option( + default=1.0, + help="How long (in seconds) should the watch command wait between outputs." + ), + follow: bool = typer.Option( + False, + "--follow", + help="Continuously read the last output commands", + ), + log_path_override: Optional[str] = typer.Option( + default=None, + help="Specify a path to the log output file." + ), +): + if log_path_override is not None: + log_reader = SyncStructuredLogHandlerSingleton.get_reader(log_path_override) + else: + log_reader = SyncStructuredLogHandlerSingleton.get_reader() + + while True: + for log_msg in log_reader.tail_logs(-num_lines): + _print_and_format_log_message(log_msg) + if not follow: + return + time.sleep(refresh_frequency) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index 658f35ebc..1ae6f2e76 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -22,6 +22,7 @@ """ +import datetime import os import sqlite3 import threading @@ -36,6 +37,27 @@ LOG_FILENAME = "guardrails_calls.db" +# Handle timestamp -> sqlite map: +def adapt_datetime(val): + """Adapt datetime.datetime to Unix timestamp.""" + # return val.isoformat() # If we want to go to datetime/isoformat... + return int(val.timestamp()) + + +sqlite3.register_adapter(datetime.datetime, adapt_datetime) + + +def convert_timestamp(val): + """Convert Unix epoch timestamp to datetime.datetime object.""" + # To go to datetime.datetime: + # return datetime.datetime.fromisoformat(val.decode()) + return datetime.datetime.fromtimestamp(int(val)) + + +sqlite3.register_converter("timestamp", convert_timestamp) + + + @dataclass class GuardLogEntry: guard_name: str @@ -139,6 +161,14 @@ def log_entry(self, guard_log_entry): def tail_logs(self, start_offset_idx: int = 0) -> Iterator[GuardLogEntry]: last_idx = start_offset_idx cursor = self.db.cursor() + if last_idx < 0: + # We're indexing from the end, so do a quick check. + cursor.execute( + "SELECT id FROM guard_logs ORDER BY id DESC LIMIT 1 OFFSET ?;", + (-last_idx,) + ) + for row in cursor: + last_idx = row['id'] sql = """ SELECT id, guard_name, start_time, end_time, prevalidate_text, @@ -170,12 +200,12 @@ def __new__(cls): return cls._instance @classmethod - def _create(cls) -> _SyncStructuredLogHandler: - return _SyncStructuredLogHandler(LOG_FILENAME, read_mode=False) + def _create(cls, path: os.PathLike = LOG_FILENAME) -> _SyncStructuredLogHandler: + return _SyncStructuredLogHandler(path, read_mode=False) @classmethod - def get_reader(cls) -> _SyncStructuredLogHandler: - return _SyncStructuredLogHandler(LOG_FILENAME, read_mode=True) + def get_reader(cls, path: os.PathLike = LOG_FILENAME) -> _SyncStructuredLogHandler: + return _SyncStructuredLogHandler(path, read_mode=True) class LoggedCall: diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index 76f642d7c..1c794606d 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -7,6 +7,7 @@ from opentelemetry.context import Context from opentelemetry.trace import StatusCode, Tracer +from guardrails.guard_call_logging import SyncStructuredLogHandlerSingleton from guardrails.stores.context import get_tracer as get_context_tracer from guardrails.stores.context import get_tracer_context from guardrails.utils.casting_utils import to_string @@ -100,6 +101,16 @@ def trace_validator_result( "instance_id": instance_id, **kwargs, } + call_logger = SyncStructuredLogHandlerSingleton() + call_logger.log( + validator_name, + start_time, + end_time, + to_string(value_before_validation), + to_string(value_after_validation), + result_type, + 0, + ) current_span.add_event( f"{validator_name}_result", {k: v for k, v in event.items() if v is not None}, diff --git a/tests/log_multithreaded_experiment.py b/tests/log_multithreaded_experiment.py new file mode 100644 index 000000000..c8450bc57 --- /dev/null +++ b/tests/log_multithreaded_experiment.py @@ -0,0 +1,76 @@ +import random +import sys +import time +from multiprocessing import Pool + +from guardrails.guard_call_logging import SyncStructuredLogHandlerSingleton + + +DELAY = 0.1 +hoisted_logger = SyncStructuredLogHandlerSingleton() + + +def main(num_threads: int, num_log_messages: int): + log_levels = list() + for _ in range(num_log_messages): + log_levels.append(random.randint(0, 5)) + print("Trying with hoisted logger:") + with Pool(num_threads) as pool: + pool.map(log_with_hoisted_logger, log_levels) + print("Trying with acquired logger:") + with Pool(num_threads) as pool: + pool.map(log_with_acquired_singleton, log_levels) + + +def log_with_hoisted_logger(log_level: int): + start = time.time() + end = time.time() + hoisted_logger.log( + "hoisted_logger", + start, + end, + "Kept logger from hoisted.", + "Success.", + "", + log_level + ) + time.sleep(DELAY) + + +def log_with_acquired_singleton(log_level: int): + # Try grabbing a reference to the sync writer. + start = time.time() + log = SyncStructuredLogHandlerSingleton() + end = time.time() + log.log( + "acquired_logger", + start, + end, + "Got logger with acquired singleton.", + "It worked.", + "", + log_level + ) + time.sleep(DELAY) + + +if __name__ == '__main__': + if "--help" in sys.argv: + print("Optional args: --num_threads, --num_log_messages") + else: + thread_count = 4 + try: + num_threads_arg_pos = sys.argv.index('--num_threads') + if num_threads_arg_pos != -1: + thread_count = int(sys.argv[num_threads_arg_pos + 1]) + except Exception: + pass + log_message_count = 10 + try: + num_log_messages_pos = sys.argv.index('--num_log_messages') + if num_log_messages_pos != -1: + log_message_count = int(sys.argv[num_log_messages_pos + 1]) + except Exception: + pass + + main(num_threads=thread_count, num_log_messages=log_message_count) From fdf89ec1b0573589974408e989000ac66b5b088a Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Wed, 26 Jun 2024 13:47:17 -0700 Subject: [PATCH 06/25] Integrate logging with trace. --- guardrails/guard_call_logging.py | 56 +++++++++++++-------------- guardrails/utils/telemetry_utils.py | 13 ++----- tests/log_multithreaded_experiment.py | 25 +++++++++++- 3 files changed, 52 insertions(+), 42 deletions(-) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index 1ae6f2e76..73a699899 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -31,12 +31,15 @@ from typing import Iterator from guardrails.classes import ValidationOutcome +from guardrails.utils.casting_utils import to_string from guardrails.classes.history import Call +from guardrails.classes.validation.validator_logs import ValidatorLogs LOG_FILENAME = "guardrails_calls.db" +# These adapters make it more convenient to add data into our log DB: # Handle timestamp -> sqlite map: def adapt_datetime(val): """Adapt datetime.datetime to Unix timestamp.""" @@ -57,7 +60,8 @@ def convert_timestamp(val): sqlite3.register_converter("timestamp", convert_timestamp) - +# This class makes it slightly easier to be selective about how we pull data. +# While it's not the ultimate contract/DB schema, it helps with typing and improves dx. @dataclass class GuardLogEntry: guard_name: str @@ -70,6 +74,8 @@ class GuardLogEntry: exception_message: str = "" +# This structured handler shouldn't be used directly, since it's touching a SQLite db. +# Instead, use the singleton or the async singleton. class _SyncStructuredLogHandler: CREATE_COMMAND = """ CREATE TABLE IF NOT EXISTS guard_logs ( @@ -158,6 +164,21 @@ def log_entry(self, guard_log_entry): asdict(guard_log_entry) ) + def log_validator(self, vlog: ValidatorLogs): + assert not self.readonly + maybe_outcome = str(vlog.validation_result.outcome) \ + if hasattr(vlog.validation_result, "outcome") else "" + with self.db: + self.db.execute(_SyncStructuredLogHandler.INSERT_COMMAND, dict( + guard_name=vlog.validator_name, + start_time=vlog.start_time if vlog.start_time else None, + end_time=vlog.end_time if vlog.end_time else 0.0, + prevalidate_text=to_string(vlog.value_before_validation), + postvalidate_text=to_string(vlog.value_after_validation), + exception_message=maybe_outcome, + log_level=0 + )) + def tail_logs(self, start_offset_idx: int = 0) -> Iterator[GuardLogEntry]: last_idx = start_offset_idx cursor = self.db.cursor() @@ -192,8 +213,10 @@ class SyncStructuredLogHandlerSingleton(_SyncStructuredLogHandler): _lock = threading.Lock() def __new__(cls): - if cls._instance is None: # Yes, two 'if' checks to avoid mutex contention. - # This only runs if we definitely need to lock. + if cls._instance is None: + # We run two 'if None' checks so we don't have to call the mutex check for + # the cases where there's obviously no handler. Only do a check if there + # MIGHT not be a handler instantiated. with cls._lock: if cls._instance is None: cls._instance = cls._create() @@ -206,30 +229,3 @@ def _create(cls, path: os.PathLike = LOG_FILENAME) -> _SyncStructuredLogHandler: @classmethod def get_reader(cls, path: os.PathLike = LOG_FILENAME) -> _SyncStructuredLogHandler: return _SyncStructuredLogHandler(path, read_mode=True) - - -class LoggedCall: - def __init__(self, call_log: Call): - # Have to wait until the actual call to get the sync/async method. - self.log_handler = None - self.called_fn = call_log - self.start_time = None - self.end_time = None - - def report_outcome(self, result: ValidationOutcome): - pass - - def __enter__(self): - self.log_handler = SyncStructuredLogHandlerSingleton() - self.start_time = time.time() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.end_time = time.time() - # Log the outcome of the operation. - self.log_handler.log() - - def __aenter__(self): - self.start_time = time.time() - - def __aexit__(self, exc_type, exc_val, exc_tb): - self.end_time = time.time() diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index 1c794606d..f0738afee 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -101,16 +101,9 @@ def trace_validator_result( "instance_id": instance_id, **kwargs, } - call_logger = SyncStructuredLogHandlerSingleton() - call_logger.log( - validator_name, - start_time, - end_time, - to_string(value_before_validation), - to_string(value_after_validation), - result_type, - 0, - ) + + SyncStructuredLogHandlerSingleton().log_validator(validator_log) + current_span.add_event( f"{validator_name}_result", {k: v for k, v in event.items() if v is not None}, diff --git a/tests/log_multithreaded_experiment.py b/tests/log_multithreaded_experiment.py index c8450bc57..90e5fb58a 100644 --- a/tests/log_multithreaded_experiment.py +++ b/tests/log_multithreaded_experiment.py @@ -14,12 +14,17 @@ def main(num_threads: int, num_log_messages: int): log_levels = list() for _ in range(num_log_messages): log_levels.append(random.randint(0, 5)) - print("Trying with hoisted logger:") + print("Trying with hoisted logger:", end=" ") with Pool(num_threads) as pool: pool.map(log_with_hoisted_logger, log_levels) - print("Trying with acquired logger:") + print("Done.") + print("Trying with acquired logger:", end=" ") with Pool(num_threads) as pool: pool.map(log_with_acquired_singleton, log_levels) + print("Done.") + print("Trying with guard: ", end=" ") + log_from_inside_guard() + print("Done") def log_with_hoisted_logger(log_level: int): @@ -54,6 +59,22 @@ def log_with_acquired_singleton(log_level: int): time.sleep(DELAY) +def log_from_inside_guard(): + from pydantic import BaseModel + from guardrails import Guard + from transformers import pipeline + + model = pipeline("text-generation", "gpt2") + + class Foo(BaseModel): + bar: int + + guard = Guard.from_pydantic(Foo) + # This may not trigger the thing: + guard.validate('{"bar": 42}') + guard(model, prompt="Hi") + + if __name__ == '__main__': if "--help" in sys.argv: print("Optional args: --num_threads, --num_log_messages") From fceb95c085fdf942d12c002d75d635b6126ee0f4 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Wed, 26 Jun 2024 15:05:42 -0700 Subject: [PATCH 07/25] Output as table. --- guardrails/cli/watch.py | 91 ++++++++++++++++++++++++-------- guardrails/guard_call_logging.py | 23 ++++++-- 2 files changed, 90 insertions(+), 24 deletions(-) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index 11de84c7e..424023f20 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -1,7 +1,11 @@ +import sqlite3 import time +from dataclasses import asdict from typing import Optional import rich +import rich.console +import rich.table import typer from guardrails.cli.guardrails import guardrails as gr_cli @@ -11,23 +15,19 @@ ) -def _print_and_format_log_message(log_msg: GuardLogEntry): - rich.print(log_msg) - - @gr_cli.command(name="watch") def watch_command( + plain: bool = typer.Option( + default=False, + is_flag=True, + help="Do not use any rich formatting, instead printing each entry on a line." + ), num_lines: int = typer.Option( default=0, help="Print the last n most recent lines. If omitted, will print all history." ), - refresh_frequency: Optional[float] = typer.Option( - default=1.0, - help="How long (in seconds) should the watch command wait between outputs." - ), follow: bool = typer.Option( - False, - "--follow", + default=False, help="Continuously read the last output commands", ), log_path_override: Optional[str] = typer.Option( @@ -35,14 +35,63 @@ def watch_command( help="Specify a path to the log output file." ), ): - if log_path_override is not None: - log_reader = SyncStructuredLogHandlerSingleton.get_reader(log_path_override) - else: - log_reader = SyncStructuredLogHandlerSingleton.get_reader() - - while True: - for log_msg in log_reader.tail_logs(-num_lines): - _print_and_format_log_message(log_msg) - if not follow: - return - time.sleep(refresh_frequency) + # Open a reader for the log path: + log_reader = None + while log_reader is None: + try: + if log_path_override is not None: + log_reader = SyncStructuredLogHandlerSingleton.get_reader(log_path_override) + else: + log_reader = SyncStructuredLogHandlerSingleton.get_reader() + except sqlite3.OperationalError: + print("Logfile not found. Retrying.") + time.sleep(1) + + # If we are using fancy outputs, grab a console ref and prep a table. + if not plain: + console, table = _setup_console_table() + + # Spin while tailing, breaking if we aren't continuously tailing. + for log_msg in log_reader.tail_logs(-num_lines, follow): + if plain: + _print_and_format_plain(log_msg) + else: + _update_table(log_msg, console, table) + + +def _setup_console_table(): + console = rich.console.Console() + table = rich.table.Table( + show_header=True, + header_style="bold", + ) + table.add_column("ID") + table.add_column("Name") + table.add_column("Start Time") + table.add_column("End Time") + table.add_column("Time Delta") + table.add_column("Prevalidate Text") + table.add_column("Postvalidate Text") + table.add_column("Result Text") + return console, table + + +def _update_table(log_msg: GuardLogEntry, console, table): + table.add_row( + str(log_msg.id), + str(log_msg.guard_name), + str(log_msg.start_time), + str(log_msg.end_time), + str(log_msg.timedelta), + str(log_msg.prevalidate_text), + str(log_msg.postvalidate_text), + str(log_msg.exception_message), + ) + console.print(table) + + +def _print_and_format_plain(log_msg: GuardLogEntry) -> None: + str_builder = list() + for k, v in asdict(log_msg).items(): + str_builder.append(f"{k}: {v}") + print("\t ".join(str_builder)) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index 73a699899..3c2e1f1d2 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -73,6 +73,10 @@ class GuardLogEntry: postvalidate_text: str = "" exception_message: str = "" + @property + def timedelta(self): + return self.end_time - self.start_time + # This structured handler shouldn't be used directly, since it's touching a SQLite db. # Instead, use the singleton or the async singleton. @@ -179,7 +183,18 @@ def log_validator(self, vlog: ValidatorLogs): log_level=0 )) - def tail_logs(self, start_offset_idx: int = 0) -> Iterator[GuardLogEntry]: + def tail_logs( + self, + start_offset_idx: int = 0, + follow: bool = False + ) -> Iterator[GuardLogEntry]: + """Returns an iterator to generate GuardLogEntries. + @param start_offset_idx int : Start printing entries after this IDX. If + negative, this will instead start printing the LAST start_offset_idx entries. + @param follow : If follow is True, will re-check the database for new entries + after the first batch is complete. If False (default), will return when entries + are exhausted. + """ last_idx = start_offset_idx cursor = self.db.cursor() if last_idx < 0: @@ -198,13 +213,15 @@ def tail_logs(self, start_offset_idx: int = 0) -> Iterator[GuardLogEntry]: WHERE id > ? ORDER BY start_time; """ - cursor.execute("SELECT 1 LIMIT 0;") + cursor.execute(sql, (last_idx,)) while True: for row in cursor: last_entry = GuardLogEntry(**row) last_idx = last_entry.id yield last_entry - # If we're here we've run out of entries to tail. + if not follow: + return + # If we're here we've run out of entries to tail. Fetch more: cursor.execute(sql, (last_idx,)) From 6c9348e3e448743ab6173aa8b233ffc941dafe4c Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Wed, 26 Jun 2024 15:07:57 -0700 Subject: [PATCH 08/25] Output as table. --- guardrails/cli/watch.py | 47 ++++++----------------------------------- 1 file changed, 7 insertions(+), 40 deletions(-) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index 424023f20..8af6568e5 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -1,11 +1,10 @@ +import json import sqlite3 import time from dataclasses import asdict from typing import Optional import rich -import rich.console -import rich.table import typer from guardrails.cli.guardrails import guardrails as gr_cli @@ -48,50 +47,18 @@ def watch_command( time.sleep(1) # If we are using fancy outputs, grab a console ref and prep a table. + output_fn = _print_and_format_plain if not plain: - console, table = _setup_console_table() + output_fn = _print_fancy # Spin while tailing, breaking if we aren't continuously tailing. for log_msg in log_reader.tail_logs(-num_lines, follow): - if plain: - _print_and_format_plain(log_msg) - else: - _update_table(log_msg, console, table) + output_fn(log_msg) -def _setup_console_table(): - console = rich.console.Console() - table = rich.table.Table( - show_header=True, - header_style="bold", - ) - table.add_column("ID") - table.add_column("Name") - table.add_column("Start Time") - table.add_column("End Time") - table.add_column("Time Delta") - table.add_column("Prevalidate Text") - table.add_column("Postvalidate Text") - table.add_column("Result Text") - return console, table - - -def _update_table(log_msg: GuardLogEntry, console, table): - table.add_row( - str(log_msg.id), - str(log_msg.guard_name), - str(log_msg.start_time), - str(log_msg.end_time), - str(log_msg.timedelta), - str(log_msg.prevalidate_text), - str(log_msg.postvalidate_text), - str(log_msg.exception_message), - ) - console.print(table) +def _print_fancy(log_msg: GuardLogEntry): + rich.print(log_msg) def _print_and_format_plain(log_msg: GuardLogEntry) -> None: - str_builder = list() - for k, v in asdict(log_msg).items(): - str_builder.append(f"{k}: {v}") - print("\t ".join(str_builder)) + print(json.dumps(asdict(log_msg))) From c44b57327da200c814ea1238f6a1c95b01b873a2 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 10:47:35 -0700 Subject: [PATCH 09/25] Make sure logging works across async, multiple threads, and multiple processes. --- guardrails/cli/watch.py | 6 +-- guardrails/guard_call_logging.py | 58 ++++++++++++++++++++------- guardrails/utils/telemetry_utils.py | 4 +- tests/log_multithreaded_experiment.py | 53 +++++++++++++----------- 4 files changed, 78 insertions(+), 43 deletions(-) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index 8af6568e5..841c30955 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -10,7 +10,7 @@ from guardrails.cli.guardrails import guardrails as gr_cli from guardrails.guard_call_logging import ( GuardLogEntry, - SyncStructuredLogHandlerSingleton, + SyncTraceHandler, ) @@ -39,9 +39,9 @@ def watch_command( while log_reader is None: try: if log_path_override is not None: - log_reader = SyncStructuredLogHandlerSingleton.get_reader(log_path_override) + log_reader = SyncTraceHandler.get_reader(log_path_override) else: - log_reader = SyncStructuredLogHandlerSingleton.get_reader() + log_reader = SyncTraceHandler.get_reader() except sqlite3.OperationalError: print("Logfile not found. Retrying.") time.sleep(1) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index 3c2e1f1d2..c93e47bd8 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -26,7 +26,6 @@ import os import sqlite3 import threading -import time from dataclasses import dataclass, asdict from typing import Iterator @@ -78,9 +77,31 @@ def timedelta(self): return self.end_time - self.start_time +class _NoopTraceHandler: + def __init__(self, log_path: os.PathLike, read_mode: bool): + pass + + def log(self, *args, **kwargs): + pass + + def log_entry(self, guard_log_entry: GuardLogEntry): + pass + + def log_validator(self, vlog: ValidatorLogs): + pass + + def tail_logs( + self, + start_offset_idx: int = 0, + follow: bool = False, + ) -> Iterator[GuardLogEntry]: + return [] + + + # This structured handler shouldn't be used directly, since it's touching a SQLite db. # Instead, use the singleton or the async singleton. -class _SyncStructuredLogHandler: +class _SyncTraceHandler: CREATE_COMMAND = """ CREATE TABLE IF NOT EXISTS guard_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -106,14 +127,18 @@ class _SyncStructuredLogHandler: def __init__(self, log_path: os.PathLike, read_mode: bool): self.readonly = read_mode if read_mode: - self.db = _SyncStructuredLogHandler._get_read_connection(log_path) + self.db = _SyncTraceHandler._get_read_connection(log_path) else: - self.db = _SyncStructuredLogHandler._get_write_connection(log_path) + self.db = _SyncTraceHandler._get_write_connection(log_path) @classmethod def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: try: - db = sqlite3.connect(log_path, isolation_level=None) + db = sqlite3.connect( + log_path, + isolation_level=None, + check_same_thread=False, + ) db.execute('PRAGMA journal_mode = wal') db.execute('PRAGMA synchronous = OFF') # isolation_level = None and pragma WAL means we can READ from the DB @@ -124,7 +149,7 @@ def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: #logging.exception("Unable to connect to guard log handler.") raise e with db: - db.execute(_SyncStructuredLogHandler.CREATE_COMMAND) + db.execute(_SyncTraceHandler.CREATE_COMMAND) return db @classmethod @@ -150,7 +175,7 @@ def log( ): assert not self.readonly with self.db: - self.db.execute(_SyncStructuredLogHandler.INSERT_COMMAND, dict( + self.db.execute(_SyncTraceHandler.INSERT_COMMAND, dict( guard_name=guard_name, start_time=start_time, end_time=end_time, @@ -160,11 +185,11 @@ def log( log_level=log_level )) - def log_entry(self, guard_log_entry): + def log_entry(self, guard_log_entry: GuardLogEntry): assert not self.readonly with self.db: self.db.execute( - _SyncStructuredLogHandler.INSERT_COMMAND, + _SyncTraceHandler.INSERT_COMMAND, asdict(guard_log_entry) ) @@ -173,7 +198,7 @@ def log_validator(self, vlog: ValidatorLogs): maybe_outcome = str(vlog.validation_result.outcome) \ if hasattr(vlog.validation_result, "outcome") else "" with self.db: - self.db.execute(_SyncStructuredLogHandler.INSERT_COMMAND, dict( + self.db.execute(_SyncTraceHandler.INSERT_COMMAND, dict( guard_name=vlog.validator_name, start_time=vlog.start_time if vlog.start_time else None, end_time=vlog.end_time if vlog.end_time else 0.0, @@ -225,7 +250,10 @@ def tail_logs( cursor.execute(sql, (last_idx,)) -class SyncStructuredLogHandlerSingleton(_SyncStructuredLogHandler): +class SyncTraceHandler(_SyncTraceHandler): + """SyncTraceHandler wraps the internal _SyncTraceHandler to make it multi-thread + safe. Coupled with some write ahead journaling in the _SyncTrace internal, we have + a faux-multi-write multi-read interface for SQLite.""" _instance = None _lock = threading.Lock() @@ -240,9 +268,9 @@ def __new__(cls): return cls._instance @classmethod - def _create(cls, path: os.PathLike = LOG_FILENAME) -> _SyncStructuredLogHandler: - return _SyncStructuredLogHandler(path, read_mode=False) + def _create(cls, path: os.PathLike = LOG_FILENAME) -> _SyncTraceHandler: + return _SyncTraceHandler(path, read_mode=False) @classmethod - def get_reader(cls, path: os.PathLike = LOG_FILENAME) -> _SyncStructuredLogHandler: - return _SyncStructuredLogHandler(path, read_mode=True) + def get_reader(cls, path: os.PathLike = LOG_FILENAME) -> _SyncTraceHandler: + return _SyncTraceHandler(path, read_mode=True) diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index f0738afee..163cf608b 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -7,7 +7,7 @@ from opentelemetry.context import Context from opentelemetry.trace import StatusCode, Tracer -from guardrails.guard_call_logging import SyncStructuredLogHandlerSingleton +from guardrails.guard_call_logging import SyncTraceHandler from guardrails.stores.context import get_tracer as get_context_tracer from guardrails.stores.context import get_tracer_context from guardrails.utils.casting_utils import to_string @@ -102,7 +102,7 @@ def trace_validator_result( **kwargs, } - SyncStructuredLogHandlerSingleton().log_validator(validator_log) + SyncTraceHandler().log_validator(validator_log) current_span.add_event( f"{validator_name}_result", diff --git a/tests/log_multithreaded_experiment.py b/tests/log_multithreaded_experiment.py index 90e5fb58a..be3835dfa 100644 --- a/tests/log_multithreaded_experiment.py +++ b/tests/log_multithreaded_experiment.py @@ -1,29 +1,52 @@ +import asyncio +import concurrent.futures import random import sys import time from multiprocessing import Pool -from guardrails.guard_call_logging import SyncStructuredLogHandlerSingleton +from guardrails.guard_call_logging import SyncTraceHandler DELAY = 0.1 -hoisted_logger = SyncStructuredLogHandlerSingleton() +hoisted_logger = SyncTraceHandler() def main(num_threads: int, num_log_messages: int): log_levels = list() for _ in range(num_log_messages): log_levels.append(random.randint(0, 5)) - print("Trying with hoisted logger:", end=" ") + print("Multiprocessing: Trying with hoisted logger:", end=" ") with Pool(num_threads) as pool: pool.map(log_with_hoisted_logger, log_levels) print("Done.") - print("Trying with acquired logger:", end=" ") + print("Multiprocessing: Trying with acquired logger:", end=" ") with Pool(num_threads) as pool: pool.map(log_with_acquired_singleton, log_levels) print("Done.") - print("Trying with guard: ", end=" ") - log_from_inside_guard() + print("Multithreading: Trying with hoisted logger:", end=" ") + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for level in log_levels: + out = executor.submit(log_with_hoisted_logger, level) + out.result() + #executor.map(log_with_hoisted_logger, log_levels) + print("Done") + print("Multithreading: Trying with acquired logger:", end=" ") + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for level in log_levels: + out = executor.submit(log_with_acquired_singleton, level) + out.result() + #executor.map(log_with_acquired_singleton, log_levels) + print("Done") + print("Asyncio: Trying with hoisted logger:", end=" ") + async def do_it(level: int): + log_with_hoisted_logger(level) + async def do_it_again(level: int): + log_with_acquired_singleton(level) + for level in log_levels: + asyncio.run(do_it(level)) + for level in log_levels: + asyncio.run(do_it_again(level)) print("Done") @@ -45,7 +68,7 @@ def log_with_hoisted_logger(log_level: int): def log_with_acquired_singleton(log_level: int): # Try grabbing a reference to the sync writer. start = time.time() - log = SyncStructuredLogHandlerSingleton() + log = SyncTraceHandler() end = time.time() log.log( "acquired_logger", @@ -59,22 +82,6 @@ def log_with_acquired_singleton(log_level: int): time.sleep(DELAY) -def log_from_inside_guard(): - from pydantic import BaseModel - from guardrails import Guard - from transformers import pipeline - - model = pipeline("text-generation", "gpt2") - - class Foo(BaseModel): - bar: int - - guard = Guard.from_pydantic(Foo) - # This may not trigger the thing: - guard.validate('{"bar": 42}') - guard(model, prompt="Hi") - - if __name__ == '__main__': if "--help" in sys.argv: print("Optional args: --num_threads, --num_log_messages") From c44799f95b3143171866f8f69a88dd839b8a2ebc Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 11:14:12 -0700 Subject: [PATCH 10/25] Move test for guard logger to the right place. --- tests/log_multithreaded_experiment.py | 104 -------------------------- tests/unit_tests/test_guard_log.py | 95 +++++++++++++++++++++++ 2 files changed, 95 insertions(+), 104 deletions(-) delete mode 100644 tests/log_multithreaded_experiment.py create mode 100644 tests/unit_tests/test_guard_log.py diff --git a/tests/log_multithreaded_experiment.py b/tests/log_multithreaded_experiment.py deleted file mode 100644 index be3835dfa..000000000 --- a/tests/log_multithreaded_experiment.py +++ /dev/null @@ -1,104 +0,0 @@ -import asyncio -import concurrent.futures -import random -import sys -import time -from multiprocessing import Pool - -from guardrails.guard_call_logging import SyncTraceHandler - - -DELAY = 0.1 -hoisted_logger = SyncTraceHandler() - - -def main(num_threads: int, num_log_messages: int): - log_levels = list() - for _ in range(num_log_messages): - log_levels.append(random.randint(0, 5)) - print("Multiprocessing: Trying with hoisted logger:", end=" ") - with Pool(num_threads) as pool: - pool.map(log_with_hoisted_logger, log_levels) - print("Done.") - print("Multiprocessing: Trying with acquired logger:", end=" ") - with Pool(num_threads) as pool: - pool.map(log_with_acquired_singleton, log_levels) - print("Done.") - print("Multithreading: Trying with hoisted logger:", end=" ") - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - for level in log_levels: - out = executor.submit(log_with_hoisted_logger, level) - out.result() - #executor.map(log_with_hoisted_logger, log_levels) - print("Done") - print("Multithreading: Trying with acquired logger:", end=" ") - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - for level in log_levels: - out = executor.submit(log_with_acquired_singleton, level) - out.result() - #executor.map(log_with_acquired_singleton, log_levels) - print("Done") - print("Asyncio: Trying with hoisted logger:", end=" ") - async def do_it(level: int): - log_with_hoisted_logger(level) - async def do_it_again(level: int): - log_with_acquired_singleton(level) - for level in log_levels: - asyncio.run(do_it(level)) - for level in log_levels: - asyncio.run(do_it_again(level)) - print("Done") - - -def log_with_hoisted_logger(log_level: int): - start = time.time() - end = time.time() - hoisted_logger.log( - "hoisted_logger", - start, - end, - "Kept logger from hoisted.", - "Success.", - "", - log_level - ) - time.sleep(DELAY) - - -def log_with_acquired_singleton(log_level: int): - # Try grabbing a reference to the sync writer. - start = time.time() - log = SyncTraceHandler() - end = time.time() - log.log( - "acquired_logger", - start, - end, - "Got logger with acquired singleton.", - "It worked.", - "", - log_level - ) - time.sleep(DELAY) - - -if __name__ == '__main__': - if "--help" in sys.argv: - print("Optional args: --num_threads, --num_log_messages") - else: - thread_count = 4 - try: - num_threads_arg_pos = sys.argv.index('--num_threads') - if num_threads_arg_pos != -1: - thread_count = int(sys.argv[num_threads_arg_pos + 1]) - except Exception: - pass - log_message_count = 10 - try: - num_log_messages_pos = sys.argv.index('--num_log_messages') - if num_log_messages_pos != -1: - log_message_count = int(sys.argv[num_log_messages_pos + 1]) - except Exception: - pass - - main(num_threads=thread_count, num_log_messages=log_message_count) diff --git a/tests/unit_tests/test_guard_log.py b/tests/unit_tests/test_guard_log.py new file mode 100644 index 000000000..5f033d019 --- /dev/null +++ b/tests/unit_tests/test_guard_log.py @@ -0,0 +1,95 @@ +import asyncio +import concurrent.futures +import random +import sys +import time +from multiprocessing import Pool + +from guardrails.guard_call_logging import SyncTraceHandler + +NUM_THREADS = 4 + +STOCK_MESSAGES = [ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", + "Ut enim ad minim veniam", + "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat", + "Heeeey Macarena.", + "Excepteur sint occaecat cupidatat non proident,", + "sunt in culpa qui officia deserunt mollit anim id est laborum." +] + + +# This is hoisted for testing to see how well we share. +_trace_logger = SyncTraceHandler() + + +def test_multiprocessing_hoisted(): + """Preallocate a shared trace handler and try to log from multiple subprocesses.""" + with Pool(NUM_THREADS) as pool: + pool.map(_hoisted_logger, ["multiproc_hoist" + msg for msg in STOCK_MESSAGES]) + + +def test_multiprocessing_acquired(): + with Pool(NUM_THREADS) as pool: + pool.map(_acquired_logger, ["multiproc_acq" + msg for msg in STOCK_MESSAGES]) + + +def test_multithreading_hoisted(): + with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + for msg in STOCK_MESSAGES: + out = executor.submit(_hoisted_logger, "multithread_hoist" + msg) + out.result() + #executor.map(log_with_hoisted_logger, log_levels) + + +def test_multithreading_acquired(): + with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + for msg in STOCK_MESSAGES: + out = executor.submit(_acquired_logger, "multithread_acq" + msg) + out.result() + + +def test_asyncio_hoisted(): + async def do_it(msg: str): + _hoisted_logger(msg) + + for m in STOCK_MESSAGES: + asyncio.run(do_it("async_hoisted" + m)) + + +def test_asyncio_acquired(): + async def do_it_again(msg: str): + _acquired_logger(msg) + + for m in STOCK_MESSAGES: + asyncio.run(do_it_again("async_acq" + m)) + + +def _hoisted_logger(msg: str): + _trace_logger.log( + "hoisted", + time.time(), + time.time(), + "Testing the behavior of a hoisted logger.", + msg, + "", + 0 + ) + + +def _acquired_logger(msg): + # Note that the trace logger is acquired INSIDE the method: + start = time.time() + trace_logger = SyncTraceHandler() + end = time.time() + trace_logger.log( + "acquired", + start, + end, + "Testing behavior of an acquired logger.", + msg, + "", + 0 + ) From c7f8f7907e13a6ac04758bd817e18d587bdbc1b8 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 12:02:31 -0700 Subject: [PATCH 11/25] Write to tempfile rather than local directory. Add method to truncate logs. --- guardrails/cli/watch.py | 6 ++-- guardrails/guard_call_logging.py | 54 ++++++++++++++++++++--------- guardrails/utils/telemetry_utils.py | 4 +-- tests/unit_tests/test_guard_log.py | 6 ++-- 4 files changed, 45 insertions(+), 25 deletions(-) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index 841c30955..c966c93a8 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -10,7 +10,7 @@ from guardrails.cli.guardrails import guardrails as gr_cli from guardrails.guard_call_logging import ( GuardLogEntry, - SyncTraceHandler, + TraceHandler, ) @@ -39,9 +39,9 @@ def watch_command( while log_reader is None: try: if log_path_override is not None: - log_reader = SyncTraceHandler.get_reader(log_path_override) + log_reader = TraceHandler.get_reader(log_path_override) else: - log_reader = SyncTraceHandler.get_reader() + log_reader = TraceHandler.get_reader() except sqlite3.OperationalError: print("Logfile not found. Retrying.") time.sleep(1) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index c93e47bd8..a54ebc403 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -25,17 +25,19 @@ import datetime import os import sqlite3 +import tempfile import threading from dataclasses import dataclass, asdict from typing import Iterator -from guardrails.classes import ValidationOutcome from guardrails.utils.casting_utils import to_string -from guardrails.classes.history import Call from guardrails.classes.validation.validator_logs import ValidatorLogs +# TODO: We should read this from guardrailsrc. +LOG_RETENTION_LIMIT = 1000000 LOG_FILENAME = "guardrails_calls.db" +LOGFILE_PATH = os.path.join(tempfile.gettempdir(), LOG_FILENAME) # These adapters make it more convenient to add data into our log DB: @@ -77,7 +79,8 @@ def timedelta(self): return self.end_time - self.start_time -class _NoopTraceHandler: +class _BaseTraceHandler: + """The base TraceHandler only pads out the methods. It's effectively a Noop""" def __init__(self, log_path: os.PathLike, read_mode: bool): pass @@ -101,7 +104,7 @@ def tail_logs( # This structured handler shouldn't be used directly, since it's touching a SQLite db. # Instead, use the singleton or the async singleton. -class _SyncTraceHandler: +class _SQLiteTraceHandler(_BaseTraceHandler): CREATE_COMMAND = """ CREATE TABLE IF NOT EXISTS guard_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -125,11 +128,16 @@ class _SyncTraceHandler: """ def __init__(self, log_path: os.PathLike, read_mode: bool): + self._log_path = log_path # Read-only value. self.readonly = read_mode if read_mode: - self.db = _SyncTraceHandler._get_read_connection(log_path) + self.db = _SQLiteTraceHandler._get_read_connection(log_path) else: - self.db = _SyncTraceHandler._get_write_connection(log_path) + self.db = _SQLiteTraceHandler._get_write_connection(log_path) + + @property + def log_path(self): + return self._log_path @classmethod def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: @@ -149,7 +157,7 @@ def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: #logging.exception("Unable to connect to guard log handler.") raise e with db: - db.execute(_SyncTraceHandler.CREATE_COMMAND) + db.execute(_SQLiteTraceHandler.CREATE_COMMAND) return db @classmethod @@ -163,6 +171,18 @@ def _get_read_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: db.row_factory = sqlite3.Row return db + def _truncate(self, keep_n: int = LOG_RETENTION_LIMIT): + assert not self.readonly + self.db.execute( + """ + DELETE FROM guard_logs + WHERE id <= ( + SELECT id FROM guard_logs ORDER BY id DESC LIMIT 1 OFFSET ? + ); + """, + (keep_n,) + ) + def log( self, guard_name: str, @@ -175,7 +195,7 @@ def log( ): assert not self.readonly with self.db: - self.db.execute(_SyncTraceHandler.INSERT_COMMAND, dict( + self.db.execute(_SQLiteTraceHandler.INSERT_COMMAND, dict( guard_name=guard_name, start_time=start_time, end_time=end_time, @@ -189,7 +209,7 @@ def log_entry(self, guard_log_entry: GuardLogEntry): assert not self.readonly with self.db: self.db.execute( - _SyncTraceHandler.INSERT_COMMAND, + _SQLiteTraceHandler.INSERT_COMMAND, asdict(guard_log_entry) ) @@ -198,7 +218,7 @@ def log_validator(self, vlog: ValidatorLogs): maybe_outcome = str(vlog.validation_result.outcome) \ if hasattr(vlog.validation_result, "outcome") else "" with self.db: - self.db.execute(_SyncTraceHandler.INSERT_COMMAND, dict( + self.db.execute(_SQLiteTraceHandler.INSERT_COMMAND, dict( guard_name=vlog.validator_name, start_time=vlog.start_time if vlog.start_time else None, end_time=vlog.end_time if vlog.end_time else 0.0, @@ -214,7 +234,7 @@ def tail_logs( follow: bool = False ) -> Iterator[GuardLogEntry]: """Returns an iterator to generate GuardLogEntries. - @param start_offset_idx int : Start printing entries after this IDX. If + @param start_offset_idx : Start printing entries after this IDX. If negative, this will instead start printing the LAST start_offset_idx entries. @param follow : If follow is True, will re-check the database for new entries after the first batch is complete. If False (default), will return when entries @@ -250,8 +270,8 @@ def tail_logs( cursor.execute(sql, (last_idx,)) -class SyncTraceHandler(_SyncTraceHandler): - """SyncTraceHandler wraps the internal _SyncTraceHandler to make it multi-thread +class TraceHandler(_SQLiteTraceHandler): + """TraceHandler wraps the internal _SQLiteTraceHandler to make it multi-thread safe. Coupled with some write ahead journaling in the _SyncTrace internal, we have a faux-multi-write multi-read interface for SQLite.""" _instance = None @@ -268,9 +288,9 @@ def __new__(cls): return cls._instance @classmethod - def _create(cls, path: os.PathLike = LOG_FILENAME) -> _SyncTraceHandler: - return _SyncTraceHandler(path, read_mode=False) + def _create(cls, path: os.PathLike = LOGFILE_PATH) -> _SQLiteTraceHandler: + return _SQLiteTraceHandler(path, read_mode=False) @classmethod - def get_reader(cls, path: os.PathLike = LOG_FILENAME) -> _SyncTraceHandler: - return _SyncTraceHandler(path, read_mode=True) + def get_reader(cls, path: os.PathLike = LOGFILE_PATH) -> _SQLiteTraceHandler: + return _SQLiteTraceHandler(path, read_mode=True) diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index 163cf608b..19aae63c6 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -7,7 +7,7 @@ from opentelemetry.context import Context from opentelemetry.trace import StatusCode, Tracer -from guardrails.guard_call_logging import SyncTraceHandler +from guardrails.guard_call_logging import TraceHandler from guardrails.stores.context import get_tracer as get_context_tracer from guardrails.stores.context import get_tracer_context from guardrails.utils.casting_utils import to_string @@ -102,7 +102,7 @@ def trace_validator_result( **kwargs, } - SyncTraceHandler().log_validator(validator_log) + TraceHandler().log_validator(validator_log) current_span.add_event( f"{validator_name}_result", diff --git a/tests/unit_tests/test_guard_log.py b/tests/unit_tests/test_guard_log.py index 5f033d019..28c04e6fe 100644 --- a/tests/unit_tests/test_guard_log.py +++ b/tests/unit_tests/test_guard_log.py @@ -5,7 +5,7 @@ import time from multiprocessing import Pool -from guardrails.guard_call_logging import SyncTraceHandler +from guardrails.guard_call_logging import TraceHandler NUM_THREADS = 4 @@ -22,7 +22,7 @@ # This is hoisted for testing to see how well we share. -_trace_logger = SyncTraceHandler() +_trace_logger = TraceHandler() def test_multiprocessing_hoisted(): @@ -82,7 +82,7 @@ def _hoisted_logger(msg: str): def _acquired_logger(msg): # Note that the trace logger is acquired INSIDE the method: start = time.time() - trace_logger = SyncTraceHandler() + trace_logger = TraceHandler() end = time.time() trace_logger.log( "acquired", From cd5f7c4a6af2b2a80bdf17154d02b49678ee855e Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 12:07:26 -0700 Subject: [PATCH 12/25] Reformat. --- guardrails/cli/watch.py | 7 +- guardrails/guard_call_logging.py | 101 +++++++++++++++-------------- tests/unit_tests/test_guard_log.py | 14 ++-- 3 files changed, 59 insertions(+), 63 deletions(-) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index c966c93a8..286067508 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -19,19 +19,18 @@ def watch_command( plain: bool = typer.Option( default=False, is_flag=True, - help="Do not use any rich formatting, instead printing each entry on a line." + help="Do not use any rich formatting, instead printing each entry on a line.", ), num_lines: int = typer.Option( default=0, - help="Print the last n most recent lines. If omitted, will print all history." + help="Print the last n most recent lines. If omitted, will print all history.", ), follow: bool = typer.Option( default=False, help="Continuously read the last output commands", ), log_path_override: Optional[str] = typer.Option( - default=None, - help="Specify a path to the log output file." + default=None, help="Specify a path to the log output file." ), ): # Open a reader for the log path: diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index a54ebc403..ddb77fa2f 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -81,6 +81,7 @@ def timedelta(self): class _BaseTraceHandler: """The base TraceHandler only pads out the methods. It's effectively a Noop""" + def __init__(self, log_path: os.PathLike, read_mode: bool): pass @@ -94,14 +95,13 @@ def log_validator(self, vlog: ValidatorLogs): pass def tail_logs( - self, - start_offset_idx: int = 0, - follow: bool = False, + self, + start_offset_idx: int = 0, + follow: bool = False, ) -> Iterator[GuardLogEntry]: return [] - # This structured handler shouldn't be used directly, since it's touching a SQLite db. # Instead, use the singleton or the async singleton. class _SQLiteTraceHandler(_BaseTraceHandler): @@ -147,14 +147,14 @@ def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: isolation_level=None, check_same_thread=False, ) - db.execute('PRAGMA journal_mode = wal') - db.execute('PRAGMA synchronous = OFF') + db.execute("PRAGMA journal_mode = wal") + db.execute("PRAGMA synchronous = OFF") # isolation_level = None and pragma WAL means we can READ from the DB # while threads using it are writing. Synchronous off puts us on the # highway to the danger zone, depending on how willing we are to lose log # messages in the event of a guard crash. except sqlite3.OperationalError as e: - #logging.exception("Unable to connect to guard log handler.") + # logging.exception("Unable to connect to guard log handler.") raise e with db: db.execute(_SQLiteTraceHandler.CREATE_COMMAND) @@ -164,9 +164,7 @@ def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: def _get_read_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: # A bit of a hack to open in read-only mode... db = sqlite3.connect( - "file:" + log_path + "?mode=ro", - isolation_level=None, - uri=True + "file:" + log_path + "?mode=ro", isolation_level=None, uri=True ) db.row_factory = sqlite3.Row return db @@ -176,62 +174,66 @@ def _truncate(self, keep_n: int = LOG_RETENTION_LIMIT): self.db.execute( """ DELETE FROM guard_logs - WHERE id <= ( + WHERE id < ( SELECT id FROM guard_logs ORDER BY id DESC LIMIT 1 OFFSET ? ); """, - (keep_n,) + (keep_n,), ) def log( - self, - guard_name: str, - start_time: float, - end_time: float, - prevalidate_text: str, - postvalidate_text: str, - exception_text: str, - log_level: int, + self, + guard_name: str, + start_time: float, + end_time: float, + prevalidate_text: str, + postvalidate_text: str, + exception_text: str, + log_level: int, ): assert not self.readonly with self.db: - self.db.execute(_SQLiteTraceHandler.INSERT_COMMAND, dict( - guard_name=guard_name, - start_time=start_time, - end_time=end_time, - prevalidate_text=prevalidate_text, - postvalidate_text=postvalidate_text, - exception_message=exception_text, - log_level=log_level - )) + self.db.execute( + _SQLiteTraceHandler.INSERT_COMMAND, + dict( + guard_name=guard_name, + start_time=start_time, + end_time=end_time, + prevalidate_text=prevalidate_text, + postvalidate_text=postvalidate_text, + exception_message=exception_text, + log_level=log_level, + ), + ) def log_entry(self, guard_log_entry: GuardLogEntry): assert not self.readonly with self.db: - self.db.execute( - _SQLiteTraceHandler.INSERT_COMMAND, - asdict(guard_log_entry) - ) + self.db.execute(_SQLiteTraceHandler.INSERT_COMMAND, asdict(guard_log_entry)) def log_validator(self, vlog: ValidatorLogs): assert not self.readonly - maybe_outcome = str(vlog.validation_result.outcome) \ - if hasattr(vlog.validation_result, "outcome") else "" + maybe_outcome = ( + str(vlog.validation_result.outcome) + if hasattr(vlog.validation_result, "outcome") + else "" + ) with self.db: - self.db.execute(_SQLiteTraceHandler.INSERT_COMMAND, dict( - guard_name=vlog.validator_name, - start_time=vlog.start_time if vlog.start_time else None, - end_time=vlog.end_time if vlog.end_time else 0.0, - prevalidate_text=to_string(vlog.value_before_validation), - postvalidate_text=to_string(vlog.value_after_validation), - exception_message=maybe_outcome, - log_level=0 - )) + self.db.execute( + _SQLiteTraceHandler.INSERT_COMMAND, + dict( + guard_name=vlog.validator_name, + start_time=vlog.start_time if vlog.start_time else None, + end_time=vlog.end_time if vlog.end_time else 0.0, + prevalidate_text=to_string(vlog.value_before_validation), + postvalidate_text=to_string(vlog.value_after_validation), + exception_message=maybe_outcome, + log_level=0, + ), + ) def tail_logs( - self, - start_offset_idx: int = 0, - follow: bool = False + self, start_offset_idx: int = 0, follow: bool = False ) -> Iterator[GuardLogEntry]: """Returns an iterator to generate GuardLogEntries. @param start_offset_idx : Start printing entries after this IDX. If @@ -246,10 +248,10 @@ def tail_logs( # We're indexing from the end, so do a quick check. cursor.execute( "SELECT id FROM guard_logs ORDER BY id DESC LIMIT 1 OFFSET ?;", - (-last_idx,) + (-last_idx,), ) for row in cursor: - last_idx = row['id'] + last_idx = row["id"] sql = """ SELECT id, guard_name, start_time, end_time, prevalidate_text, @@ -274,6 +276,7 @@ class TraceHandler(_SQLiteTraceHandler): """TraceHandler wraps the internal _SQLiteTraceHandler to make it multi-thread safe. Coupled with some write ahead journaling in the _SyncTrace internal, we have a faux-multi-write multi-read interface for SQLite.""" + _instance = None _lock = threading.Lock() diff --git a/tests/unit_tests/test_guard_log.py b/tests/unit_tests/test_guard_log.py index 28c04e6fe..524bccb21 100644 --- a/tests/unit_tests/test_guard_log.py +++ b/tests/unit_tests/test_guard_log.py @@ -17,7 +17,7 @@ "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat", "Heeeey Macarena.", "Excepteur sint occaecat cupidatat non proident,", - "sunt in culpa qui officia deserunt mollit anim id est laborum." + "sunt in culpa qui officia deserunt mollit anim id est laborum.", ] @@ -41,7 +41,7 @@ def test_multithreading_hoisted(): for msg in STOCK_MESSAGES: out = executor.submit(_hoisted_logger, "multithread_hoist" + msg) out.result() - #executor.map(log_with_hoisted_logger, log_levels) + # executor.map(log_with_hoisted_logger, log_levels) def test_multithreading_acquired(): @@ -75,7 +75,7 @@ def _hoisted_logger(msg: str): "Testing the behavior of a hoisted logger.", msg, "", - 0 + 0, ) @@ -85,11 +85,5 @@ def _acquired_logger(msg): trace_logger = TraceHandler() end = time.time() trace_logger.log( - "acquired", - start, - end, - "Testing behavior of an acquired logger.", - msg, - "", - 0 + "acquired", start, end, "Testing behavior of an acquired logger.", msg, "", 0 ) From f49e8fcf7b076c9db204783b54a68816e40e623c Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 12:13:24 -0700 Subject: [PATCH 13/25] Default to follow (by request). Remove unused log level. --- guardrails/cli/watch.py | 2 +- guardrails/guard_call_logging.py | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index 286067508..ecacfbccb 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -26,7 +26,7 @@ def watch_command( help="Print the last n most recent lines. If omitted, will print all history.", ), follow: bool = typer.Option( - default=False, + default=True, help="Continuously read the last output commands", ), log_path_override: Optional[str] = typer.Option( diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index ddb77fa2f..360b2d692 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -35,7 +35,7 @@ # TODO: We should read this from guardrailsrc. -LOG_RETENTION_LIMIT = 1000000 +LOG_RETENTION_LIMIT = 100000 LOG_FILENAME = "guardrails_calls.db" LOGFILE_PATH = os.path.join(tempfile.gettempdir(), LOG_FILENAME) @@ -68,7 +68,6 @@ class GuardLogEntry: guard_name: str start_time: float end_time: float - log_level: int id: int = -1 prevalidate_text: str = "" postvalidate_text: str = "" @@ -113,8 +112,7 @@ class _SQLiteTraceHandler(_BaseTraceHandler): end_time REAL, prevalidate_text TEXT, postvalidate_text TEXT, - exception_message TEXT, - log_level INTEGER + exception_message TEXT ); """ INSERT_COMMAND = """ @@ -123,7 +121,7 @@ class _SQLiteTraceHandler(_BaseTraceHandler): exception_message, log_level ) VALUES ( :guard_name, :start_time, :end_time, :prevalidate_text, :postvalidate_text, - :exception_message, :log_level + :exception_message ); """ @@ -189,7 +187,6 @@ def log( prevalidate_text: str, postvalidate_text: str, exception_text: str, - log_level: int, ): assert not self.readonly with self.db: @@ -202,7 +199,6 @@ def log( prevalidate_text=prevalidate_text, postvalidate_text=postvalidate_text, exception_message=exception_text, - log_level=log_level, ), ) @@ -228,7 +224,6 @@ def log_validator(self, vlog: ValidatorLogs): prevalidate_text=to_string(vlog.value_before_validation), postvalidate_text=to_string(vlog.value_after_validation), exception_message=maybe_outcome, - log_level=0, ), ) @@ -255,7 +250,7 @@ def tail_logs( sql = """ SELECT id, guard_name, start_time, end_time, prevalidate_text, - postvalidate_text, exception_message, log_level + postvalidate_text, exception_message FROM guard_logs WHERE id > ? ORDER BY start_time; From 9b1b512cb08064bc28148f4340d061af05384cd0 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 12:21:01 -0700 Subject: [PATCH 14/25] Fix doctest. Update docstring. --- guardrails/guard_call_logging.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index 360b2d692..ba9b4b221 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -5,21 +5,24 @@ collating the pre/post validation text and timing of guard calls. Uses a singleton to share write access to a SQLite database across threads. +By default, logs will be created in a temporary directory. This can be overridden by +setting GUARDRAILS_LOG_FILE_PATH in the environment. tracehandler.log_path will give +the full path of the current log file. + # Reading logs (basic): -reader = SyncStructuredLogHandlerSingleton.get_reader() -for t in reader.tail_logs(): - print(t) +>>> reader = TraceHandler.get_reader() +>>> for t in reader.tail_logs(): +>>> print(t) # Reading logs (advanced): -reader = SyncStructuredLogHandlerSingleton.get_reader() -reader.db.execute("SELECT * FROM guard_logs;") # Arbitrary SQL support. +>>> reader = TraceHandler.get_reader() +>>> reader.db.execute("SELECT * FROM guard_logs;") # Arbitrary SQL support. # Saving logs -writer = SynbcStructuredLogHandlerSingleton() -writer.log( - "my_guard_name", start, end, "Raw LLM Output Text", "Sanitized", "exception?", 0 -) - +>>> writer = TraceHandler() +>>> writer.log( +>>> "my_guard_name", 0.0, 1.0, "Raw LLM Output Text", "Sanitized", "exception?" +>>> ) """ import datetime @@ -37,7 +40,10 @@ # TODO: We should read this from guardrailsrc. LOG_RETENTION_LIMIT = 100000 LOG_FILENAME = "guardrails_calls.db" -LOGFILE_PATH = os.path.join(tempfile.gettempdir(), LOG_FILENAME) +LOGFILE_PATH = os.environ.get( + "GUARDRAILS_LOG_FILE_PATH", # Document this environment variable. + os.path.join(tempfile.gettempdir(), LOG_FILENAME), +) # These adapters make it more convenient to add data into our log DB: @@ -118,7 +124,7 @@ class _SQLiteTraceHandler(_BaseTraceHandler): INSERT_COMMAND = """ INSERT INTO guard_logs ( guard_name, start_time, end_time, prevalidate_text, postvalidate_text, - exception_message, log_level + exception_message ) VALUES ( :guard_name, :start_time, :end_time, :prevalidate_text, :postvalidate_text, :exception_message From f10a6d5724d0a113fb918687f3e50ddf9fd39f78 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 12:26:39 -0700 Subject: [PATCH 15/25] Format. --- guardrails/guard_call_logging.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index ba9b4b221..e2a31f98e 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -292,9 +292,10 @@ def __new__(cls): return cls._instance @classmethod - def _create(cls, path: os.PathLike = LOGFILE_PATH) -> _SQLiteTraceHandler: - return _SQLiteTraceHandler(path, read_mode=False) + def _create(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: + # return _SQLiteTraceHandler(path, read_mode=False) + return _BaseTraceHandler(path, read_mode=False) @classmethod - def get_reader(cls, path: os.PathLike = LOGFILE_PATH) -> _SQLiteTraceHandler: + def get_reader(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: return _SQLiteTraceHandler(path, read_mode=True) From 3d12e665239db5e18a23dc25755ae6808cbd7f58 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 12:28:47 -0700 Subject: [PATCH 16/25] Relint. --- guardrails/cli/__init__.py | 2 +- tests/unit_tests/test_guard_log.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/guardrails/cli/__init__.py b/guardrails/cli/__init__.py index 0d2c9f92d..b10af851b 100644 --- a/guardrails/cli/__init__.py +++ b/guardrails/cli/__init__.py @@ -3,7 +3,7 @@ import guardrails.cli.validate # noqa from guardrails.cli.guardrails import guardrails as cli from guardrails.cli.hub import hub_command -from guardrails.cli.watch import watch_command +from guardrails.cli.watch import watch_command # noqa: F401 cli.add_typer( diff --git a/tests/unit_tests/test_guard_log.py b/tests/unit_tests/test_guard_log.py index 524bccb21..c47fc3718 100644 --- a/tests/unit_tests/test_guard_log.py +++ b/tests/unit_tests/test_guard_log.py @@ -1,7 +1,5 @@ import asyncio import concurrent.futures -import random -import sys import time from multiprocessing import Pool From 6f1a3791f6c728350ec124351ad648d8f98d011c Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 12:30:02 -0700 Subject: [PATCH 17/25] Accidentally only returned Noop handler. --- guardrails/guard_call_logging.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index e2a31f98e..4b0780b37 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -293,8 +293,7 @@ def __new__(cls): @classmethod def _create(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: - # return _SQLiteTraceHandler(path, read_mode=False) - return _BaseTraceHandler(path, read_mode=False) + return _SQLiteTraceHandler(path, read_mode=False) @classmethod def get_reader(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: From 6b66bda26e919c03a450d4734da2b3f8a944eb6a Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 13:08:13 -0700 Subject: [PATCH 18/25] Remove error level and reformat. --- tests/unit_tests/test_guard_log.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit_tests/test_guard_log.py b/tests/unit_tests/test_guard_log.py index c47fc3718..0fde2ad0b 100644 --- a/tests/unit_tests/test_guard_log.py +++ b/tests/unit_tests/test_guard_log.py @@ -73,7 +73,6 @@ def _hoisted_logger(msg: str): "Testing the behavior of a hoisted logger.", msg, "", - 0, ) @@ -83,5 +82,5 @@ def _acquired_logger(msg): trace_logger = TraceHandler() end = time.time() trace_logger.log( - "acquired", start, end, "Testing behavior of an acquired logger.", msg, "", 0 + "acquired", start, end, "Testing behavior of an acquired logger.", msg, "" ) From 5da2e5d431363bfc5bac80d9f97337bc3a72c8c3 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 13:38:00 -0700 Subject: [PATCH 19/25] Linting. --- guardrails/cli/watch.py | 3 ++- guardrails/guard_call_logging.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index ecacfbccb..1110ce893 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -38,7 +38,8 @@ def watch_command( while log_reader is None: try: if log_path_override is not None: - log_reader = TraceHandler.get_reader(log_path_override) + # For the noqa: Strings are pathlike. Ignore conversion issue. + log_reader = TraceHandler.get_reader(log_path_override) # noqa else: log_reader = TraceHandler.get_reader() except sqlite3.OperationalError: diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index 4b0780b37..b0beb1f17 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -104,7 +104,7 @@ def tail_logs( start_offset_idx: int = 0, follow: bool = False, ) -> Iterator[GuardLogEntry]: - return [] + yield from [] # This structured handler shouldn't be used directly, since it's touching a SQLite db. @@ -168,7 +168,7 @@ def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: def _get_read_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: # A bit of a hack to open in read-only mode... db = sqlite3.connect( - "file:" + log_path + "?mode=ro", isolation_level=None, uri=True + "file:" + str(log_path) + "?mode=ro", isolation_level=None, uri=True ) db.row_factory = sqlite3.Row return db From 8e2ac75cf6e13066343809f137bd2a1441ed28d5 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 13:52:00 -0700 Subject: [PATCH 20/25] Fix lint and pyrite issues. --- guardrails/cli/watch.py | 3 +-- guardrails/guard_call_logging.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index 1110ce893..65dfa58fb 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -38,8 +38,7 @@ def watch_command( while log_reader is None: try: if log_path_override is not None: - # For the noqa: Strings are pathlike. Ignore conversion issue. - log_reader = TraceHandler.get_reader(log_path_override) # noqa + log_reader = TraceHandler.get_reader(log_path_override) # type: ignore else: log_reader = TraceHandler.get_reader() except sqlite3.OperationalError: diff --git a/guardrails/guard_call_logging.py b/guardrails/guard_call_logging.py index b0beb1f17..213ea149e 100644 --- a/guardrails/guard_call_logging.py +++ b/guardrails/guard_call_logging.py @@ -88,7 +88,7 @@ class _BaseTraceHandler: """The base TraceHandler only pads out the methods. It's effectively a Noop""" def __init__(self, log_path: os.PathLike, read_mode: bool): - pass + self.db = None def log(self, *args, **kwargs): pass @@ -217,7 +217,10 @@ def log_validator(self, vlog: ValidatorLogs): assert not self.readonly maybe_outcome = ( str(vlog.validation_result.outcome) - if hasattr(vlog.validation_result, "outcome") + if ( + vlog.validation_result is not None + and hasattr(vlog.validation_result, "outcome") + ) else "" ) with self.db: @@ -292,9 +295,11 @@ def __new__(cls): return cls._instance @classmethod - def _create(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: + def _create(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: # type: ignore return _SQLiteTraceHandler(path, read_mode=False) + # To disable logging: + # return _BaseTraceHandler(path, read_mode=False) @classmethod - def get_reader(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: + def get_reader(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: # type: ignore return _SQLiteTraceHandler(path, read_mode=True) From 6e38053500e692fcc9303c806f8eaf25a9042d5e Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 14:34:30 -0700 Subject: [PATCH 21/25] PR Feedback: Move guard_call_logging to tracing. --- guardrails/call_tracing/__init__.py | 3 +++ guardrails/{ => call_tracing}/guard_call_logging.py | 0 guardrails/cli/watch.py | 2 +- guardrails/utils/telemetry_utils.py | 2 +- tests/unit_tests/test_guard_log.py | 2 +- 5 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 guardrails/call_tracing/__init__.py rename guardrails/{ => call_tracing}/guard_call_logging.py (100%) diff --git a/guardrails/call_tracing/__init__.py b/guardrails/call_tracing/__init__.py new file mode 100644 index 000000000..685a830bc --- /dev/null +++ b/guardrails/call_tracing/__init__.py @@ -0,0 +1,3 @@ +""" +For tracing (logging) and reporting the timing of Guard and Validator calls. +""" \ No newline at end of file diff --git a/guardrails/guard_call_logging.py b/guardrails/call_tracing/guard_call_logging.py similarity index 100% rename from guardrails/guard_call_logging.py rename to guardrails/call_tracing/guard_call_logging.py diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index 65dfa58fb..568a396e1 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -8,7 +8,7 @@ import typer from guardrails.cli.guardrails import guardrails as gr_cli -from guardrails.guard_call_logging import ( +from guardrails.call_tracing.guard_call_logging import ( GuardLogEntry, TraceHandler, ) diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index 19aae63c6..e0dec867e 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -7,7 +7,7 @@ from opentelemetry.context import Context from opentelemetry.trace import StatusCode, Tracer -from guardrails.guard_call_logging import TraceHandler +from guardrails.call_tracing.guard_call_logging import TraceHandler from guardrails.stores.context import get_tracer as get_context_tracer from guardrails.stores.context import get_tracer_context from guardrails.utils.casting_utils import to_string diff --git a/tests/unit_tests/test_guard_log.py b/tests/unit_tests/test_guard_log.py index 0fde2ad0b..89f4decc3 100644 --- a/tests/unit_tests/test_guard_log.py +++ b/tests/unit_tests/test_guard_log.py @@ -3,7 +3,7 @@ import time from multiprocessing import Pool -from guardrails.guard_call_logging import TraceHandler +from guardrails.call_tracing.guard_call_logging import TraceHandler NUM_THREADS = 4 From 46bc48a8f0eb2aeb3d1f941a5f71296b7f363f41 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 14:47:23 -0700 Subject: [PATCH 22/25] PR Feeback: Move things to individual namespaced files. --- guardrails/call_tracing/__init__.py | 6 +- guardrails/call_tracing/guard_call_logging.py | 250 +----------------- .../call_tracing/sqlite_trace_handler.py | 212 +++++++++++++++ guardrails/call_tracing/trace_entry.py | 16 ++ guardrails/call_tracing/tracer_mixin.py | 28 ++ guardrails/cli/watch.py | 6 +- guardrails/utils/telemetry_utils.py | 2 +- 7 files changed, 272 insertions(+), 248 deletions(-) create mode 100644 guardrails/call_tracing/sqlite_trace_handler.py create mode 100644 guardrails/call_tracing/trace_entry.py create mode 100644 guardrails/call_tracing/tracer_mixin.py diff --git a/guardrails/call_tracing/__init__.py b/guardrails/call_tracing/__init__.py index 685a830bc..c9d214a87 100644 --- a/guardrails/call_tracing/__init__.py +++ b/guardrails/call_tracing/__init__.py @@ -1,3 +1,7 @@ """ For tracing (logging) and reporting the timing of Guard and Validator calls. -""" \ No newline at end of file +""" +from guardrails.call_tracing.trace_entry import GuardTraceEntry +from guardrails.call_tracing.guard_call_logging import TraceHandler + +__all__ = ['GuardTraceEntry', 'TraceHandler'] \ No newline at end of file diff --git a/guardrails/call_tracing/guard_call_logging.py b/guardrails/call_tracing/guard_call_logging.py index 213ea149e..cb1188681 100644 --- a/guardrails/call_tracing/guard_call_logging.py +++ b/guardrails/call_tracing/guard_call_logging.py @@ -25,20 +25,14 @@ >>> ) """ -import datetime import os -import sqlite3 import tempfile import threading -from dataclasses import dataclass, asdict -from typing import Iterator - -from guardrails.utils.casting_utils import to_string -from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.call_tracing.sqlite_trace_handler import SQLiteTraceHandler +from guardrails.call_tracing.tracer_mixin import TracerMixin # TODO: We should read this from guardrailsrc. -LOG_RETENTION_LIMIT = 100000 LOG_FILENAME = "guardrails_calls.db" LOGFILE_PATH = os.environ.get( "GUARDRAILS_LOG_FILE_PATH", # Document this environment variable. @@ -46,237 +40,7 @@ ) -# These adapters make it more convenient to add data into our log DB: -# Handle timestamp -> sqlite map: -def adapt_datetime(val): - """Adapt datetime.datetime to Unix timestamp.""" - # return val.isoformat() # If we want to go to datetime/isoformat... - return int(val.timestamp()) - - -sqlite3.register_adapter(datetime.datetime, adapt_datetime) - - -def convert_timestamp(val): - """Convert Unix epoch timestamp to datetime.datetime object.""" - # To go to datetime.datetime: - # return datetime.datetime.fromisoformat(val.decode()) - return datetime.datetime.fromtimestamp(int(val)) - - -sqlite3.register_converter("timestamp", convert_timestamp) - - -# This class makes it slightly easier to be selective about how we pull data. -# While it's not the ultimate contract/DB schema, it helps with typing and improves dx. -@dataclass -class GuardLogEntry: - guard_name: str - start_time: float - end_time: float - id: int = -1 - prevalidate_text: str = "" - postvalidate_text: str = "" - exception_message: str = "" - - @property - def timedelta(self): - return self.end_time - self.start_time - - -class _BaseTraceHandler: - """The base TraceHandler only pads out the methods. It's effectively a Noop""" - - def __init__(self, log_path: os.PathLike, read_mode: bool): - self.db = None - - def log(self, *args, **kwargs): - pass - - def log_entry(self, guard_log_entry: GuardLogEntry): - pass - - def log_validator(self, vlog: ValidatorLogs): - pass - - def tail_logs( - self, - start_offset_idx: int = 0, - follow: bool = False, - ) -> Iterator[GuardLogEntry]: - yield from [] - - -# This structured handler shouldn't be used directly, since it's touching a SQLite db. -# Instead, use the singleton or the async singleton. -class _SQLiteTraceHandler(_BaseTraceHandler): - CREATE_COMMAND = """ - CREATE TABLE IF NOT EXISTS guard_logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - guard_name TEXT, - start_time REAL, - end_time REAL, - prevalidate_text TEXT, - postvalidate_text TEXT, - exception_message TEXT - ); - """ - INSERT_COMMAND = """ - INSERT INTO guard_logs ( - guard_name, start_time, end_time, prevalidate_text, postvalidate_text, - exception_message - ) VALUES ( - :guard_name, :start_time, :end_time, :prevalidate_text, :postvalidate_text, - :exception_message - ); - """ - - def __init__(self, log_path: os.PathLike, read_mode: bool): - self._log_path = log_path # Read-only value. - self.readonly = read_mode - if read_mode: - self.db = _SQLiteTraceHandler._get_read_connection(log_path) - else: - self.db = _SQLiteTraceHandler._get_write_connection(log_path) - - @property - def log_path(self): - return self._log_path - - @classmethod - def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: - try: - db = sqlite3.connect( - log_path, - isolation_level=None, - check_same_thread=False, - ) - db.execute("PRAGMA journal_mode = wal") - db.execute("PRAGMA synchronous = OFF") - # isolation_level = None and pragma WAL means we can READ from the DB - # while threads using it are writing. Synchronous off puts us on the - # highway to the danger zone, depending on how willing we are to lose log - # messages in the event of a guard crash. - except sqlite3.OperationalError as e: - # logging.exception("Unable to connect to guard log handler.") - raise e - with db: - db.execute(_SQLiteTraceHandler.CREATE_COMMAND) - return db - - @classmethod - def _get_read_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: - # A bit of a hack to open in read-only mode... - db = sqlite3.connect( - "file:" + str(log_path) + "?mode=ro", isolation_level=None, uri=True - ) - db.row_factory = sqlite3.Row - return db - - def _truncate(self, keep_n: int = LOG_RETENTION_LIMIT): - assert not self.readonly - self.db.execute( - """ - DELETE FROM guard_logs - WHERE id < ( - SELECT id FROM guard_logs ORDER BY id DESC LIMIT 1 OFFSET ? - ); - """, - (keep_n,), - ) - - def log( - self, - guard_name: str, - start_time: float, - end_time: float, - prevalidate_text: str, - postvalidate_text: str, - exception_text: str, - ): - assert not self.readonly - with self.db: - self.db.execute( - _SQLiteTraceHandler.INSERT_COMMAND, - dict( - guard_name=guard_name, - start_time=start_time, - end_time=end_time, - prevalidate_text=prevalidate_text, - postvalidate_text=postvalidate_text, - exception_message=exception_text, - ), - ) - - def log_entry(self, guard_log_entry: GuardLogEntry): - assert not self.readonly - with self.db: - self.db.execute(_SQLiteTraceHandler.INSERT_COMMAND, asdict(guard_log_entry)) - - def log_validator(self, vlog: ValidatorLogs): - assert not self.readonly - maybe_outcome = ( - str(vlog.validation_result.outcome) - if ( - vlog.validation_result is not None - and hasattr(vlog.validation_result, "outcome") - ) - else "" - ) - with self.db: - self.db.execute( - _SQLiteTraceHandler.INSERT_COMMAND, - dict( - guard_name=vlog.validator_name, - start_time=vlog.start_time if vlog.start_time else None, - end_time=vlog.end_time if vlog.end_time else 0.0, - prevalidate_text=to_string(vlog.value_before_validation), - postvalidate_text=to_string(vlog.value_after_validation), - exception_message=maybe_outcome, - ), - ) - - def tail_logs( - self, start_offset_idx: int = 0, follow: bool = False - ) -> Iterator[GuardLogEntry]: - """Returns an iterator to generate GuardLogEntries. - @param start_offset_idx : Start printing entries after this IDX. If - negative, this will instead start printing the LAST start_offset_idx entries. - @param follow : If follow is True, will re-check the database for new entries - after the first batch is complete. If False (default), will return when entries - are exhausted. - """ - last_idx = start_offset_idx - cursor = self.db.cursor() - if last_idx < 0: - # We're indexing from the end, so do a quick check. - cursor.execute( - "SELECT id FROM guard_logs ORDER BY id DESC LIMIT 1 OFFSET ?;", - (-last_idx,), - ) - for row in cursor: - last_idx = row["id"] - sql = """ - SELECT - id, guard_name, start_time, end_time, prevalidate_text, - postvalidate_text, exception_message - FROM guard_logs - WHERE id > ? - ORDER BY start_time; - """ - cursor.execute(sql, (last_idx,)) - while True: - for row in cursor: - last_entry = GuardLogEntry(**row) - last_idx = last_entry.id - yield last_entry - if not follow: - return - # If we're here we've run out of entries to tail. Fetch more: - cursor.execute(sql, (last_idx,)) - - -class TraceHandler(_SQLiteTraceHandler): +class TraceHandler(TracerMixin): """TraceHandler wraps the internal _SQLiteTraceHandler to make it multi-thread safe. Coupled with some write ahead journaling in the _SyncTrace internal, we have a faux-multi-write multi-read interface for SQLite.""" @@ -295,11 +59,11 @@ def __new__(cls): return cls._instance @classmethod - def _create(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: # type: ignore - return _SQLiteTraceHandler(path, read_mode=False) + def _create(cls, path: os.PathLike = LOGFILE_PATH) -> TracerMixin: # type: ignore + return SQLiteTraceHandler(path, read_mode=False) # To disable logging: # return _BaseTraceHandler(path, read_mode=False) @classmethod - def get_reader(cls, path: os.PathLike = LOGFILE_PATH) -> _BaseTraceHandler: # type: ignore - return _SQLiteTraceHandler(path, read_mode=True) + def get_reader(cls, path: os.PathLike = LOGFILE_PATH) -> TracerMixin: # type: ignore + return SQLiteTraceHandler(path, read_mode=True) diff --git a/guardrails/call_tracing/sqlite_trace_handler.py b/guardrails/call_tracing/sqlite_trace_handler.py new file mode 100644 index 000000000..551e43a9f --- /dev/null +++ b/guardrails/call_tracing/sqlite_trace_handler.py @@ -0,0 +1,212 @@ +import datetime +import os +import sqlite3 +import time +from dataclasses import asdict +from typing import Iterator + +from guardrails.call_tracing.trace_entry import GuardTraceEntry +from guardrails.call_tracing.tracer_mixin import TracerMixin +from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.utils.casting_utils import to_string + + +LOG_RETENTION_LIMIT = 100000 +TIME_BETWEEN_CLEANUPS = 10.0 # Seconds + + +# These adapters make it more convenient to add data into our log DB: +# Handle timestamp -> sqlite map: +def adapt_datetime(val): + """Adapt datetime.datetime to Unix timestamp.""" + # return val.isoformat() # If we want to go to datetime/isoformat... + return int(val.timestamp()) + + +sqlite3.register_adapter(datetime.datetime, adapt_datetime) + + +def convert_timestamp(val): + """Convert Unix epoch timestamp to datetime.datetime object.""" + # To go to datetime.datetime: + # return datetime.datetime.fromisoformat(val.decode()) + return datetime.datetime.fromtimestamp(int(val)) + + +sqlite3.register_converter("timestamp", convert_timestamp) + + +# This structured handler shouldn't be used directly, since it's touching a SQLite db. +# Instead, use the singleton or the async singleton. +class SQLiteTraceHandler(TracerMixin): + CREATE_COMMAND = """ + CREATE TABLE IF NOT EXISTS guard_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + guard_name TEXT, + start_time REAL, + end_time REAL, + prevalidate_text TEXT, + postvalidate_text TEXT, + exception_message TEXT + ); + """ + INSERT_COMMAND = """ + INSERT INTO guard_logs ( + guard_name, start_time, end_time, prevalidate_text, postvalidate_text, + exception_message + ) VALUES ( + :guard_name, :start_time, :end_time, :prevalidate_text, :postvalidate_text, + :exception_message + ); + """ + + def __init__(self, log_path: os.PathLike, read_mode: bool): + self._log_path = log_path # Read-only value. + self.last_cleanup = time.time() + self.readonly = read_mode + if read_mode: + self.db = SQLiteTraceHandler._get_read_connection(log_path) + else: + self.db = SQLiteTraceHandler._get_write_connection(log_path) + + @property + def log_path(self): + return self._log_path + + @classmethod + def _get_write_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: + try: + db = sqlite3.connect( + log_path, + isolation_level=None, + check_same_thread=False, + ) + db.execute("PRAGMA journal_mode = wal") + db.execute("PRAGMA synchronous = OFF") + # isolation_level = None and pragma WAL means we can READ from the DB + # while threads using it are writing. Synchronous off puts us on the + # highway to the danger zone, depending on how willing we are to lose log + # messages in the event of a guard crash. + except sqlite3.OperationalError as e: + # logging.exception("Unable to connect to guard log handler.") + raise e + with db: + db.execute(SQLiteTraceHandler.CREATE_COMMAND) + return db + + @classmethod + def _get_read_connection(cls, log_path: os.PathLike) -> sqlite3.Connection: + # A bit of a hack to open in read-only mode... + db = sqlite3.connect( + "file:" + str(log_path) + "?mode=ro", isolation_level=None, uri=True + ) + db.row_factory = sqlite3.Row + return db + + def _truncate(self, force: bool = False, keep_n: int = LOG_RETENTION_LIMIT): + assert not self.readonly + now = time.time() + if force or (now - self.last_cleanup > TIME_BETWEEN_CLEANUPS): + self.last_cleanup = now + self.db.execute( + """ + DELETE FROM guard_logs + WHERE id < ( + SELECT id FROM guard_logs ORDER BY id DESC LIMIT 1 OFFSET ? + ); + """, + (keep_n,), + ) + + def log( + self, + guard_name: str, + start_time: float, + end_time: float, + prevalidate_text: str, + postvalidate_text: str, + exception_text: str, + ): + assert not self.readonly + with self.db: + self.db.execute( + SQLiteTraceHandler.INSERT_COMMAND, + dict( + guard_name=guard_name, + start_time=start_time, + end_time=end_time, + prevalidate_text=prevalidate_text, + postvalidate_text=postvalidate_text, + exception_message=exception_text, + ), + ) + self._truncate() + + def log_entry(self, guard_log_entry: GuardTraceEntry): + assert not self.readonly + with self.db: + self.db.execute(SQLiteTraceHandler.INSERT_COMMAND, asdict(guard_log_entry)) + self._truncate() + + def log_validator(self, vlog: ValidatorLogs): + assert not self.readonly + maybe_outcome = ( + str(vlog.validation_result.outcome) + if ( + vlog.validation_result is not None + and hasattr(vlog.validation_result, "outcome") + ) + else "" + ) + with self.db: + self.db.execute( + SQLiteTraceHandler.INSERT_COMMAND, + dict( + guard_name=vlog.validator_name, + start_time=vlog.start_time if vlog.start_time else None, + end_time=vlog.end_time if vlog.end_time else 0.0, + prevalidate_text=to_string(vlog.value_before_validation), + postvalidate_text=to_string(vlog.value_after_validation), + exception_message=maybe_outcome, + ), + ) + self._truncate() + + def tail_logs( + self, start_offset_idx: int = 0, follow: bool = False + ) -> Iterator[GuardTraceEntry]: + """Returns an iterator to generate GuardLogEntries. + @param start_offset_idx : Start printing entries after this IDX. If + negative, this will instead start printing the LAST start_offset_idx entries. + @param follow : If follow is True, will re-check the database for new entries + after the first batch is complete. If False (default), will return when entries + are exhausted. + """ + last_idx = start_offset_idx + cursor = self.db.cursor() + if last_idx < 0: + # We're indexing from the end, so do a quick check. + cursor.execute( + "SELECT id FROM guard_logs ORDER BY id DESC LIMIT 1 OFFSET ?;", + (-last_idx,), + ) + for row in cursor: + last_idx = row["id"] + sql = """ + SELECT + id, guard_name, start_time, end_time, prevalidate_text, + postvalidate_text, exception_message + FROM guard_logs + WHERE id > ? + ORDER BY start_time; + """ + cursor.execute(sql, (last_idx,)) + while True: + for row in cursor: + last_entry = GuardTraceEntry(**row) + last_idx = last_entry.id + yield last_entry + if not follow: + return + # If we're here we've run out of entries to tail. Fetch more: + cursor.execute(sql, (last_idx,)) diff --git a/guardrails/call_tracing/trace_entry.py b/guardrails/call_tracing/trace_entry.py new file mode 100644 index 000000000..45fda431a --- /dev/null +++ b/guardrails/call_tracing/trace_entry.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + + +@dataclass +class GuardTraceEntry: + guard_name: str + start_time: float + end_time: float + id: int = -1 + prevalidate_text: str = "" + postvalidate_text: str = "" + exception_message: str = "" + + @property + def timedelta(self): + return self.end_time - self.start_time diff --git a/guardrails/call_tracing/tracer_mixin.py b/guardrails/call_tracing/tracer_mixin.py new file mode 100644 index 000000000..f0aa92fea --- /dev/null +++ b/guardrails/call_tracing/tracer_mixin.py @@ -0,0 +1,28 @@ +import os +from typing import Iterator + +from guardrails.call_tracing.trace_entry import GuardTraceEntry +from guardrails.classes.validation.validator_logs import ValidatorLogs + + +class TracerMixin: + """The pads out the methods but is otherwise a noop.""" + + def __init__(self, log_path: os.PathLike, read_mode: bool): + self.db = None + + def log(self, *args, **kwargs): + pass + + def log_entry(self, guard_log_entry: GuardTraceEntry): + pass + + def log_validator(self, vlog: ValidatorLogs): + pass + + def tail_logs( + self, + start_offset_idx: int = 0, + follow: bool = False, + ) -> Iterator[GuardTraceEntry]: + yield from [] diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index 568a396e1..ad1725d0a 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -9,9 +9,9 @@ from guardrails.cli.guardrails import guardrails as gr_cli from guardrails.call_tracing.guard_call_logging import ( - GuardLogEntry, TraceHandler, ) +from guardrails.call_tracing.trace_entry import GuardTraceEntry @gr_cli.command(name="watch") @@ -55,9 +55,9 @@ def watch_command( output_fn(log_msg) -def _print_fancy(log_msg: GuardLogEntry): +def _print_fancy(log_msg: GuardTraceEntry): rich.print(log_msg) -def _print_and_format_plain(log_msg: GuardLogEntry) -> None: +def _print_and_format_plain(log_msg: GuardTraceEntry) -> None: print(json.dumps(asdict(log_msg))) diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index e0dec867e..9702f8a4e 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -7,7 +7,7 @@ from opentelemetry.context import Context from opentelemetry.trace import StatusCode, Tracer -from guardrails.call_tracing.guard_call_logging import TraceHandler +from guardrails.call_tracing import TraceHandler from guardrails.stores.context import get_tracer as get_context_tracer from guardrails.stores.context import get_tracer_context from guardrails.utils.casting_utils import to_string From 62ae667776cf0456e24dece29a26a6ece72c1a54 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 15:02:08 -0700 Subject: [PATCH 23/25] Move around and clean up based on PR feedback. --- guardrails/call_tracing/__init__.py | 9 +++++++-- .../call_tracing/sqlite_trace_handler.py | 18 ++++++++++++++++++ guardrails/call_tracing/trace_entry.py | 15 ++++++++++++--- ...{guard_call_logging.py => trace_handler.py} | 2 +- guardrails/call_tracing/tracer_mixin.py | 8 ++++++++ guardrails/cli/watch.py | 5 +---- tests/unit_tests/test_guard_log.py | 2 +- 7 files changed, 48 insertions(+), 11 deletions(-) rename guardrails/call_tracing/{guard_call_logging.py => trace_handler.py} (99%) diff --git a/guardrails/call_tracing/__init__.py b/guardrails/call_tracing/__init__.py index c9d214a87..76e6a3754 100644 --- a/guardrails/call_tracing/__init__.py +++ b/guardrails/call_tracing/__init__.py @@ -1,7 +1,12 @@ """ For tracing (logging) and reporting the timing of Guard and Validator calls. + +sqlite_trace_handler defines most of the actual implementation methods. +trace_handler provides the singleton that's used for fast global access across threads. +tracer_mixin defines the interface and can act as a noop. +trace_entry is just a helpful dataclass. """ from guardrails.call_tracing.trace_entry import GuardTraceEntry -from guardrails.call_tracing.guard_call_logging import TraceHandler +from guardrails.call_tracing.trace_handler import TraceHandler -__all__ = ['GuardTraceEntry', 'TraceHandler'] \ No newline at end of file +__all__ = ['GuardTraceEntry', 'TraceHandler'] diff --git a/guardrails/call_tracing/sqlite_trace_handler.py b/guardrails/call_tracing/sqlite_trace_handler.py index 551e43a9f..22ce08c14 100644 --- a/guardrails/call_tracing/sqlite_trace_handler.py +++ b/guardrails/call_tracing/sqlite_trace_handler.py @@ -1,3 +1,21 @@ +""" +sqlite_trace_handler.py + +This is the metaphorical bread and butter of our tracing implementation, or at least the +butter. It wraps a SQLite database and configures it to be 'agreeable' in multithreaded +situations. Normally, when sharing across threads and instances one should consider +using a larger database solution like Postgres, but in this case we only care about +_supporting_ writing from multiple places. We don't expect it will be the norm. +We care about (1) not negatively impacting performance, (2) not crashing when used in +unusual ways, and (3) not losing data when possible. + +The happy path should be reasonably performant. The unhappy path should not crash. + +The other part of the multithreaded support comes from the public trace_handler, which +uses a singleton pattern to only have a single instance of the database per-thread. +If we _do_ somehow end up shared across threads, the journaling settings and writeahead +should protect us from odd behavior. +""" import datetime import os import sqlite3 diff --git a/guardrails/call_tracing/trace_entry.py b/guardrails/call_tracing/trace_entry.py index 45fda431a..259ff3865 100644 --- a/guardrails/call_tracing/trace_entry.py +++ b/guardrails/call_tracing/trace_entry.py @@ -1,12 +1,21 @@ +""" +trace_entry.py + +GuardTraceEntry is a dataclass which doesn't explicitly define the schema of our logs, +but serves as a nice, easy-to-use dataclass for when we want to manipulate things +programmatically. If performance and filtering is a concern, it's probably worth +writing the SQL directly instead of filtering these in a for-loop. +""" + from dataclasses import dataclass @dataclass class GuardTraceEntry: - guard_name: str - start_time: float - end_time: float id: int = -1 + guard_name: str = "" + start_time: float = 0.0 + end_time: float = 0.0 prevalidate_text: str = "" postvalidate_text: str = "" exception_message: str = "" diff --git a/guardrails/call_tracing/guard_call_logging.py b/guardrails/call_tracing/trace_handler.py similarity index 99% rename from guardrails/call_tracing/guard_call_logging.py rename to guardrails/call_tracing/trace_handler.py index cb1188681..cb383c063 100644 --- a/guardrails/call_tracing/guard_call_logging.py +++ b/guardrails/call_tracing/trace_handler.py @@ -1,5 +1,5 @@ """ -guard_call_logging.py +trace_handler.py A set of tools to track the behavior of guards, specifically with the intent of collating the pre/post validation text and timing of guard calls. Uses a singleton to diff --git a/guardrails/call_tracing/tracer_mixin.py b/guardrails/call_tracing/tracer_mixin.py index f0aa92fea..dd56307c5 100644 --- a/guardrails/call_tracing/tracer_mixin.py +++ b/guardrails/call_tracing/tracer_mixin.py @@ -1,3 +1,11 @@ +""" +tracer_mixin.py + +This file defines our preferred tracer interface. +It has a side effect of acting as a 'noop' when we want to benchmark performance of a +tracer. +""" + import os from typing import Iterator diff --git a/guardrails/cli/watch.py b/guardrails/cli/watch.py index ad1725d0a..ba163b976 100644 --- a/guardrails/cli/watch.py +++ b/guardrails/cli/watch.py @@ -8,10 +8,7 @@ import typer from guardrails.cli.guardrails import guardrails as gr_cli -from guardrails.call_tracing.guard_call_logging import ( - TraceHandler, -) -from guardrails.call_tracing.trace_entry import GuardTraceEntry +from guardrails.call_tracing import GuardTraceEntry, TraceHandler @gr_cli.command(name="watch") diff --git a/tests/unit_tests/test_guard_log.py b/tests/unit_tests/test_guard_log.py index 89f4decc3..ee3746576 100644 --- a/tests/unit_tests/test_guard_log.py +++ b/tests/unit_tests/test_guard_log.py @@ -3,7 +3,7 @@ import time from multiprocessing import Pool -from guardrails.call_tracing.guard_call_logging import TraceHandler +from guardrails.call_tracing import TraceHandler NUM_THREADS = 4 From bf5c8d865f462f29824b520d1c26ace625e7a479 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 15:03:01 -0700 Subject: [PATCH 24/25] Remove the macarena from unit tests. --- tests/unit_tests/test_guard_log.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit_tests/test_guard_log.py b/tests/unit_tests/test_guard_log.py index ee3746576..fc332f374 100644 --- a/tests/unit_tests/test_guard_log.py +++ b/tests/unit_tests/test_guard_log.py @@ -13,7 +13,6 @@ "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", "Ut enim ad minim veniam", "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat", - "Heeeey Macarena.", "Excepteur sint occaecat cupidatat non proident,", "sunt in culpa qui officia deserunt mollit anim id est laborum.", ] From 36f51774849e1e9827634ff525e899ce578b0573 Mon Sep 17 00:00:00 2001 From: Joseph Catrambone Date: Thu, 27 Jun 2024 15:09:46 -0700 Subject: [PATCH 25/25] Linting. --- guardrails/call_tracing/__init__.py | 3 ++- guardrails/call_tracing/sqlite_trace_handler.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/guardrails/call_tracing/__init__.py b/guardrails/call_tracing/__init__.py index 76e6a3754..62078ab46 100644 --- a/guardrails/call_tracing/__init__.py +++ b/guardrails/call_tracing/__init__.py @@ -6,7 +6,8 @@ tracer_mixin defines the interface and can act as a noop. trace_entry is just a helpful dataclass. """ + from guardrails.call_tracing.trace_entry import GuardTraceEntry from guardrails.call_tracing.trace_handler import TraceHandler -__all__ = ['GuardTraceEntry', 'TraceHandler'] +__all__ = ["GuardTraceEntry", "TraceHandler"] diff --git a/guardrails/call_tracing/sqlite_trace_handler.py b/guardrails/call_tracing/sqlite_trace_handler.py index 22ce08c14..2831c05d8 100644 --- a/guardrails/call_tracing/sqlite_trace_handler.py +++ b/guardrails/call_tracing/sqlite_trace_handler.py @@ -16,6 +16,7 @@ If we _do_ somehow end up shared across threads, the journaling settings and writeahead should protect us from odd behavior. """ + import datetime import os import sqlite3