From f3dc1d50d361e621ac9d3b82933329dc28a8d862 Mon Sep 17 00:00:00 2001 From: Tom Monk Date: Sat, 23 Mar 2024 16:52:08 -0700 Subject: [PATCH] Fix issue 2485 enable caching for get_logger calls Cache one Logger object per Python logger name in LoggingHandler --- .../sdk/_logs/_internal/__init__.py | 20 +++++++++++-------- opentelemetry-sdk/tests/logs/test_export.py | 15 ++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/_logs/_internal/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/_logs/_internal/__init__.py index 8ba0dae6f2e..04c32e7514a 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/_logs/_internal/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/_logs/_internal/__init__.py @@ -19,6 +19,7 @@ import logging import threading import traceback +from functools import lru_cache from os import environ from time import time_ns from typing import Any, Callable, Optional, Tuple, Union # noqa @@ -448,9 +449,6 @@ def __init__( ) -> None: super().__init__(level=level) self._logger_provider = logger_provider or get_logger_provider() - self._logger = get_logger( - __name__, logger_provider=self._logger_provider - ) @staticmethod def _get_attributes(record: logging.LogRecord) -> Attributes: @@ -530,6 +528,9 @@ def _translate(self, record: logging.LogRecord) -> LogRecord: "WARN" if record.levelname == "WARNING" else record.levelname ) + logger = get_logger( + record.name, logger_provider=self._logger_provider + ) return LogRecord( timestamp=timestamp, observed_timestamp=observered_timestamp, @@ -539,7 +540,7 @@ def _translate(self, record: logging.LogRecord) -> LogRecord: severity_text=level_name, severity_number=severity_number, body=body, - resource=self._logger.resource, + resource=logger.resource, attributes=attributes, ) @@ -549,15 +550,17 @@ def emit(self, record: logging.LogRecord) -> None: The record is translated to OTel format, and then sent across the pipeline. """ - if not isinstance(self._logger, NoOpLogger): - self._logger.emit(self._translate(record)) + logger = get_logger( + record.name, logger_provider=self._logger_provider + ) + if not isinstance(logger, NoOpLogger): + logger.emit(self._translate(record)) def flush(self) -> None: """ Flushes the logging output. Skip flushing if logger is NoOp. """ - if not isinstance(self._logger, NoOpLogger): - self._logger_provider.force_flush() + self._logger_provider.force_flush() class Logger(APILogger): @@ -618,6 +621,7 @@ def __init__( def resource(self): return self._resource + @lru_cache(maxsize=None) def get_logger( self, name: str, diff --git a/opentelemetry-sdk/tests/logs/test_export.py b/opentelemetry-sdk/tests/logs/test_export.py index d48dcf8e242..69100d863ef 100644 --- a/opentelemetry-sdk/tests/logs/test_export.py +++ b/opentelemetry-sdk/tests/logs/test_export.py @@ -71,6 +71,7 @@ def test_simple_log_record_processor_default_level(self): self.assertEqual( warning_log_record.severity_number, SeverityNumber.WARN ) + self.assertEqual(finished_logs[0].instrumentation_scope.name, "default_level") def test_simple_log_record_processor_custom_level(self): exporter = InMemoryLogExporter() @@ -104,6 +105,8 @@ def test_simple_log_record_processor_custom_level(self): self.assertEqual( fatal_log_record.severity_number, SeverityNumber.FATAL ) + self.assertEqual(finished_logs[0].instrumentation_scope.name, "custom_level") + self.assertEqual(finished_logs[1].instrumentation_scope.name, "custom_level") def test_simple_log_record_processor_trace_correlation(self): exporter = InMemoryLogExporter() @@ -129,6 +132,7 @@ def test_simple_log_record_processor_trace_correlation(self): self.assertEqual( log_record.trace_flags, INVALID_SPAN_CONTEXT.trace_flags ) + self.assertEqual(finished_logs[0].instrumentation_scope.name, "trace_correlation") exporter.clear() tracer = trace.TracerProvider().get_tracer(__name__) @@ -140,6 +144,7 @@ def test_simple_log_record_processor_trace_correlation(self): self.assertEqual(log_record.body, "Critical message within span") self.assertEqual(log_record.severity_text, "CRITICAL") self.assertEqual(log_record.severity_number, SeverityNumber.FATAL) + self.assertEqual(finished_logs[0].instrumentation_scope.name, "trace_correlation") span_context = span.get_span_context() self.assertEqual(log_record.trace_id, span_context.trace_id) self.assertEqual(log_record.span_id, span_context.span_id) @@ -166,6 +171,7 @@ def test_simple_log_record_processor_shutdown(self): self.assertEqual( warning_log_record.severity_number, SeverityNumber.WARN ) + self.assertEqual(finished_logs[0].instrumentation_scope.name, "shutdown") exporter.clear() logger_provider.shutdown() with self.assertLogs(level=logging.WARNING): @@ -206,6 +212,8 @@ def test_simple_log_record_processor_different_msg_types(self): for item in finished_logs ] self.assertEqual(expected, emitted) + for item in finished_logs: + self.assertEqual(item.instrumentation_scope.name, "different_msg_types") class TestBatchLogRecordProcessor(ConcurrencyTestBase): @@ -379,6 +387,8 @@ def test_shutdown(self): for item in finished_logs ] self.assertEqual(expected, emitted) + for item in finished_logs: + self.assertEqual(item.instrumentation_scope.name, "shutdown") def test_force_flush(self): exporter = InMemoryLogExporter() @@ -398,6 +408,7 @@ def test_force_flush(self): log_record = finished_logs[0].log_record self.assertEqual(log_record.body, "Earth is burning") self.assertEqual(log_record.severity_number, SeverityNumber.FATAL) + self.assertEqual(finished_logs[0].instrumentation_scope.name, "force_flush") def test_log_record_processor_too_many_logs(self): exporter = InMemoryLogExporter() @@ -416,6 +427,8 @@ def test_log_record_processor_too_many_logs(self): self.assertTrue(log_record_processor.force_flush()) finised_logs = exporter.get_finished_logs() self.assertEqual(len(finised_logs), 1000) + for item in finised_logs: + self.assertEqual(item.instrumentation_scope.name, "many_logs") def test_with_multiple_threads(self): exporter = InMemoryLogExporter() @@ -443,6 +456,8 @@ def bulk_log_and_flush(num_logs): finished_logs = exporter.get_finished_logs() self.assertEqual(len(finished_logs), 2415) + for item in finished_logs: + self.assertEqual(item.instrumentation_scope.name, "threads") @unittest.skipUnless( hasattr(os, "fork"),