Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 22 additions & 29 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,44 +845,37 @@ std::string GetLastErrorMessage();

// TODO: Move this to Python
std::string GetModuleDirectory() {
namespace fs = std::filesystem;
py::object module = py::module::import("mssql_python");
py::object module_path = module.attr("__file__");
std::string module_file = module_path.cast<std::string>();

#ifdef _WIN32
// Windows-specific path handling
char path[MAX_PATH];
errno_t err = strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length());
if (err != 0) {
LOG("GetModuleDirectory: strncpy_s failed copying path - "
"error_code=%d, path_length=%zu",
err, module_file.length());
return {};
}
PathRemoveFileSpecA(path);
return std::string(path);
#else
// macOS/Unix path handling without using std::filesystem
std::string::size_type pos = module_file.find_last_of('/');
if (pos != std::string::npos) {
std::string dir = module_file.substr(0, pos);
return dir;
}
LOG("GetModuleDirectory: Could not extract directory from module path - "
"path='%s'",
module_file.c_str());
return module_file;
#endif
// Use std::filesystem::path for cross-platform path handling
// This properly handles UTF-8 encoded paths on all platforms
fs::path modulePath(module_file);
fs::path parentDir = modulePath.parent_path();

// Log path extraction for observability
LOG("GetModuleDirectory: Extracted directory - "
"original_path='%s', directory='%s'",
module_file.c_str(), parentDir.string().c_str());

// Return UTF-8 encoded string for consistent handling
// If parentDir is empty or invalid, subsequent operations (like LoadDriverLibrary)
// will fail naturally with clear error messages
return parentDir.string();
}

// Platform-agnostic function to load the driver dynamic library
DriverHandle LoadDriverLibrary(const std::string& driverPath) {
LOG("LoadDriverLibrary: Attempting to load ODBC driver from path='%s'", driverPath.c_str());

#ifdef _WIN32
// Windows: Convert string to wide string for LoadLibraryW
std::wstring widePath(driverPath.begin(), driverPath.end());
HMODULE handle = LoadLibraryW(widePath.c_str());
// Windows: Use std::filesystem::path for proper UTF-8 to UTF-16 conversion
// fs::path::c_str() returns wchar_t* on Windows with correct encoding
namespace fs = std::filesystem;
fs::path pathObj(driverPath);
HMODULE handle = LoadLibraryW(pathObj.c_str());
if (!handle) {
LOG("LoadDriverLibrary: LoadLibraryW failed for path='%s' - %s", driverPath.c_str(),
GetLastErrorMessage().c_str());
Expand Down Expand Up @@ -1013,8 +1006,8 @@ DriverHandle LoadDriverOrThrowException() {
fs::path dllDir = fs::path(moduleDir) / "libs" / "windows" / archDir;
fs::path authDllPath = dllDir / "mssql-auth.dll";
if (fs::exists(authDllPath)) {
HMODULE hAuth = LoadLibraryW(
std::wstring(authDllPath.native().begin(), authDllPath.native().end()).c_str());
// Use fs::path::c_str() which returns wchar_t* on Windows with proper encoding
HMODULE hAuth = LoadLibraryW(authDllPath.c_str());
if (hAuth) {
LOG("LoadDriverOrThrowException: mssql-auth.dll loaded "
"successfully from '%s'",
Expand Down
6 changes: 6 additions & 0 deletions tests/test_013_SqlHandle_free_shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@
import threading
import time

import pytest


class TestHandleFreeShutdown:
"""Test SqlHandle::free() behavior for all handle types during Python shutdown."""

@pytest.mark.stress
def test_aggressive_dbc_segfault_reproduction(self, conn_str):
"""
AGGRESSIVE TEST: Try to reproduce DBC handle segfault during shutdown.
Expand Down Expand Up @@ -157,6 +160,7 @@ def on_exit():
assert result.returncode == 0, f"Process failed. stderr: {result.stderr}"
print(f"PASS: DBC handle cleanup properly skipped during shutdown")

@pytest.mark.stress
def test_force_gc_finalization_order_issue(self, conn_str):
"""
TEST: Force specific GC finalization order to trigger segfault.
Expand Down Expand Up @@ -434,6 +438,7 @@ def test_mixed_handle_cleanup_at_shutdown(self, conn_str):
assert "Connection 3: everything properly closed" in result.stdout
print(f"PASS: Mixed handle cleanup during shutdown")

@pytest.mark.stress
def test_rapid_connection_churn_with_shutdown(self, conn_str):
"""
Test rapid connection creation/deletion followed by shutdown.
Expand Down Expand Up @@ -1087,6 +1092,7 @@ def close(self):
assert "Mixed scenario: PASSED" in result.stdout
print(f"PASS: Cleanup connections mixed scenario")

@pytest.mark.stress
def test_active_connections_thread_safety(self, conn_str):
"""
Test _active_connections thread-safety with concurrent registration.
Expand Down
Loading
Loading