Skip to content

Commit

Permalink
feat: ensure serializability for whole logging data passed to the zmq (
Browse files Browse the repository at this point in the history
…#1759)

Backported-from: main (24.09)
Backported-to: 23.09
Co-authored-by: Joongi Kim <joongi@lablup.com>
  • Loading branch information
kyujin-cho and achimnol committed Apr 8, 2024
1 parent 1727af3 commit 01b8e3c
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 153 deletions.
1 change: 1 addition & 0 deletions changes/1759.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve exception logging stability by pre-formatting exception objects instead of pickling/unpickling them
20 changes: 0 additions & 20 deletions python.lock
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
// "rich~=13.6",
// "setproctitle~=1.3.2",
// "tabulate~=0.8.9",
// "tblib~=1.7",
// "temporenc~=0.1.0",
// "tenacity>=8.0",
// "textual~=0.52.1",
Expand Down Expand Up @@ -3906,24 +3905,6 @@
"requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7",
"version": "0.8.10"
},
{
"artifacts": [
{
"algorithm": "sha256",
"hash": "289fa7359e580950e7d9743eab36b0691f0310fce64dee7d9c31065b8f723e23",
"url": "https://files.pythonhosted.org/packages/f8/cd/2fad4add11c8837e72f50a30e2bda30e67a10d70462f826b291443a55c7d/tblib-1.7.0-py2.py3-none-any.whl"
},
{
"algorithm": "sha256",
"hash": "059bd77306ea7b419d4f76016aef6d7027cc8a0785579b5aad198803435f882c",
"url": "https://files.pythonhosted.org/packages/d3/41/901ef2e81d7b1e834b9870d416cb09479e175a2be1c4aa1a9dcd0a555293/tblib-1.7.0.tar.gz"
}
],
"project_name": "tblib",
"requires_dists": [],
"requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7",
"version": "1.7.0"
},
{
"artifacts": [
{
Expand Down Expand Up @@ -4787,7 +4768,6 @@
"rich~=13.6",
"setproctitle~=1.3.2",
"tabulate~=0.8.9",
"tblib~=1.7",
"temporenc~=0.1.0",
"tenacity>=8.0",
"textual~=0.52.1",
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ SQLAlchemy[postgresql_asyncpg]~=1.4.40
setproctitle~=1.3.2
tabulate~=0.8.9
temporenc~=0.1.0
tblib~=1.7
tenacity>=8.0
tomli~=2.0.1
tomlkit~=0.11.1
Expand Down
156 changes: 98 additions & 58 deletions src/ai/backend/common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import logging.config
import logging.handlers
import os
import pickle
import pprint
import socket
import ssl
import sys
import threading
import time
import traceback
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from contextvars import ContextVar
Expand All @@ -18,11 +18,13 @@
from typing import Any, Mapping, MutableMapping, Optional

import coloredlogs
import graypy
import trafaret as t
import yarl
import zmq
from pythonjsonlogger.jsonlogger import JsonFormatter
from tblib import pickling_support

from ai.backend.common import msgpack

from . import config
from . import validators as tx
Expand Down Expand Up @@ -170,13 +172,6 @@ def emit(self, record):
tags = set()
extra_data = dict()

if record.exc_info:
tags.add("has_exception")
if self.formatter:
extra_data["exception"] = self.formatter.formatException(record.exc_info)
else:
extra_data["exception"] = logging._defaultFormatter.formatException(record.exc_info)

# This log format follows logstash's event format.
log = OrderedDict([
("@timestamp", datetime.now().isoformat()),
Expand All @@ -198,6 +193,61 @@ def emit(self, record):
self._sock.sendall(json.dumps(log).encode("utf-8"))


def format_exception(self, ei):
s = "".join(ei)
if s[-1:] == "\n":
s = s[:-1]
return s


class SerializedExceptionFormatter(logging.Formatter):
def formatException(self, ei) -> str:
return format_exception(self, ei)


class GELFTLSHandler(graypy.GELFTLSHandler):
ssl_ctx: ssl.SSLContext

def __init__(self, host, port=12204, validate=False, ca_certs=None, **kwargs):
"""Initialize the GELFTLSHandler
:param host: GELF TLS input host.
:type host: str
:param port: GELF TLS input port.
:type port: int
:param validate: If :obj:`True`, validate the Graylog server's
certificate. In this case specifying ``ca_certs`` is also
required.
:type validate: bool
:param ca_certs: Path to CA bundle file.
:type ca_certs: str
"""

super().__init__(host, port=port, validate=validate, **kwargs)
self.ssl_ctx = ssl.create_default_context(capath=ca_certs)
if not validate:
self.ssl_ctx.check_hostname = False
self.ssl_ctx.verify_mode = ssl.CERT_NONE

def makeSocket(self, timeout=1):
"""Create a TLS wrapped socket"""
plain_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

if hasattr(plain_socket, "settimeout"):
plain_socket.settimeout(timeout)

wrapped_socket = self.ssl_ctx.wrap_socket(
plain_socket,
server_hostname=self.host,
)
wrapped_socket.connect((self.host, self.port))

return wrapped_socket


def setup_graylog_handler(config: Mapping[str, Any]) -> Optional[logging.Handler]:
try:
import graypy
Expand All @@ -223,6 +273,9 @@ def setup_graylog_handler(config: Mapping[str, Any]) -> Optional[logging.Handler


class ConsoleFormatter(logging.Formatter):
def formatException(self, ei) -> str:
return format_exception(self, ei)

def formatTime(self, record: logging.LogRecord, datefmt: str = None) -> str:
ct = self.converter(record.created) # type: ignore
if datefmt:
Expand All @@ -234,6 +287,9 @@ def formatTime(self, record: logging.LogRecord, datefmt: str = None) -> str:


class CustomJsonFormatter(JsonFormatter):
def formatException(self, ei) -> str:
return format_exception(self, ei)

def add_fields(
self,
log_record: dict[str, Any], # the manipulated entry object
Expand All @@ -251,6 +307,12 @@ def add_fields(
log_record["level"] = record.levelname.upper()


class ColorizedFormatter(coloredlogs.ColoredFormatter):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
coloredlogs.logging.Formatter.formatException = format_exception


class pretty:
"""A simple object wrapper to pretty-format it when formatting the log record."""

Expand All @@ -272,7 +334,7 @@ def setup_console_log_handler(config: Mapping[str, Any]) -> logging.Handler:
if colored is None:
colored = sys.stderr.isatty()
if colored:
console_formatter = coloredlogs.ColoredFormatter(
console_formatter = ColorizedFormatter(
log_formats[drv_config["format"]],
datefmt="%Y-%m-%d %H:%M:%S.%f", # coloredlogs has intrinsic support for msec
field_styles={
Expand Down Expand Up @@ -330,6 +392,9 @@ def log_worker(
logstash_handler = None
graylog_handler = None

# For future references: when implementing new kind of logging adapters,
# make sure to adapt our custom `Formatter.formatException()` approach;
# Otherwise it won't print out EXCEPTION level log (along with the traceback).
if "console" in logging_config["drivers"]:
console_handler = setup_console_log_handler(logging_config)

Expand All @@ -346,8 +411,11 @@ def log_worker(
myhost="hostname", # TODO: implement
)
logstash_handler.setLevel(logging_config["level"])
logstash_handler.setFormatter(SerializedExceptionFormatter())
if "graylog" in logging_config["drivers"]:
graylog_handler = setup_graylog_handler(logging_config)
assert graylog_handler is not None
graylog_handler.setFormatter(SerializedExceptionFormatter())

zctx = zmq.Context()
agg_sock = zctx.socket(zmq.PULL)
Expand All @@ -361,19 +429,10 @@ def log_worker(
data = agg_sock.recv()
if not data:
return
try:
rec = pickle.loads(data)
except (pickle.PickleError, TypeError):
# We have an unpickling error.
# Change into a self-created log record with exception info.
rec = logging.makeLogRecord({
"name": __name__,
"msg": "Cannot unpickle the log record (raw data: %r)",
"levelno": logging.ERROR,
"levelname": "error",
"args": (data,), # attach the original data for inspection
"exc_info": sys.exc_info(),
})
unpacked_data = msgpack.unpackb(data)
if not unpacked_data:
break
rec = logging.makeLogRecord(unpacked_data)
if rec is None:
break
if console_handler:
Expand Down Expand Up @@ -430,39 +489,22 @@ def emit(self, record: Optional[logging.LogRecord]) -> None:
self._fallback(record)
return
# record may be None to signal shutdown.
if record:
log_body = {
"name": record.name,
"pathname": record.pathname,
"lineno": record.lineno,
"msg": record.getMessage(),
"levelno": record.levelno,
"levelname": record.levelname,
}
if record.exc_info:
log_body["exc_info"] = traceback.format_exception(*record.exc_info)
else:
log_body = None
try:
if record is not None and record.exc_info is not None:
pickling_support.install(record.exc_info[1])
pickled_rec = pickle.dumps(record)
except (
pickle.PickleError,
TypeError,
ImportError, # when "Python is likely to be shutting down"
):
# We have a pickling error.
# Change it into a self-created picklable log record with exception info.
if record is not None:
exc_info: Any
if isinstance(record.exc_info, tuple):
exc_info = (
PickledException,
PickledException(repr(record.exc_info[1])), # store stringified repr
record.exc_info[2],
)
else:
exc_info = record.exc_info
record = logging.makeLogRecord({
"name": record.name,
"pathname": record.pathname,
"lineno": record.lineno,
"msg": record.getMessage(),
"levelno": record.levelno,
"levelname": record.levelname,
"exc_info": exc_info,
})
pickled_rec = pickle.dumps(record)
try:
self._sock.send(pickled_rec)
serialized_record = msgpack.packb(log_body)
self._sock.send(serialized_record)
except zmq.ZMQError:
self._fallback(record)

Expand Down Expand Up @@ -514,7 +556,7 @@ def __init__(
log_handlers.append(file_handler)
self.log_config = {
"version": 1,
"disable_existing_loggers": True,
"disable_existing_loggers": False,
"handlers": {
"null": {"class": "logging.NullHandler"},
},
Expand Down Expand Up @@ -600,8 +642,6 @@ def __init__(
}

def __enter__(self):
tx.fix_trafaret_pickle_support() # monkey-patch for pickling trafaret.DataError
pickling_support.install() # enable pickling of tracebacks
self.log_config["handlers"]["relay"] = {
"class": "ai.backend.common.logging.RelayHandler",
"level": self.logging_config["level"],
Expand Down
7 changes: 0 additions & 7 deletions src/ai/backend/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@
)


def fix_trafaret_pickle_support():
def __reduce__(self):
return (type(self), (self.error, self.name, self.value, self.trafaret, self.code))

t.DataError.__reduce__ = __reduce__


class StringLengthMeta(TrafaretMeta):
"""
A metaclass that makes string-like trafarets to have sliced min/max length indicator.
Expand Down
9 changes: 0 additions & 9 deletions tests/common/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(do_skip)


@pytest.fixture(scope="session", autouse=True)
def event_loop():
# uvloop.install()
loop = asyncio.new_event_loop()
# setup_child_watcher()
yield loop
loop.close()


@pytest.fixture(scope="session", autouse=True)
def test_ns():
return f"test-{secrets.token_hex(8)}"
Expand Down

0 comments on commit 01b8e3c

Please sign in to comment.