diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 30efd6d99..b117c1714 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -7,11 +7,8 @@ import platform import struct from typing import Tuple, Dict, Optional, Union -from mssql_python.logging_config import get_logger, ENABLE_LOGGING from mssql_python.constants import AuthType -logger = get_logger() - class AADAuth: """Handles Azure Active Directory authentication""" diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 8456ef92d..d1ed6e78c 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -13,16 +13,12 @@ import weakref import re from mssql_python.cursor import Cursor -from mssql_python.logging_config import get_logger, ENABLE_LOGGING -from mssql_python.constants import ConstantsDDBC as ddbc_sql_const -from mssql_python.helpers import add_driver_to_connection_str, check_error +from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, log from mssql_python import ddbc_bindings from mssql_python.pooling import PoolingManager -from mssql_python.exceptions import DatabaseError, InterfaceError +from mssql_python.exceptions import InterfaceError from mssql_python.auth import process_connection_string -logger = get_logger() - class Connection: """ @@ -126,8 +122,7 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st continue conn_str += f"{key}={value};" - if ENABLE_LOGGING: - logger.info("Final connection string: %s", conn_str) + log('info', "Final connection string: %s", sanitize_connection_string(conn_str)) return conn_str @@ -150,8 +145,7 @@ def autocommit(self, value: bool) -> None: None """ self.setautocommit(value) - if ENABLE_LOGGING: - logger.info("Autocommit mode set to %s.", value) + log('info', "Autocommit mode set to %s.", value) def setautocommit(self, value: bool = True) -> None: """ @@ -189,6 +183,7 @@ def cursor(self) -> Cursor: ) cursor = Cursor(self) + self._cursors.add(cursor) # Track the cursor return cursor def commit(self) -> None: @@ -205,8 +200,7 @@ def commit(self) -> None: """ # Commit the current transaction self._conn.commit() - if ENABLE_LOGGING: - logger.info("Transaction committed successfully.") + log('info', "Transaction committed successfully.") def rollback(self) -> None: """ @@ -221,8 +215,7 @@ def rollback(self) -> None: """ # Roll back the current transaction self._conn.rollback() - if ENABLE_LOGGING: - logger.info("Transaction rolled back successfully.") + log('info', "Transaction rolled back successfully.") def close(self) -> None: """ @@ -246,7 +239,7 @@ def close(self) -> None: # Convert to list to avoid modification during iteration cursors_to_close = list(self._cursors) close_errors = [] - + for cursor in cursors_to_close: try: if not cursor.closed: @@ -254,12 +247,11 @@ def close(self) -> None: except Exception as e: # Collect errors but continue closing other cursors close_errors.append(f"Error closing cursor: {e}") - if ENABLE_LOGGING: - logger.warning(f"Error closing cursor: {e}") + log('warning', f"Error closing cursor: {e}") # If there were errors closing cursors, log them but continue - if close_errors and ENABLE_LOGGING: - logger.warning(f"Encountered {len(close_errors)} errors while closing cursors") + if close_errors: + log('warning', f"Encountered {len(close_errors)} errors while closing cursors") # Clear the cursor set explicitly to release any internal references self._cursors.clear() @@ -270,16 +262,14 @@ def close(self) -> None: self._conn.close() self._conn = None except Exception as e: - if ENABLE_LOGGING: - logger.error(f"Error closing database connection: {e}") + log('error', f"Error closing database connection: {e}") # Re-raise the connection close error as it's more critical raise finally: # Always mark as closed, even if there were errors self._closed = True - if ENABLE_LOGGING: - logger.info("Connection closed successfully.") + log('info', "Connection closed successfully.") def __del__(self): """ @@ -287,9 +277,9 @@ def __del__(self): This is a safety net to ensure resources are cleaned up even if close() was not called explicitly. """ - if not self._closed: + if "_closed" not in self.__dict__ or not self._closed: try: self.close() except Exception as e: - if ENABLE_LOGGING: - logger.error(f"Error during connection cleanup in __del__: {e}") + # Dont raise exceptions from __del__ to avoid issues during garbage collection + log('error', f"Error during connection cleanup: {e}") \ No newline at end of file diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 7c4e1efdc..ed1bb70dc 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -14,12 +14,11 @@ import datetime from typing import List, Union from mssql_python.constants import ConstantsDDBC as ddbc_sql_const -from mssql_python.helpers import check_error -from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings +from mssql_python.exceptions import InterfaceError from .row import Row -logger = get_logger() class Cursor: """ @@ -415,10 +414,7 @@ def _initialize_cursor(self) -> None: """ Initialize the DDBC statement handle. """ - # Allocate the DDBC statement handle self._allocate_statement_handle() - # Add the cursor to the connection's cursor set - self.connection._cursors.add(self) def _allocate_statement_handle(self): """ @@ -426,25 +422,14 @@ def _allocate_statement_handle(self): """ self.hstmt = self.connection._conn.alloc_statement_handle() - def _free_cursor(self) -> None: + def _reset_cursor(self) -> None: """ - Free the DDBC statement handle and remove the cursor from the connection's cursor set. + Reset the DDBC statement handle. """ if self.hstmt: self.hstmt.free() self.hstmt = None - if ENABLE_LOGGING: - logger.debug("SQLFreeHandle succeeded") - # We don't need to remove the cursor from the connection's cursor set here, - # as it is a weak reference and will be automatically removed - # when the cursor is garbage collected. - - def _reset_cursor(self) -> None: - """ - Reset the DDBC statement handle. - """ - # Free the current cursor if it exists - self._free_cursor() + log('debug', "SQLFreeHandle succeeded") # Reinitialize the statement handle self._initialize_cursor() @@ -461,8 +446,7 @@ def close(self) -> None: if self.hstmt: self.hstmt.free() self.hstmt = None - if ENABLE_LOGGING: - logger.debug("SQLFreeHandle succeeded") + log('debug', "SQLFreeHandle succeeded") self.closed = True def _check_closed(self): @@ -596,15 +580,14 @@ def execute( # Executing a new statement. Reset is_stmt_prepared to false self.is_stmt_prepared = [False] - if ENABLE_LOGGING: - logger.debug("Executing query: %s", operation) - for i, param in enumerate(parameters): - logger.debug( - """Parameter number: %s, Parameter: %s, - Param Python Type: %s, ParamInfo: %s, %s, %s, %s, %s""", - i + 1, - param, - str(type(param)), + log('debug', "Executing query: %s", operation) + for i, param in enumerate(parameters): + log('debug', + """Parameter number: %s, Parameter: %s, + Param Python Type: %s, ParamInfo: %s, %s, %s, %s, %s""", + i + 1, + param, + str(type(param)), parameters_type[i].paramSQLType, parameters_type[i].paramCType, parameters_type[i].columnSize, @@ -709,10 +692,9 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: ) columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters) - if ENABLE_LOGGING: - logger.info("Executing batch query with %d parameter sets:\n%s", - len(seq_of_parameters),"\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters)) - ) + log('info', "Executing batch query with %d parameter sets:\n%s", + len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters)) + ) # Execute batched statement ret = ddbc_bindings.SQLExecuteMany( @@ -784,6 +766,7 @@ def fetchall(self) -> List[Row]: # Fetch raw data rows_data = [] ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + # Convert raw data to Row objects return [Row(row_data, self.description) for row_data in rows_data] @@ -805,15 +788,16 @@ def nextset(self) -> Union[bool, None]: if ret == ddbc_sql_const.SQL_NO_DATA.value: return False return True - + def __del__(self): """ Destructor to ensure the cursor is closed when it is no longer needed. This is a safety net to ensure resources are cleaned up even if close() was not called explicitly. """ - if not self.closed: + if "_closed" not in self.__dict__ or not self._closed: try: self.close() except Exception as e: - logger.error(f"Error closing cursor: {e}") + # Don't raise an exception in __del__, just log it + log('error', "Error during cursor cleanup in __del__: %s", e) \ No newline at end of file diff --git a/mssql_python/exceptions.py b/mssql_python/exceptions.py index c2307a5f5..308a85690 100644 --- a/mssql_python/exceptions.py +++ b/mssql_python/exceptions.py @@ -4,7 +4,7 @@ This module contains custom exception classes for the mssql_python package. These classes are used to raise exceptions when an error occurs while executing a query. """ -from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.logging_config import get_logger logger = get_logger() @@ -621,7 +621,7 @@ def truncate_error_message(error_message: str) -> str: string_third = string_second[string_second.index("]") + 1 :] return string_first + string_third except Exception as e: - if ENABLE_LOGGING: + if logger: logger.error("Error while truncating error message: %s",e) return error_message @@ -641,7 +641,7 @@ def raise_exception(sqlstate: str, ddbc_error: str) -> None: """ exception_class = sqlstate_to_exception(sqlstate, ddbc_error) if exception_class: - if ENABLE_LOGGING: + if logger: logger.error(exception_class) raise exception_class raise DatabaseError( diff --git a/mssql_python/helpers.py b/mssql_python/helpers.py index cffb06467..267ede75c 100644 --- a/mssql_python/helpers.py +++ b/mssql_python/helpers.py @@ -6,7 +6,7 @@ from mssql_python import ddbc_bindings from mssql_python.exceptions import raise_exception -from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.logging_config import get_logger import platform from pathlib import Path from mssql_python.ddbc_bindings import normalize_architecture @@ -73,7 +73,7 @@ def check_error(handle_type, handle, ret): """ if ret < 0: error_info = ddbc_bindings.DDBCSQLCheckError(handle_type, handle, ret) - if ENABLE_LOGGING: + if logger: logger.error("Error: %s", error_info.ddbcErrorMsg) raise_exception(error_info.sqlState, error_info.ddbcErrorMsg) @@ -184,3 +184,31 @@ def get_driver_path(module_dir, architecture): raise RuntimeError(f"ODBC driver not found at: {driver_path_str}") return driver_path_str + + +def sanitize_connection_string(conn_str: str) -> str: + """ + Sanitize the connection string by removing sensitive information. + Args: + conn_str (str): The connection string to sanitize. + Returns: + str: The sanitized connection string. + """ + # Remove sensitive information from the connection string, Pwd section + # Replace Pwd=...; or Pwd=... (end of string) with Pwd=***; + import re + return re.sub(r"(Pwd\s*=\s*)[^;]*", r"\1***", conn_str, flags=re.IGNORECASE) + + +def log(level: str, message: str, *args) -> None: + """ + Universal logging helper that gets a fresh logger instance. + + Args: + level: Log level ('debug', 'info', 'warning', 'error') + message: Log message with optional format placeholders + *args: Arguments for message formatting + """ + logger = get_logger() + if logger: + getattr(logger, level)(message, *args) \ No newline at end of file diff --git a/mssql_python/logging_config.py b/mssql_python/logging_config.py index d0952724f..2e9eaaeaf 100644 --- a/mssql_python/logging_config.py +++ b/mssql_python/logging_config.py @@ -8,58 +8,157 @@ from logging.handlers import RotatingFileHandler import os import sys +import datetime -ENABLE_LOGGING = False +class LoggingManager: + """ + Singleton class to manage logging configuration for the mssql_python package. + This class provides a centralized way to manage logging configuration and replaces + the previous approach using global variables. + """ + _instance = None + _initialized = False + _logger = None + _log_file = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(LoggingManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + self._initialized = True + self._enabled = False + + @classmethod + def is_logging_enabled(cls): + """Class method to check if logging is enabled for backward compatibility""" + if cls._instance is None: + return False + return cls._instance._enabled + + @property + def enabled(self): + """Check if logging is enabled""" + return self._enabled + + @property + def log_file(self): + """Get the current log file path""" + return self._log_file + + def setup(self, mode="file", log_level=logging.DEBUG): + """ + Set up logging configuration. + + This method configures the logging settings for the application. + It sets the log level, format, and log file location. + + Args: + mode (str): The logging mode ('file' or 'stdout'). + log_level (int): The logging level (default: logging.DEBUG). + """ + # Enable logging + self._enabled = True + + # Create a logger for mssql_python module + # Use a consistent logger name to ensure we're using the same logger throughout + self._logger = logging.getLogger("mssql_python") + self._logger.setLevel(log_level) + + # Configure the root logger to ensure all messages are captured + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + + # Make sure the logger propagates to the root logger + self._logger.propagate = True + + # Clear any existing handlers to avoid duplicates during re-initialization + if self._logger.handlers: + self._logger.handlers.clear() + + # Construct the path to the log file + # Directory for log files - currentdir/logs + current_dir = os.path.dirname(os.path.abspath(__file__)) + log_dir = os.path.join(current_dir, 'logs') + # exist_ok=True allows the directory to be created if it doesn't exist + os.makedirs(log_dir, exist_ok=True) + + # Generate timestamp-based filename for better sorting and organization + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + self._log_file = os.path.join(log_dir, f'mssql_python_trace_{timestamp}_{os.getpid()}.log') + + # Create a log handler to log to driver specific file + # By default we only want to log to a file, max size 500MB, and keep 5 backups + file_handler = RotatingFileHandler(self._log_file, maxBytes=512*1024*1024, backupCount=5) + file_handler.setLevel(log_level) + + # Create a custom formatter that adds [Python Layer log] prefix only to non-DDBC messages + class PythonLayerFormatter(logging.Formatter): + def format(self, record): + message = record.getMessage() + # Don't add [Python Layer log] prefix if the message already has [DDBC Bindings log] or [Python Layer log] + if "[DDBC Bindings log]" not in message and "[Python Layer log]" not in message: + # Create a copy of the record to avoid modifying the original + new_record = logging.makeLogRecord(record.__dict__) + new_record.msg = f"[Python Layer log] {record.msg}" + return super().format(new_record) + return super().format(record) + + # Use our custom formatter + formatter = PythonLayerFormatter('%(asctime)s - %(levelname)s - %(filename)s - %(message)s') + file_handler.setFormatter(formatter) + self._logger.addHandler(file_handler) + + if mode == 'stdout': + # If the mode is stdout, then we want to log to the console as well + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(log_level) + # Use the same smart formatter + stdout_handler.setFormatter(formatter) + self._logger.addHandler(stdout_handler) + elif mode != 'file': + raise ValueError(f'Invalid logging mode: {mode}') + + return self._logger + + def get_logger(self): + """ + Get the logger instance. + + Returns: + logging.Logger: The logger instance, or None if logging is not enabled. + """ + if not self.enabled: + # If logging is not enabled, return None + return None + return self._logger + + +# Create a singleton instance +_manager = LoggingManager() def setup_logging(mode="file", log_level=logging.DEBUG): """ Set up logging configuration. - - This method configures the logging settings for the application. - It sets the log level, format, and log file location. - + + This is a wrapper around the LoggingManager.setup method for backward compatibility. + Args: mode (str): The logging mode ('file' or 'stdout'). log_level (int): The logging level (default: logging.DEBUG). """ - global ENABLE_LOGGING - ENABLE_LOGGING = True - - # Create a logger for mssql_python module - logger = logging.getLogger(__name__) - logger.setLevel(log_level) - - # Construct the path to the log file - # TODO: Use a different dir to dump log file - current_dir = os.path.dirname(os.path.abspath(__file__)) - log_file = os.path.join(current_dir, f'mssql_python_trace_{os.getpid()}.log') - - # Create a log handler to log to driver specific file - # By default we only want to log to a file, max size 500MB, and keep 5 backups - # TODO: Rotate files based on time too? Ex: everyday - file_handler = RotatingFileHandler(log_file, maxBytes=512*1024*1024, backupCount=5) - file_handler.setLevel(log_level) - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(filename)s - %(message)s') - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - if mode == 'stdout': - # If the mode is stdout, then we want to log to the console as well - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setLevel(log_level) - stdout_handler.setFormatter(formatter) - logger.addHandler(stdout_handler) - elif mode != 'file': - raise ValueError(f'Invalid logging mode: {mode}') + return _manager.setup(mode, log_level) def get_logger(): """ Get the logger instance. + + This is a wrapper around the LoggingManager.get_logger method for backward compatibility. Returns: logging.Logger: The logger instance. """ - if not ENABLE_LOGGING: - return None - return logging.getLogger(__name__) + return _manager.get_logger() \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 5568ded00..49c7c7af4 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -539,15 +539,22 @@ void HandleZeroColumnSizeAtFetch(SQLULEN& columnSize) { // TODO: Revisit GIL considerations if we're using python's logger template void LOG(const std::string& formatString, Args&&... args) { - // TODO: Try to do this string concatenation at compile time - std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; - static py::object logging = py::module_::import("mssql_python.logging_config") - .attr("get_logger")(); - if (py::isinstance(logging)) { - return; + py::gil_scoped_acquire gil; // <---- this ensures safe Python API usage + + py::object logger = py::module_::import("mssql_python.logging_config").attr("get_logger")(); + if (py::isinstance(logger)) return; + + try { + std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; + if constexpr (sizeof...(args) == 0) { + logger.attr("debug")(py::str(ddbcFormatString)); + } else { + py::str message = py::str(ddbcFormatString).format(std::forward(args)...); + logger.attr("debug")(message); + } + } catch (const std::exception& e) { + std::cerr << "Logging error: " << e.what() << std::endl; } - py::str message = py::str(ddbcFormatString).format(std::forward(args)...); - logging.attr("debug")(message); } // TODO: Add more nuanced exception classes @@ -668,17 +675,19 @@ DriverHandle LoadDriverOrThrowException() { (archStr == "arm64") ? "arm64" : "x86"; - fs::path dllDir = fs::path(moduleDir) / "libs" / archDir; + fs::path dllDir = fs::path(moduleDir) / "libs" / "windows" / archDir; fs::path authDllPath = dllDir / "mssql-auth.dll"; if (fs::exists(authDllPath)) { HMODULE hAuth = LoadLibraryW(std::wstring(authDllPath.native().begin(), authDllPath.native().end()).c_str()); if (hAuth) { - LOG("Authentication DLL loaded: {}", authDllPath.string()); + LOG("mssql-auth.dll loaded: {}", authDllPath.string()); } else { LOG("Failed to load mssql-auth.dll: {}", GetLastErrorMessage()); + ThrowStdException("Failed to load mssql-auth.dll. Please ensure it is present in the expected directory."); } } else { LOG("Note: mssql-auth.dll not found. This is OK if Entra ID is not in use."); + ThrowStdException("mssql-auth.dll not found. If you are using Entra ID, please ensure it is present."); } #endif @@ -689,6 +698,10 @@ DriverHandle LoadDriverOrThrowException() { DriverHandle handle = LoadDriverLibrary(driverPath.string()); if (!handle) { LOG("Failed to load driver: {}", GetLastErrorMessage()); + // If this happens in linux, suggest installing libltdl7 + #ifdef __linux__ + ThrowStdException("Failed to load ODBC driver. If you are on Linux, please install libltdl7 package."); + #endif ThrowStdException("Failed to load ODBC driver. Please check installation."); } LOG("Driver library successfully loaded."); @@ -776,6 +789,13 @@ SQLSMALLINT SqlHandle::type() const { return _type; } +/* + * IMPORTANT: Never log in destructors - it causes segfaults. + * During program exit, C++ destructors may run AFTER Python shuts down. + * LOG() tries to acquire Python GIL and call Python functions, which crashes + * if Python is already gone. Keep destructors simple - just free resources. + * If you need destruction logs, use explicit close() methods instead. + */ void SqlHandle::free() { if (_handle && SQLFreeHandle_ptr) { const char* type_str = nullptr; @@ -788,9 +808,7 @@ void SqlHandle::free() { } SQLFreeHandle_ptr(_type, _handle); _handle = nullptr; - std::stringstream ss; - ss << "Freed SQL Handle of type: " << type_str; - LOG(ss.str()); + // Don't log during destruction - it can cause segfaults during Python shutdown } } diff --git a/mssql_python/pybind/unix_utils.cpp b/mssql_python/pybind/unix_utils.cpp index 681cdccc8..c98a9e090 100644 --- a/mssql_python/pybind/unix_utils.cpp +++ b/mssql_python/pybind/unix_utils.cpp @@ -14,19 +14,22 @@ const size_t kUcsLength = 2; // SQLWCHAR is 2 bytes on all platform // TODO: Make Logger a separate module and import it across the project template void LOG(const std::string& formatString, Args&&... args) { - // Get the logger each time instead of caching it to ensure we get the latest state - py::object logging_module = py::module_::import("mssql_python.logging_config"); - py::object logger = logging_module.attr("get_logger")(); - - // If logger is None, don't try to log - if (py::isinstance(logger)) { - return; + py::gil_scoped_acquire gil; // <---- this ensures safe Python API usage + + py::object logger = py::module_::import("mssql_python.logging_config").attr("get_logger")(); + if (py::isinstance(logger)) return; + + try { + std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; + if constexpr (sizeof...(args) == 0) { + logger.attr("debug")(py::str(ddbcFormatString)); + } else { + py::str message = py::str(ddbcFormatString).format(std::forward(args)...); + logger.attr("debug")(message); + } + } catch (const std::exception& e) { + std::cerr << "Logging error: " << e.what() << std::endl; } - - // Format the message and log it - std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; - py::str message = py::str(ddbcFormatString).format(std::forward(args)...); - logger.attr("debug")(message); } // Function to convert SQLWCHAR strings to std::wstring on macOS diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 8fd0a8213..bef238151 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -10,10 +10,11 @@ Note: The cursor function is not yet implemented, so related tests are commented out. """ +from mssql_python.exceptions import InterfaceError import pytest import time from mssql_python import Connection, connect, pooling - + def drop_table_if_exists(cursor, table_name): """Drop the table if it exists""" try: diff --git a/tests/test_005_connection_cursor_lifecycle.py b/tests/test_005_connection_cursor_lifecycle.py index 12f7405ac..5fa3d56cd 100644 --- a/tests/test_005_connection_cursor_lifecycle.py +++ b/tests/test_005_connection_cursor_lifecycle.py @@ -72,9 +72,11 @@ def test_cursor_cleanup_without_close(conn_str): def test_no_segfault_on_gc(conn_str): """Test that no segmentation fault occurs during garbage collection""" - code = """ + # Properly escape the connection string for embedding in code + escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') + code = f""" from mssql_python import connect -conn = connect(\"""" + conn_str + """\") +conn = connect("{escaped_conn_str}") cursors = [conn.cursor() for _ in range(5)] for cur in cursors: cur.execute("SELECT 1") diff --git a/tests/test_007_logging.py b/tests/test_007_logging.py index e78c29eb5..fc9907acf 100644 --- a/tests/test_007_logging.py +++ b/tests/test_007_logging.py @@ -1,25 +1,53 @@ import logging import os import pytest -from mssql_python.logging_config import setup_logging, get_logger, ENABLE_LOGGING +import glob +from mssql_python.logging_config import setup_logging, get_logger, LoggingManager def get_log_file_path(): + # Get the LoggingManager singleton instance + manager = LoggingManager() + # If logging is enabled, return the actual log file path + if manager.enabled and manager.log_file: + return manager.log_file + # For fallback/cleanup, try to find existing log files in the logs directory repo_root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + log_dir = os.path.join(repo_root_dir, "mssql_python", "logs") + os.makedirs(log_dir, exist_ok=True) + + # Try to find existing log files + log_files = glob.glob(os.path.join(log_dir, "mssql_python_trace_*.log")) + if log_files: + # Return the most recently created log file + return max(log_files, key=os.path.getctime) + + # Fallback to default pattern pid = os.getpid() - log_file = os.path.join(repo_root_dir, "mssql_python", f"mssql_python_trace_{pid}.log") - return log_file + return os.path.join(log_dir, f"mssql_python_trace_{pid}.log") @pytest.fixture def cleanup_logger(): """Cleanup logger & log files before and after each test""" def cleanup(): + # Get the LoggingManager singleton instance + manager = LoggingManager() logger = get_logger() if logger is not None: logger.handlers.clear() - log_file_path = get_log_file_path() - if os.path.exists(log_file_path): - os.remove(log_file_path) - ENABLE_LOGGING = False + + # Try to remove the actual log file if it exists + try: + log_file_path = get_log_file_path() + if os.path.exists(log_file_path): + os.remove(log_file_path) + except: + pass # Ignore errors during cleanup + + # Reset the LoggingManager instance + manager._enabled = False + manager._initialized = False + manager._logger = None + manager._log_file = None # Perform cleanup before the test cleanup() yield @@ -29,9 +57,11 @@ def cleanup(): def test_no_logging(cleanup_logger): """Test that logging is off by default""" try: + # Get the LoggingManager singleton instance + manager = LoggingManager() logger = get_logger() assert logger is None - assert ENABLE_LOGGING == False + assert manager.enabled == False except Exception as e: pytest.fail(f"Logging not off by default. Error: {e}") @@ -41,7 +71,8 @@ def test_setup_logging(cleanup_logger): setup_logging() # This must enable logging logger = get_logger() assert logger is not None - assert logger == logging.getLogger('mssql_python.logging_config') + # Fix: Check for the correct logger name + assert logger == logging.getLogger('mssql_python') assert logger.level == logging.DEBUG # DEBUG level except Exception as e: pytest.fail(f"Logging setup failed: {e}") @@ -84,4 +115,115 @@ def test_logging_in_stdout_mode(cleanup_logger, capsys): captured_stdout = capsys.readouterr().out assert test_message in captured_stdout, "Log message not found in stdout" except Exception as e: - pytest.fail(f"Logging in stdout mode failed: {e}") \ No newline at end of file + pytest.fail(f"Logging in stdout mode failed: {e}") + +def test_python_layer_prefix(cleanup_logger): + """Test that Python layer logs have the correct prefix""" + try: + setup_logging() + logger = get_logger() + assert logger is not None + + # Log a test message + test_message = "This is a Python layer test message" + logger.info(test_message) + + # Check if the log file contains the message with [Python Layer log] prefix + log_file_path = get_log_file_path() + with open(log_file_path, 'r') as f: + log_content = f.read() + + # The logged message should have the Python Layer prefix + assert "[Python Layer log]" in log_content, "Python Layer log prefix not found" + assert test_message in log_content, "Test message not found in log file" + except Exception as e: + pytest.fail(f"Python layer prefix test failed: {e}") + +def test_different_log_levels(cleanup_logger): + """Test that different log levels work correctly""" + try: + setup_logging() + logger = get_logger() + assert logger is not None + + # Log messages at different levels + debug_msg = "This is a DEBUG message" + info_msg = "This is an INFO message" + warning_msg = "This is a WARNING message" + error_msg = "This is an ERROR message" + + logger.debug(debug_msg) + logger.info(info_msg) + logger.warning(warning_msg) + logger.error(error_msg) + + # Check if the log file contains all messages + log_file_path = get_log_file_path() + with open(log_file_path, 'r') as f: + log_content = f.read() + + assert debug_msg in log_content, "DEBUG message not found in log file" + assert info_msg in log_content, "INFO message not found in log file" + assert warning_msg in log_content, "WARNING message not found in log file" + assert error_msg in log_content, "ERROR message not found in log file" + + # Also check for level indicators in the log + assert "DEBUG" in log_content, "DEBUG level not found in log file" + assert "INFO" in log_content, "INFO level not found in log file" + assert "WARNING" in log_content, "WARNING level not found in log file" + assert "ERROR" in log_content, "ERROR level not found in log file" + except Exception as e: + pytest.fail(f"Log levels test failed: {e}") + +def test_singleton_behavior(cleanup_logger): + """Test that LoggingManager behaves as a singleton""" + try: + # Create multiple instances of LoggingManager + manager1 = LoggingManager() + manager2 = LoggingManager() + + # They should be the same instance + assert manager1 is manager2, "LoggingManager instances are not the same" + + # Enable logging through one instance + manager1._enabled = True + + # The other instance should reflect this change + assert manager2.enabled == True, "Singleton state not shared between instances" + + # Reset for cleanup + manager1._enabled = False + except Exception as e: + pytest.fail(f"Singleton behavior test failed: {e}") + +def test_timestamp_in_log_filename(cleanup_logger): + """Test that log filenames include timestamps""" + try: + setup_logging() + + # Get the log file path + log_file_path = get_log_file_path() + filename = os.path.basename(log_file_path) + + # Extract parts of the filename + parts = filename.split('_') + + # The filename should follow the pattern: mssql_python_trace_YYYYMMDD_HHMMSS_PID.log + # Fix: Account for the fact that "mssql_python" contains an underscore + assert parts[0] == "mssql", "Incorrect filename prefix part 1" + assert parts[1] == "python", "Incorrect filename prefix part 2" + assert parts[2] == "trace", "Incorrect filename part" + + # Check date format (YYYYMMDD) + date_part = parts[3] + assert len(date_part) == 8 and date_part.isdigit(), "Date format incorrect in filename" + + # Check time format (HHMMSS) + time_part = parts[4] + assert len(time_part) == 6 and time_part.isdigit(), "Time format incorrect in filename" + + # Process ID should be the last part before .log + pid_part = parts[5].split('.')[0] + assert pid_part.isdigit(), "Process ID not found in filename" + except Exception as e: + pytest.fail(f"Timestamp in filename test failed: {e}") \ No newline at end of file