diff --git a/README.md b/README.md index 52de722..d98fa0c 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,7 @@ from mcpd import McpdClient # Initialize the client with your mcpd API endpoint. # api_key is optional and sends an 'MCPD-API-KEY' header. # server_health_cache_ttl is optional and sets the time in seconds to cache a server health response. +# logger is optional and allows you to provide a custom logger implementation (see Logging section). client = McpdClient(api_endpoint="http://localhost:8090", api_key="optional-key", server_health_cache_ttl=10) ``` @@ -182,6 +183,101 @@ client = McpdClient(api_endpoint="http://localhost:8090", api_key="optional-key" * `client.is_server_healthy(server_name: str) -> bool` - Checks if the specified server is healthy and can handle requests. +## Logging + +The SDK includes built-in logging infrastructure that can be enabled via the `MCPD_LOG_LEVEL` environment variable. Logging is disabled by default to avoid contaminating stdout/stderr. + +> [!IMPORTANT] +> Only enable `MCPD_LOG_LEVEL` in non-MCP-server contexts. MCP servers can use stdout for JSON-RPC communication, +> and any logging output will break the protocol. + +### Available Log Levels + +Set the `MCPD_LOG_LEVEL` environment variable to one of the following values (from most to least verbose): + +* `trace` - Most verbose logging (includes all levels below) +* `debug` - Debug-level logging +* `info` - Informational logging +* `warn` - Warning-level logging (recommended for most use cases) +* `error` - Error-level logging only +* `off` - Disable all logging (default) + +### Example Usage + +```bash +# Enable warning-level logging +export MCPD_LOG_LEVEL=warn +python your_script.py +``` + +```python +from mcpd import McpdClient + +# Warnings will be logged to stderr when MCPD_LOG_LEVEL=warn +client = McpdClient(api_endpoint="http://localhost:8090") + +# For example, the SDK will log warnings for: +# - Non-existent servers when calling agent_tools() +# - Unhealthy servers when calling agent_tools() +# - Servers that become unavailable during tool fetching +``` + +### Custom Logger + +You can provide your own logger implementation that implements the `Logger` protocol: + +```python +import sys +from mcpd import McpdClient, Logger + +class CustomLogger: + """Custom logger that writes to stderr (safe for MCP server contexts).""" + + def trace(self, msg: str, *args: object) -> None: + print(f"TRACE: {msg % args}", file=sys.stderr) + + def debug(self, msg: str, *args: object) -> None: + print(f"DEBUG: {msg % args}", file=sys.stderr) + + def info(self, msg: str, *args: object) -> None: + print(f"INFO: {msg % args}", file=sys.stderr) + + def warn(self, msg: str, *args: object) -> None: + print(f"WARN: {msg % args}", file=sys.stderr) + + def error(self, msg: str, *args: object) -> None: + print(f"ERROR: {msg % args}", file=sys.stderr) + +# Use custom logger +client = McpdClient( + api_endpoint="http://localhost:8090", + logger=CustomLogger() +) +``` + +You can also provide a partial logger implementation. Any omitted methods will fall back to the default logger (which respects `MCPD_LOG_LEVEL`): + +```python +import sys + +class PartialLogger: + """Partial logger - only override warn/error, others use default.""" + + def warn(self, msg: str, *args: object) -> None: + # Custom warning handler (writes to stderr). + print(f"CUSTOM WARN: {msg % args}", file=sys.stderr) + + def error(self, msg: str, *args: object) -> None: + # Custom error handler (writes to stderr). + print(f"CUSTOM ERROR: {msg % args}", file=sys.stderr) + # trace, debug, info use default logger (respects MCPD_LOG_LEVEL) + +client = McpdClient( + api_endpoint="http://localhost:8090", + logger=PartialLogger() +) +``` + ## Error Handling All SDK-level errors, including HTTP and connection errors, will raise a `McpdError` exception. diff --git a/src/mcpd/__init__.py b/src/mcpd/__init__.py index 95c9338..0c8d6b4 100644 --- a/src/mcpd/__init__.py +++ b/src/mcpd/__init__.py @@ -12,6 +12,7 @@ - Comprehensive error handling: Detailed exceptions for different failure modes """ +from ._logger import Logger, LogLevel from .exceptions import ( AuthenticationError, ConnectionError, @@ -28,6 +29,8 @@ __all__ = [ "McpdClient", "HealthStatus", + "Logger", + "LogLevel", "McpdError", "AuthenticationError", "ConnectionError", diff --git a/src/mcpd/_logger.py b/src/mcpd/_logger.py new file mode 100644 index 0000000..c6318b3 --- /dev/null +++ b/src/mcpd/_logger.py @@ -0,0 +1,268 @@ +"""Internal logging infrastructure for the mcpd SDK. + +This module provides a logging shim controlled by the MCPD_LOG_LEVEL environment +variable. Logging is disabled by default to avoid contaminating stdout/stderr in +MCP server contexts. + +CRITICAL: Only enable MCPD_LOG_LEVEL in non-MCP-server contexts. MCP servers use +stdout for JSON-RPC communication, and any logging output will break the protocol. +""" + +import logging +import os +from enum import Enum +from typing import Protocol + + +class LogLevel(str, Enum): + """Valid log level values for MCPD_LOG_LEVEL environment variable. + + Aligns with mcpd server binary log levels for consistency across the mcpd ecosystem. + """ + + TRACE = "trace" + DEBUG = "debug" + INFO = "info" + WARN = "warn" + ERROR = "error" + OFF = "off" + + +class Logger(Protocol): + """Logger protocol defining the SDK's logging interface. + + This protocol matches standard logging levels and allows custom logger injection. + All methods accept a message and optional formatting arguments. + """ + + def trace(self, msg: str, *args: object) -> None: + """Log a trace-level message (most verbose).""" + ... + + def debug(self, msg: str, *args: object) -> None: + """Log a debug-level message.""" + ... + + def info(self, msg: str, *args: object) -> None: + """Log an info-level message.""" + ... + + def warn(self, msg: str, *args: object) -> None: + """Log a warning-level message.""" + ... + + def error(self, msg: str, *args: object) -> None: + """Log an error-level message.""" + ... + + +# Custom TRACE level (below DEBUG=10). +_TRACE = 5 +logging.addLevelName(_TRACE, "TRACE") + +_RANKS: dict[str, int] = { + LogLevel.TRACE.value: _TRACE, + LogLevel.DEBUG.value: logging.DEBUG, + LogLevel.INFO.value: logging.INFO, + LogLevel.WARN.value: logging.WARNING, + "warning": logging.WARNING, # Alias for backwards compatibility. + LogLevel.ERROR.value: logging.ERROR, + LogLevel.OFF.value: 1000, # Higher than any standard level. +} + + +def _resolve_log_level(raw: str | None) -> str: + """Resolve the log level from environment variable value. + + Args: + raw: Raw value from MCPD_LOG_LEVEL environment variable. + + Returns: + Valid log level string matching LogLevel enum values. + Returns LogLevel.OFF.value if raw is None, empty, or not a valid level. + """ + candidate = raw.strip().lower() if raw else None + return candidate if candidate and candidate in _RANKS else LogLevel.OFF.value + + +def _get_level() -> str: + """Get the current log level from environment variable (lazy evaluation). + + This function is called on each log statement to support dynamic level changes. + + Note: + Dynamic level changes can facilitate testing. + + Returns: + The resolved log level string. + """ + return _resolve_log_level(os.getenv("MCPD_LOG_LEVEL")) + + +def _create_default_logger() -> Logger: + """Create the default logger with lazy level evaluation. + + Returns: + A Logger instance that checks MCPD_LOG_LEVEL on each log call, + enabling dynamic level changes without module reloading. + """ + # Create logger and handler once (not per-call). + _logger = logging.getLogger(__name__) + + if not _logger.handlers: + # Add stderr handler (default for StreamHandler). + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + _logger.addHandler(handler) + _logger.propagate = False + + class _DefaultLogger: + """Default logger that checks level on each call (lazy evaluation).""" + + def trace(self, msg: str, *args: object) -> None: + """Log trace-level message.""" + lvl = _get_level() + if lvl != LogLevel.OFF.value and _RANKS[lvl] <= _RANKS[LogLevel.TRACE.value]: + _logger.setLevel(_TRACE) + _logger.log(_TRACE, msg, *args) + + def debug(self, msg: str, *args: object) -> None: + """Log debug-level message.""" + lvl = _get_level() + if lvl != LogLevel.OFF.value and _RANKS[lvl] <= _RANKS[LogLevel.DEBUG.value]: + _logger.setLevel(logging.DEBUG) + _logger.debug(msg, *args) + + def info(self, msg: str, *args: object) -> None: + """Log info-level message.""" + lvl = _get_level() + if lvl != LogLevel.OFF.value and _RANKS[lvl] <= _RANKS[LogLevel.INFO.value]: + _logger.setLevel(logging.INFO) + _logger.info(msg, *args) + + def warn(self, msg: str, *args: object) -> None: + """Log warning-level message.""" + lvl = _get_level() + if lvl != LogLevel.OFF.value and _RANKS[lvl] <= _RANKS[LogLevel.WARN.value]: + _logger.setLevel(logging.WARNING) + _logger.warning(msg, *args) + + def error(self, msg: str, *args: object) -> None: + """Log error-level message.""" + lvl = _get_level() + if lvl != LogLevel.OFF.value and _RANKS[lvl] <= _RANKS[LogLevel.ERROR.value]: + _logger.setLevel(logging.ERROR) + _logger.error(msg, *args) + + return _DefaultLogger() + + +class _PartialLoggerWrapper: + """Wrapper that combines partial custom logger with default logger fallback. + + This enables partial logger implementations where users can override specific + methods while keeping defaults for others. + """ + + def __init__(self, custom: object, default: Logger) -> None: + """Initialize the wrapper. + + Args: + custom: Partial logger implementation (may not have all methods). + default: Default logger to use for missing methods. + """ + self._custom = custom + self._default = default + + def trace(self, msg: str, *args: object) -> None: + """Log trace-level message.""" + if hasattr(self._custom, LogLevel.TRACE.value): + self._custom.trace(msg, *args) + else: + self._default.trace(msg, *args) + + def debug(self, msg: str, *args: object) -> None: + """Log debug-level message.""" + if hasattr(self._custom, LogLevel.DEBUG.value): + self._custom.debug(msg, *args) + else: + self._default.debug(msg, *args) + + def info(self, msg: str, *args: object) -> None: + """Log info-level message.""" + if hasattr(self._custom, LogLevel.INFO.value): + self._custom.info(msg, *args) + else: + self._default.info(msg, *args) + + def warn(self, msg: str, *args: object) -> None: + """Log warning-level message.""" + if hasattr(self._custom, LogLevel.WARN.value): + self._custom.warn(msg, *args) + else: + self._default.warn(msg, *args) + + def error(self, msg: str, *args: object) -> None: + """Log error-level message.""" + if hasattr(self._custom, LogLevel.ERROR.value): + self._custom.error(msg, *args) + else: + self._default.error(msg, *args) + + +def create_logger(impl: Logger | object | None = None) -> Logger: + """Create a logger, optionally using a custom implementation. + + This function allows SDK users to inject their own logger implementation. + Supports partial implementations - any omitted methods will fall back to the + default logger, which respects the MCPD_LOG_LEVEL environment variable. + + Args: + impl: Optional custom Logger implementation or partial implementation. + If None, uses the default logger controlled by MCPD_LOG_LEVEL. + If partially provided, custom methods are used and omitted methods + fall back to default logger (which respects MCPD_LOG_LEVEL). + + Returns: + A Logger instance with all methods implemented. + + Example: + >>> # Use default logger (controlled by MCPD_LOG_LEVEL). + >>> logger = create_logger() + >>> + >>> # Full custom logger. + >>> class MyLogger: + ... def trace(self, msg, *args): pass + ... def debug(self, msg, *args): pass + ... def info(self, msg, *args): pass + ... def warn(self, msg, *args): print(f"WARN: {msg % args}") + ... def error(self, msg, *args): print(f"ERROR: {msg % args}") + >>> logger = create_logger(MyLogger()) + >>> + >>> # Partial logger: custom warn/error, default (MCPD_LOG_LEVEL-aware) for others. + >>> class PartialLogger: + ... def warn(self, msg, *args): print(f"WARN: {msg % args}") + ... def error(self, msg, *args): print(f"ERROR: {msg % args}") + ... # trace, debug, info use default logger (respects MCPD_LOG_LEVEL) + >>> logger = create_logger(PartialLogger()) + """ + if impl is None: + return _default_logger + + # Check if it's a full Logger implementation (has all required methods). + required_methods = [ + LogLevel.TRACE.value, + LogLevel.DEBUG.value, + LogLevel.INFO.value, + LogLevel.WARN.value, + LogLevel.ERROR.value, + ] + if all(hasattr(impl, method) for method in required_methods): + return impl + + # Partial implementation - wrap with fallback to default logger. + return _PartialLoggerWrapper(impl, _default_logger) + + +# Module-level default logger (created at import time). +_default_logger: Logger = _create_default_logger() diff --git a/src/mcpd/mcpd_client.py b/src/mcpd/mcpd_client.py index 579d805..f8b5564 100644 --- a/src/mcpd/mcpd_client.py +++ b/src/mcpd/mcpd_client.py @@ -18,6 +18,7 @@ import requests from cachetools import TTLCache, cached +from ._logger import Logger, create_logger from .dynamic_caller import DynamicCaller from .exceptions import ( AuthenticationError, @@ -109,7 +110,13 @@ class McpdClient: """Maximum number of server health entries to cache. Prevents unbounded memory growth while allowing legitimate large-scale monitoring.""" - def __init__(self, api_endpoint: str, api_key: str | None = None, server_health_cache_ttl: float = 10): + def __init__( + self, + api_endpoint: str, + api_key: str | None = None, + server_health_cache_ttl: float = 10, + logger: Logger | None = None, + ) -> None: """Initialize a new McpdClient instance. Args: @@ -119,6 +126,8 @@ def __init__(self, api_endpoint: str, api_key: str | None = None, server_health_ will be included in all requests as "Authorization: Bearer {api_key}". server_health_cache_ttl: Time to live in seconds for the cache of the server health API calls. A value of 0 means no caching. + logger: Optional custom Logger implementation. If None, uses the default logger + controlled by the MCPD_LOG_LEVEL environment variable. Raises: ValueError: If api_endpoint is empty or invalid. @@ -140,6 +149,7 @@ def __init__(self, api_endpoint: str, api_key: str | None = None, server_health_ self._session = requests.Session() # Initialize components + self._logger = create_logger(logger) self._function_builder = FunctionBuilder(self) # Set up authentication @@ -524,9 +534,9 @@ def _agent_tools(self) -> list[_AgentFunction]: for server_name in healthy_servers: try: tool_schemas = self.tools(server_name=server_name) - except (ConnectionError, TimeoutError, AuthenticationError, ServerNotFoundError, McpdError): + except (ConnectionError, TimeoutError, AuthenticationError, ServerNotFoundError, McpdError) as e: # These servers were reported as healthy, so failures for schemas would be unexpected. - # TODO: self._logger.warn(f"Failed to retrieve tool schema for server '{name}'") # include exception + self._logger.warn("Server '%s' became unavailable or unhealthy during tool fetch: %s", server_name, e) continue for tool_schema in tool_schemas: @@ -557,11 +567,15 @@ def is_valid(name: str) -> bool: health = health_map.get(name) if not health: - # TODO: self._logger.warn(f"Skipping non-existent server '{name}'") + self._logger.warn("Skipping non-existent server '%s'", name) + return False + + status = health.get("status") + if not HealthStatus.is_healthy(status): + self._logger.warn("Skipping unhealthy server '%s' with status '%s'", name, status) return False - return HealthStatus.is_healthy(health.get("status")) - # TODO: self._logger.warn(f"Skipping unhealthy server '{name}' with status '{status}'") + return True return [name for name in server_names if is_valid(name)] diff --git a/tests/unit/test_mcpd_client.py b/tests/unit/test_mcpd_client.py index 51dd5d0..5c5aba9 100644 --- a/tests/unit/test_mcpd_client.py +++ b/tests/unit/test_mcpd_client.py @@ -963,3 +963,187 @@ def test_clear_server_health_cache(self, mock_get): result2 = client.server_health("test_server") assert result2 == {"name": "test_server", "status": "ok"} assert mock_get.call_count == 2 + + +class TestLogger: + """Tests for logger integration in McpdClient.""" + + def test_logger_initialization_default(self, monkeypatch): + """Test that client initializes with default logger.""" + monkeypatch.setenv("MCPD_LOG_LEVEL", "warn") + client = McpdClient(api_endpoint="http://localhost:8090") + assert client._logger is not None + assert hasattr(client._logger, "warn") + assert hasattr(client._logger, "error") + + def test_logger_initialization_custom(self): + """Test that client accepts custom logger.""" + + class CustomLogger: + def __init__(self): + self.warnings = [] + self.errors = [] + + def trace(self, msg: str, *args: object) -> None: + pass + + def debug(self, msg: str, *args: object) -> None: + pass + + def info(self, msg: str, *args: object) -> None: + pass + + def warn(self, msg: str, *args: object) -> None: + self.warnings.append(msg % args) + + def error(self, msg: str, *args: object) -> None: + self.errors.append(msg % args) + + custom_logger = CustomLogger() + client = McpdClient(api_endpoint="http://localhost:8090", logger=custom_logger) + assert client._logger is custom_logger + + @patch.object(McpdClient, "server_health") + def test_get_healthy_servers_logs_nonexistent(self, mock_health, monkeypatch): + """Test that _get_healthy_servers logs warning for non-existent server.""" + monkeypatch.setenv("MCPD_LOG_LEVEL", "warn") + + class CustomLogger: + def __init__(self): + self.warnings = [] + + def trace(self, msg: str, *args: object) -> None: + pass + + def debug(self, msg: str, *args: object) -> None: + pass + + def info(self, msg: str, *args: object) -> None: + pass + + def warn(self, msg: str, *args: object) -> None: + self.warnings.append(msg % args) + + def error(self, msg: str, *args: object) -> None: + pass + + custom_logger = CustomLogger() + client = McpdClient(api_endpoint="http://localhost:8090", logger=custom_logger) + + # Mock health check to return only server1, not server2. + mock_health.return_value = { + "server1": {"status": "ok"}, + } + + result = client._get_healthy_servers(["server1", "server2"]) + + # Should only return server1. + assert result == ["server1"] + + # Should log warning for non-existent server2. + assert len(custom_logger.warnings) == 1 + assert "Skipping non-existent server 'server2'" in custom_logger.warnings[0] + + @patch.object(McpdClient, "server_health") + def test_get_healthy_servers_logs_unhealthy(self, mock_health, monkeypatch): + """Test that _get_healthy_servers logs warning for unhealthy server.""" + monkeypatch.setenv("MCPD_LOG_LEVEL", "warn") + + class CustomLogger: + def __init__(self): + self.warnings = [] + + def trace(self, msg: str, *args: object) -> None: + pass + + def debug(self, msg: str, *args: object) -> None: + pass + + def info(self, msg: str, *args: object) -> None: + pass + + def warn(self, msg: str, *args: object) -> None: + self.warnings.append(msg % args) + + def error(self, msg: str, *args: object) -> None: + pass + + custom_logger = CustomLogger() + client = McpdClient(api_endpoint="http://localhost:8090", logger=custom_logger) + + # Mock health check with unhealthy server. + mock_health.return_value = { + "server1": {"status": "ok"}, + "server2": {"status": "timeout"}, + } + + result = client._get_healthy_servers(["server1", "server2"]) + + # Should only return server1. + assert result == ["server1"] + + # Should log warning for unhealthy server2. + assert len(custom_logger.warnings) == 1 + assert "Skipping unhealthy server 'server2' with status 'timeout'" in custom_logger.warnings[0] + + @patch.object(McpdClient, "servers") + @patch.object(McpdClient, "tools") + @patch.object(McpdClient, "server_health") + def test_agent_tools_logs_tool_fetch_error(self, mock_health, mock_tools, mock_servers, monkeypatch): + """Test that _agent_tools logs warning when tool fetch fails.""" + monkeypatch.setenv("MCPD_LOG_LEVEL", "warn") + + class CustomLogger: + def __init__(self): + self.warnings = [] + + def trace(self, msg: str, *args: object) -> None: + pass + + def debug(self, msg: str, *args: object) -> None: + pass + + def info(self, msg: str, *args: object) -> None: + pass + + def warn(self, msg: str, *args: object) -> None: + self.warnings.append(msg % args) + + def error(self, msg: str, *args: object) -> None: + pass + + custom_logger = CustomLogger() + client = McpdClient(api_endpoint="http://localhost:8090", logger=custom_logger) + + mock_servers.return_value = ["server1", "server2"] + mock_health.return_value = { + "server1": {"status": "ok"}, + "server2": {"status": "ok"}, + } + + # server1 succeeds, server2 fails during tool fetch. + def tools_side_effect(server_name=None): + if server_name == "server1": + return [{"name": "tool1", "description": "Tool 1"}] + elif server_name == "server2": + raise McpdError("Connection failed") + return [] + + mock_tools.side_effect = tools_side_effect + + with patch.object(client._function_builder, "create_function_from_schema") as mock_create: + mock_func = Mock() + mock_func._server_name = "server1" + mock_func._tool_name = "tool1" + mock_create.return_value = mock_func + + result = client.agent_tools() + + # Should only return tool from server1. + assert len(result) == 1 + assert result[0]._server_name == "server1" + + # Should log warning for server2 tool fetch failure. + assert len(custom_logger.warnings) == 1 + assert "Server 'server2' became unavailable or unhealthy during tool fetch" in custom_logger.warnings[0] + assert "Connection failed" in custom_logger.warnings[0]