diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 06cf32da..13367b4c 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -665,24 +665,67 @@ void HandleZeroColumnSizeAtFetch(SQLULEN& columnSize) { } // namespace +// Helper function to check if Python is shutting down or finalizing +// This centralizes the shutdown detection logic to avoid code duplication +static bool is_python_finalizing() { + try { + if (Py_IsInitialized() == 0) { + return true; // Python is already shut down + } + + py::gil_scoped_acquire gil; + py::object sys_module = py::module_::import("sys"); + if (!sys_module.is_none()) { + // Check if the attribute exists before accessing it (for Python version compatibility) + if (py::hasattr(sys_module, "_is_finalizing")) { + py::object finalizing_func = sys_module.attr("_is_finalizing"); + if (!finalizing_func.is_none() && finalizing_func().cast()) { + return true; // Python is finalizing + } + } + } + return false; + } catch (...) { + std::cerr << "Error occurred while checking Python finalization state." << std::endl; + // Be conservative - don't assume shutdown on any exception + // Only return true if we're absolutely certain Python is shutting down + return false; + } +} + // TODO: Revisit GIL considerations if we're using python's logger template void LOG(const std::string& formatString, Args&&... args) { - py::gil_scoped_acquire gil; // <---- this ensures safe Python API usage + // Check if Python is shutting down to avoid crash during cleanup + if (is_python_finalizing()) { + return; // Python is shutting down or finalizing, don't log + } + + try { + py::gil_scoped_acquire gil; // <---- this ensures safe Python API usage - py::object logger = py::module_::import("mssql_python.logging_config").attr("get_logger")(); - if (py::isinstance(logger)) return; + py::object logger = py::module_::import("mssql_python.logging_config").attr("get_logger")(); + if (py::isinstance(logger)) return; - try { - std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; - if constexpr (sizeof...(args) == 0) { - logger.attr("debug")(py::str(ddbcFormatString)); - } else { - py::str message = py::str(ddbcFormatString).format(std::forward(args)...); - logger.attr("debug")(message); + try { + std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; + if constexpr (sizeof...(args) == 0) { + logger.attr("debug")(py::str(ddbcFormatString)); + } else { + py::str message = py::str(ddbcFormatString).format(std::forward(args)...); + logger.attr("debug")(message); + } + } catch (const std::exception& e) { + std::cerr << "Logging error: " << e.what() << std::endl; } + } catch (const py::error_already_set& e) { + // Python is shutting down or in an inconsistent state, silently ignore + (void)e; // Suppress unused variable warning + return; } catch (const std::exception& e) { - std::cerr << "Logging error: " << e.what() << std::endl; + // Any other error, ignore to prevent crash during cleanup + (void)e; // Suppress unused variable warning + return; } } @@ -993,17 +1036,26 @@ SQLSMALLINT SqlHandle::type() const { */ void SqlHandle::free() { if (_handle && SQLFreeHandle_ptr) { - const char* type_str = nullptr; - switch (_type) { - case SQL_HANDLE_ENV: type_str = "ENV"; break; - case SQL_HANDLE_DBC: type_str = "DBC"; break; - case SQL_HANDLE_STMT: type_str = "STMT"; break; - case SQL_HANDLE_DESC: type_str = "DESC"; break; - default: type_str = "UNKNOWN"; break; + // Check if Python is shutting down using centralized helper function + bool pythonShuttingDown = is_python_finalizing(); + + // CRITICAL FIX: During Python shutdown, don't free STMT handles as their parent DBC may already be freed + // This prevents segfault when handles are freed in wrong order during interpreter shutdown + // Type 3 = SQL_HANDLE_STMT, Type 2 = SQL_HANDLE_DBC, Type 1 = SQL_HANDLE_ENV + if (pythonShuttingDown && _type == 3) { + _handle = nullptr; // Mark as freed to prevent double-free attempts + return; } + + // Always clean up ODBC resources, regardless of Python state SQLFreeHandle_ptr(_type, _handle); _handle = nullptr; - // Don't log during destruction - it can cause segfaults during Python shutdown + + // Only log if Python is not shutting down (to avoid segfault) + if (!pythonShuttingDown) { + // Don't log during destruction - even in normal cases it can be problematic + // If logging is needed, use explicit close() methods instead + } } } diff --git a/tests/test_005_connection_cursor_lifecycle.py b/tests/test_005_connection_cursor_lifecycle.py index d87c3f21..df777c08 100644 --- a/tests/test_005_connection_cursor_lifecycle.py +++ b/tests/test_005_connection_cursor_lifecycle.py @@ -475,3 +475,212 @@ def test_mixed_cursor_cleanup_scenarios(conn_str, tmp_path): assert "All tests passed" in result.stdout # Should not have error logs assert "Exception during cursor cleanup" not in result.stderr + + +def test_sql_syntax_error_no_segfault_on_shutdown(conn_str): + """Test that SQL syntax errors don't cause segfault during Python shutdown""" + # This test reproduces the exact scenario that was causing segfaults + escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') + code = f""" +from mssql_python import connect + +# Create connection +conn = connect("{escaped_conn_str}") +cursor = conn.cursor() + +# Execute invalid SQL that causes syntax error - this was causing segfault +cursor.execute("syntax error") + +# Don't explicitly close cursor/connection - let Python shutdown handle cleanup +print("Script completed, shutting down...") # This would NOT print anyways +# Segfault would happen here during Python shutdown +""" + + # Run in subprocess to catch segfaults + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True + ) + + # Should not segfault (exit code 139 on Unix, 134 on macOS) + assert result.returncode == 1, f"Expected exit code 1 due to syntax error, but got {result.returncode}. STDERR: {result.stderr}" + +def test_multiple_sql_syntax_errors_no_segfault(conn_str): + """Test multiple SQL syntax errors don't cause segfault during cleanup""" + escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') + code = f""" +from mssql_python import connect + +conn = connect("{escaped_conn_str}") + +# Multiple cursors with syntax errors +cursors = [] +for i in range(3): + cursor = conn.cursor() + cursors.append(cursor) + cursor.execute(f"invalid sql syntax {{i}}") + +# Mix of syntax errors and valid queries +cursor_valid = conn.cursor() +cursor_valid.execute("SELECT 1") +cursor_valid.fetchall() +cursors.append(cursor_valid) + +# Don't close anything - test Python shutdown cleanup +print("Multiple syntax errors handled, shutting down...") +""" + + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True + ) + + assert result.returncode == 1, f"Expected exit code 1 due to syntax errors, but got {result.returncode}. STDERR: {result.stderr}" + + +def test_connection_close_during_active_query_no_segfault(conn_str): + """Test closing connection while cursor has pending results doesn't cause segfault""" + escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') + code = f""" +from mssql_python import connect + +# Create connection and cursor +conn = connect("{escaped_conn_str}") +cursor = conn.cursor() + +# Execute query but don't fetch results - leave them pending +cursor.execute("SELECT COUNT(*) FROM sys.objects") + +# Close connection while results are still pending +# This tests handle cleanup when STMT has pending results but DBC is freed +conn.close() + +print("Connection closed with pending cursor results") +# Cursor destructor will run during normal cleanup, not shutdown +""" + + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True + ) + + # Should not segfault - should exit cleanly + assert result.returncode == 0, f"Expected clean exit, but got exit code {result.returncode}. STDERR: {result.stderr}" + assert "Connection closed with pending cursor results" in result.stdout + + +def test_concurrent_cursor_operations_no_segfault(conn_str): + """Test concurrent cursor operations don't cause segfaults or race conditions""" + escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') + code = f""" +import threading +from mssql_python import connect + +conn = connect("{escaped_conn_str}") +results = [] +exceptions = [] + +def worker(thread_id): + try: + for i in range(15): + cursor = conn.cursor() + cursor.execute(f"SELECT {{thread_id * 100 + i}} as value") + result = cursor.fetchone() + results.append(result[0]) + # Don't explicitly close cursor - test concurrent destructors + except Exception as e: + exceptions.append(f"Thread {{thread_id}}: {{e}}") + +# Create multiple threads doing concurrent cursor operations +threads = [] +for i in range(4): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + +for t in threads: + t.join() + +print(f"Completed: {{len(results)}} results, {{len(exceptions)}} exceptions") + +# Report any exceptions for debugging +for exc in exceptions: + print(f"Exception: {{exc}}") + +print("Concurrent operations completed") +""" + + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True + ) + + # Should not segfault + assert result.returncode == 0, f"Expected clean exit, but got exit code {result.returncode}. STDERR: {result.stderr}" + assert "Concurrent operations completed" in result.stdout + + # Check that most operations completed successfully + # Allow for some exceptions due to threading, but shouldn't be many + output_lines = result.stdout.split('\n') + completed_line = [line for line in output_lines if 'Completed:' in line] + if completed_line: + # Extract numbers from "Completed: X results, Y exceptions" + import re + match = re.search(r'Completed: (\d+) results, (\d+) exceptions', completed_line[0]) + if match: + results_count = int(match.group(1)) + exceptions_count = int(match.group(2)) + # Should have completed most operations (allow some threading issues) + assert results_count >= 50, f"Too few successful operations: {results_count}" + assert exceptions_count <= 10, f"Too many exceptions: {exceptions_count}" + + +def test_aggressive_threading_abrupt_exit_no_segfault(conn_str): + """Test abrupt exit with active threads and pending queries doesn't cause segfault""" + escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') + code = f""" +import threading +import sys +import time +from mssql_python import connect + +conn = connect("{escaped_conn_str}") + +def aggressive_worker(thread_id): + '''Worker that creates cursors with pending results and doesn't clean up''' + for i in range(8): + cursor = conn.cursor() + # Execute query but don't fetch - leave results pending + cursor.execute(f"SELECT COUNT(*) FROM sys.objects WHERE object_id > {{thread_id * 1000 + i}}") + + # Create another cursor immediately without cleaning up the first + cursor2 = conn.cursor() + cursor2.execute(f"SELECT TOP 3 * FROM sys.objects WHERE object_id > {{thread_id * 1000 + i}}") + + # Don't fetch results, don't close cursors - maximum chaos + time.sleep(0.005) # Let other threads interleave + +# Start multiple daemon threads +for i in range(3): + t = threading.Thread(target=aggressive_worker, args=(i,), daemon=True) + t.start() + +# Let them run briefly then exit abruptly +time.sleep(0.3) +print("Exiting abruptly with active threads and pending queries") +sys.exit(0) # Abrupt exit without joining threads +""" + + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True + ) + + # Should not segfault - should exit cleanly even with abrupt exit + assert result.returncode == 0, f"Expected clean exit, but got exit code {result.returncode}. STDERR: {result.stderr}" + assert "Exiting abruptly with active threads and pending queries" in result.stdout