From 25d6ef09d89965c762bd02a4d3d11c83eb00d846 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Mon, 19 May 2025 14:23:04 +0530 Subject: [PATCH 1/4] refactor native layer and create reusable components --- mssql_python/pybind/ddbc_bindings.cpp | 492 +++++++++++++------------- mssql_python/pybind/ddbc_bindings.h | 138 ++++++++ 2 files changed, 378 insertions(+), 252 deletions(-) create mode 100644 mssql_python/pybind/ddbc_bindings.h diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 742cd556..e181c7df 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -9,22 +9,18 @@ #include #include // std::setw, std::setfill #include -#include #include // std::forward // Replace std::filesystem usage with Windows-specific headers #include #pragma comment(lib, "shlwapi.lib") +#include "ddbc_bindings.h" #include #include #include #include // Add this line for datetime support #include -#include // windows.h needs to be included before sql.h -#include -#include - namespace py = pybind11; using namespace pybind11::literals; @@ -115,59 +111,6 @@ struct ErrorInfo { std::wstring ddbcErrorMsg; }; - -//------------------------------------------------------------------------------------------------- -// Function pointer typedefs -//------------------------------------------------------------------------------------------------- - -// Handle APIs -typedef SQLRETURN (*SQLAllocHandleFunc)(SQLSMALLINT, SQLHANDLE, SQLHANDLE*); -typedef SQLRETURN (*SQLSetEnvAttrFunc)(SQLHANDLE, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (*SQLSetConnectAttrFunc)(SQLHDBC, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (*SQLSetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (*SQLGetConnectAttrFunc)(SQLHDBC, SQLINTEGER, SQLPOINTER, SQLINTEGER, - SQLINTEGER*); - -// Connection and Execution APIs -typedef SQLRETURN (*SQLDriverConnectFunc)(SQLHANDLE, SQLHWND, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLSMALLINT*, SQLUSMALLINT); -typedef SQLRETURN (*SQLExecDirectFunc)(SQLHANDLE, SQLWCHAR*, SQLINTEGER); -typedef SQLRETURN (*SQLPrepareFunc)(SQLHANDLE, SQLWCHAR*, SQLINTEGER); -typedef SQLRETURN (*SQLBindParameterFunc)(SQLHANDLE, SQLUSMALLINT, SQLSMALLINT, SQLSMALLINT, - SQLSMALLINT, SQLULEN, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); -typedef SQLRETURN (*SQLExecuteFunc)(SQLHANDLE); -typedef SQLRETURN (*SQLRowCountFunc)(SQLHSTMT, SQLLEN*); -typedef SQLRETURN (*SQLSetDescFieldFunc)(SQLHDESC, SQLSMALLINT, SQLSMALLINT, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (*SQLGetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, SQLPOINTER, SQLINTEGER, SQLINTEGER*); - -// Data retrieval APIs -typedef SQLRETURN (*SQLFetchFunc)(SQLHANDLE); -typedef SQLRETURN (*SQLFetchScrollFunc)(SQLHANDLE, SQLSMALLINT, SQLLEN); -typedef SQLRETURN (*SQLGetDataFunc)(SQLHANDLE, SQLUSMALLINT, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); -typedef SQLRETURN (*SQLNumResultColsFunc)(SQLHSTMT, SQLSMALLINT*); -typedef SQLRETURN (*SQLBindColFunc)(SQLHSTMT, SQLUSMALLINT, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); -typedef SQLRETURN (*SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT, - SQLSMALLINT*, SQLSMALLINT*, SQLULEN*, SQLSMALLINT*, - SQLSMALLINT*); -typedef SQLRETURN (*SQLMoreResultsFunc)(SQLHSTMT); -typedef SQLRETURN (*SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, - SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); - -// Transaction APIs -typedef SQLRETURN (*SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); - -// Disconnect/free APIs -typedef SQLRETURN (*SQLFreeHandleFunc)(SQLSMALLINT, SQLHANDLE); -typedef SQLRETURN (*SQLDisconnectFunc)(SQLHDBC); -typedef SQLRETURN (*SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT); - -// Diagnostic APIs -typedef SQLRETURN (*SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*, SQLINTEGER*, - SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*); - //------------------------------------------------------------------------------------------------- // Function pointer initialization //------------------------------------------------------------------------------------------------- @@ -210,183 +153,8 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr; // Diagnostic APIs SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr; -// Move GetModuleDirectory outside namespace to resolve ambiguity -std::string GetModuleDirectory() { - py::object module = py::module::import("mssql_python"); - py::object module_path = module.attr("__file__"); - std::string module_file = module_path.cast(); - - char path[MAX_PATH]; - strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length()); - PathRemoveFileSpecA(path); - return std::string(path); -} - -// Smart wrapper around SQLHANDLE -class SqlHandle { -public: - SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle) : _type(type), _handle(rawHandle) {} - // Optional: global flag to disable cleanup during shutdown - ~SqlHandle() { - // Note: Destructor is intentionally a no-op. Python owns the lifecycle. - // Native ODBC handles must be explicitly released by calling `free()` directly from Python. - // This avoids nondeterministic crashes during GC or shutdown during pytest. - // Read the documentation for more details (https://aka.ms/CPPvsPythonGC) - } - void free() { - if (_handle && SQLFreeHandle_ptr) { - SQLFreeHandle_ptr(_type, _handle); - _handle = nullptr; - } - } - SQLHANDLE get() const { return _handle; } -private: - SQLSMALLINT _type; - SQLHANDLE _handle; -}; -using SqlHandlePtr = std::shared_ptr; - namespace { -// TODO: Revisit GIL considerations if we're using python's logger -template -void LOG(const std::string& formatString, Args&&... args) { - // TODO: Try to do this string concatenation at compile time - std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; - static py::object logging = py::module_::import("mssql_python.logging_config") - .attr("get_logger")(); - if (py::isinstance(logging)) { - return; - } - py::str message = py::str(ddbcFormatString).format(std::forward(args)...); - logging.attr("debug")(message); -} - -// TODO: Add more nuanced exception classes -void ThrowStdException(const std::string& message) { throw std::runtime_error(message); } - -// Helper to load the driver -// TODO: We don't need to do explicit linking using LoadLibrary. We can just use implicit -// linking to load this DLL. It will simplify the code a lot. -std::wstring LoadDriverOrThrowException(const std::wstring& modulePath = L"") { - std::wstring ddbcModulePath = modulePath; - if (ddbcModulePath.empty()) { - // Get the module path if not provided - std::string path = GetModuleDirectory(); - ddbcModulePath = std::wstring(path.begin(), path.end()); - } - - std::wstring dllDir = ddbcModulePath; - dllDir += L"\\libs\\"; - - // Convert ARCHITECTURE macro to wstring - std::wstring archStr(ARCHITECTURE, ARCHITECTURE + strlen(ARCHITECTURE)); - - // Map architecture identifiers to correct subdirectory names - std::wstring archDir; - if (archStr == L"win64" || archStr == L"amd64" || archStr == L"x64") { - archDir = L"x64"; - } else if (archStr == L"arm64") { - archDir = L"arm64"; - } else { - archDir = L"x86"; - } - dllDir += archDir; - std::wstring mssqlauthDllPath = dllDir + L"\\mssql-auth.dll"; - dllDir += L"\\msodbcsql18.dll"; - - // Preload mssql-auth.dll from the same path if available - // TODO: Only load mssql-auth.dll if using Entra ID Authentication modes (Active Directory modes) - HMODULE hAuthModule = LoadLibraryW(mssqlauthDllPath.c_str()); - if (hAuthModule) { - LOG("Authentication library loaded successfully from - {}", mssqlauthDllPath.c_str()); - } else { - LOG("Note: Authentication library not found at - {}. This is OK if you're not using Entra ID Authentication.", mssqlauthDllPath.c_str()); - } - - // Convert wstring to string for logging - std::string dllDirStr(dllDir.begin(), dllDir.end()); - LOG("Attempting to load driver from - {}", dllDirStr); - - HMODULE hModule = LoadLibraryW(dllDir.c_str()); - if (!hModule) { - // Failed to load the DLL, get the error message - DWORD error = GetLastError(); - char* messageBuffer = nullptr; - size_t size = FormatMessageA( - FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - error, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPSTR)&messageBuffer, - 0, - NULL - ); - std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; - LocalFree(messageBuffer); - - // Log the error message - LOG("Failed to load the driver with error code: {} - {}", error, errorMessage); - ThrowStdException("Failed to load the ODBC driver. Please check that it is installed correctly."); - } - - // If we got here, we've successfully loaded the DLL. Now get the function pointers. - // Environment and handle function loading - SQLAllocHandle_ptr = (SQLAllocHandleFunc)GetProcAddress(hModule, "SQLAllocHandle"); - SQLSetEnvAttr_ptr = (SQLSetEnvAttrFunc)GetProcAddress(hModule, "SQLSetEnvAttr"); - SQLSetConnectAttr_ptr = (SQLSetConnectAttrFunc)GetProcAddress(hModule, "SQLSetConnectAttrW"); - SQLSetStmtAttr_ptr = (SQLSetStmtAttrFunc)GetProcAddress(hModule, "SQLSetStmtAttrW"); - SQLGetConnectAttr_ptr = (SQLGetConnectAttrFunc)GetProcAddress(hModule, "SQLGetConnectAttrW"); - - // Connection and statement function loading - SQLDriverConnect_ptr = (SQLDriverConnectFunc)GetProcAddress(hModule, "SQLDriverConnectW"); - SQLExecDirect_ptr = (SQLExecDirectFunc)GetProcAddress(hModule, "SQLExecDirectW"); - SQLPrepare_ptr = (SQLPrepareFunc)GetProcAddress(hModule, "SQLPrepareW"); - SQLBindParameter_ptr = (SQLBindParameterFunc)GetProcAddress(hModule, "SQLBindParameter"); - SQLExecute_ptr = (SQLExecuteFunc)GetProcAddress(hModule, "SQLExecute"); - SQLRowCount_ptr = (SQLRowCountFunc)GetProcAddress(hModule, "SQLRowCount"); - SQLGetStmtAttr_ptr = (SQLGetStmtAttrFunc)GetProcAddress(hModule, "SQLGetStmtAttrW"); - SQLSetDescField_ptr = (SQLSetDescFieldFunc)GetProcAddress(hModule, "SQLSetDescFieldW"); - - // Fetch and data retrieval function loading - SQLFetch_ptr = (SQLFetchFunc)GetProcAddress(hModule, "SQLFetch"); - SQLFetchScroll_ptr = (SQLFetchScrollFunc)GetProcAddress(hModule, "SQLFetchScroll"); - SQLGetData_ptr = (SQLGetDataFunc)GetProcAddress(hModule, "SQLGetData"); - SQLNumResultCols_ptr = (SQLNumResultColsFunc)GetProcAddress(hModule, "SQLNumResultCols"); - SQLBindCol_ptr = (SQLBindColFunc)GetProcAddress(hModule, "SQLBindCol"); - SQLDescribeCol_ptr = (SQLDescribeColFunc)GetProcAddress(hModule, "SQLDescribeColW"); - SQLMoreResults_ptr = (SQLMoreResultsFunc)GetProcAddress(hModule, "SQLMoreResults"); - SQLColAttribute_ptr = (SQLColAttributeFunc)GetProcAddress(hModule, "SQLColAttributeW"); - - // Transaction functions loading - SQLEndTran_ptr = (SQLEndTranFunc)GetProcAddress(hModule, "SQLEndTran"); - - // Disconnect and free functions loading - SQLFreeHandle_ptr = (SQLFreeHandleFunc)GetProcAddress(hModule, "SQLFreeHandle"); - SQLDisconnect_ptr = (SQLDisconnectFunc)GetProcAddress(hModule, "SQLDisconnect"); - SQLFreeStmt_ptr = (SQLFreeStmtFunc)GetProcAddress(hModule, "SQLFreeStmt"); - - // Diagnostic record function Loading - SQLGetDiagRec_ptr = (SQLGetDiagRecFunc)GetProcAddress(hModule, "SQLGetDiagRecW"); - - bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && - SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr && - SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && - SQLRowCount_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr && SQLFetch_ptr && - SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr && - SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && - SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr && - SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr; - - if (!success) { - LOG("Failed to load required function pointers from driver - {}", dllDirStr); - ThrowStdException("Failed to load required function pointers from driver"); - } - LOG("Successfully loaded function pointers from driver"); - - return dllDir; -} - const char* GetSqlCTypeAsString(const SQLSMALLINT cType) { switch (cType) { STRINGIFY_FOR_CASE(SQL_C_CHAR); @@ -702,11 +470,214 @@ void HandleZeroColumnSizeAtFetch(SQLULEN& columnSize) { } // namespace +// TODO: Revisit GIL considerations if we're using python's logger +template +void LOG(const std::string& formatString, Args&&... args) { + // TODO: Try to do this string concatenation at compile time + std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; + static py::object logging = py::module_::import("mssql_python.logging_config") + .attr("get_logger")(); + if (py::isinstance(logging)) { + return; + } + py::str message = py::str(ddbcFormatString).format(std::forward(args)...); + logging.attr("debug")(message); +} + +// TODO: Add more nuanced exception classes +void ThrowStdException(const std::string& message) { throw std::runtime_error(message); } + +std::string GetModuleDirectory() { + py::object module = py::module::import("mssql_python"); + py::object module_path = module.attr("__file__"); + std::string module_file = module_path.cast(); + + char path[MAX_PATH]; + strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length()); + PathRemoveFileSpecA(path); + return std::string(path); +} + +// Helper to load the driver +// TODO: We don't need to do explicit linking using LoadLibrary. We can just use implicit +// linking to load this DLL. It will simplify the code a lot. +std::wstring LoadDriverOrThrowException() { + const std::wstring& modulePath = L""; + std::wstring ddbcModulePath = modulePath; + if (ddbcModulePath.empty()) { + // Get the module path if not provided + std::string path = GetModuleDirectory(); + ddbcModulePath = std::wstring(path.begin(), path.end()); + } + + std::wstring dllDir = ddbcModulePath; + dllDir += L"\\libs\\"; + + // Convert ARCHITECTURE macro to wstring + std::wstring archStr(ARCHITECTURE, ARCHITECTURE + strlen(ARCHITECTURE)); + + // Map architecture identifiers to correct subdirectory names + std::wstring archDir; + if (archStr == L"win64" || archStr == L"amd64" || archStr == L"x64") { + archDir = L"x64"; + } else if (archStr == L"arm64") { + archDir = L"arm64"; + } else { + archDir = L"x86"; + } + dllDir += archDir; + std::wstring mssqlauthDllPath = dllDir + L"\\mssql-auth.dll"; + dllDir += L"\\msodbcsql18.dll"; + + // Preload mssql-auth.dll from the same path if available + // TODO: Only load mssql-auth.dll if using Entra ID Authentication modes (Active Directory modes) + HMODULE hAuthModule = LoadLibraryW(mssqlauthDllPath.c_str()); + if (hAuthModule) { + LOG("Authentication library loaded successfully from - {}", mssqlauthDllPath.c_str()); + } else { + LOG("Note: Authentication library not found at - {}. This is OK if you're not using Entra ID Authentication.", mssqlauthDllPath.c_str()); + } + + // Convert wstring to string for logging + std::string dllDirStr(dllDir.begin(), dllDir.end()); + LOG("Attempting to load driver from - {}", dllDirStr); + + HMODULE hModule = LoadLibraryW(dllDir.c_str()); + if (!hModule) { + // Failed to load the DLL, get the error message + DWORD error = GetLastError(); + char* messageBuffer = nullptr; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, + error, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&messageBuffer, + 0, + NULL + ); + std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; + LocalFree(messageBuffer); + + // Log the error message + LOG("Failed to load the driver with error code: {} - {}", error, errorMessage); + ThrowStdException("Failed to load the ODBC driver. Please check that it is installed correctly."); + } + + // If we got here, we've successfully loaded the DLL. Now get the function pointers. + // Environment and handle function loading + SQLAllocHandle_ptr = (SQLAllocHandleFunc)GetProcAddress(hModule, "SQLAllocHandle"); + SQLSetEnvAttr_ptr = (SQLSetEnvAttrFunc)GetProcAddress(hModule, "SQLSetEnvAttr"); + SQLSetConnectAttr_ptr = (SQLSetConnectAttrFunc)GetProcAddress(hModule, "SQLSetConnectAttrW"); + SQLSetStmtAttr_ptr = (SQLSetStmtAttrFunc)GetProcAddress(hModule, "SQLSetStmtAttrW"); + SQLGetConnectAttr_ptr = (SQLGetConnectAttrFunc)GetProcAddress(hModule, "SQLGetConnectAttrW"); + + // Connection and statement function loading + SQLDriverConnect_ptr = (SQLDriverConnectFunc)GetProcAddress(hModule, "SQLDriverConnectW"); + SQLExecDirect_ptr = (SQLExecDirectFunc)GetProcAddress(hModule, "SQLExecDirectW"); + SQLPrepare_ptr = (SQLPrepareFunc)GetProcAddress(hModule, "SQLPrepareW"); + SQLBindParameter_ptr = (SQLBindParameterFunc)GetProcAddress(hModule, "SQLBindParameter"); + SQLExecute_ptr = (SQLExecuteFunc)GetProcAddress(hModule, "SQLExecute"); + SQLRowCount_ptr = (SQLRowCountFunc)GetProcAddress(hModule, "SQLRowCount"); + SQLGetStmtAttr_ptr = (SQLGetStmtAttrFunc)GetProcAddress(hModule, "SQLGetStmtAttrW"); + SQLSetDescField_ptr = (SQLSetDescFieldFunc)GetProcAddress(hModule, "SQLSetDescFieldW"); + + // Fetch and data retrieval function loading + SQLFetch_ptr = (SQLFetchFunc)GetProcAddress(hModule, "SQLFetch"); + SQLFetchScroll_ptr = (SQLFetchScrollFunc)GetProcAddress(hModule, "SQLFetchScroll"); + SQLGetData_ptr = (SQLGetDataFunc)GetProcAddress(hModule, "SQLGetData"); + SQLNumResultCols_ptr = (SQLNumResultColsFunc)GetProcAddress(hModule, "SQLNumResultCols"); + SQLBindCol_ptr = (SQLBindColFunc)GetProcAddress(hModule, "SQLBindCol"); + SQLDescribeCol_ptr = (SQLDescribeColFunc)GetProcAddress(hModule, "SQLDescribeColW"); + SQLMoreResults_ptr = (SQLMoreResultsFunc)GetProcAddress(hModule, "SQLMoreResults"); + SQLColAttribute_ptr = (SQLColAttributeFunc)GetProcAddress(hModule, "SQLColAttributeW"); + + // Transaction functions loading + SQLEndTran_ptr = (SQLEndTranFunc)GetProcAddress(hModule, "SQLEndTran"); + + // Disconnect and free functions loading + SQLFreeHandle_ptr = (SQLFreeHandleFunc)GetProcAddress(hModule, "SQLFreeHandle"); + SQLDisconnect_ptr = (SQLDisconnectFunc)GetProcAddress(hModule, "SQLDisconnect"); + SQLFreeStmt_ptr = (SQLFreeStmtFunc)GetProcAddress(hModule, "SQLFreeStmt"); + + // Diagnostic record function Loading + SQLGetDiagRec_ptr = (SQLGetDiagRecFunc)GetProcAddress(hModule, "SQLGetDiagRecW"); + + bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && + SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr && + SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && + SQLRowCount_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr && SQLFetch_ptr && + SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr && + SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && + SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr && + SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr; + + if (!success) { + LOG("Failed to load required function pointers from driver - {}", dllDirStr); + ThrowStdException("Failed to load required function pointers from driver"); + } + LOG("Successfully loaded function pointers from driver"); + + return dllDir; +} + +// DriverLoader definition +DriverLoader::DriverLoader() : m_driverLoaded(false) {} + +DriverLoader& DriverLoader::getInstance() { + static DriverLoader instance; + return instance; +} + +void DriverLoader::loadDriver() { + if (!m_driverLoaded) { + LoadDriverOrThrowException(); + m_driverLoaded = true; + } +} + +// SqlHandle definition +SqlHandle::SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle) + : _type(type), _handle(rawHandle) {} + +// Note: Destructor is intentionally a no-op. Python owns the lifecycle. +// Native ODBC handles must be explicitly released by calling `free()` directly from Python. +// This avoids nondeterministic crashes during GC or shutdown during pytest. +// Read the documentation for more details (https://aka.ms/CPPvsPythonGC) +SqlHandle::~SqlHandle() {} + +SQLHANDLE SqlHandle::get() const { + return _handle; +} + +SQLSMALLINT SqlHandle::type() const { + return _type; +} + +void SqlHandle::free() { + if (_handle && SQLFreeHandle_ptr) { + const char* type_str = nullptr; + switch (_type) { + case SQL_HANDLE_ENV: type_str = "ENV"; break; + case SQL_HANDLE_DBC: type_str = "DBC"; break; + case SQL_HANDLE_STMT: type_str = "STMT"; break; + case SQL_HANDLE_DESC: type_str = "DESC"; break; + default: type_str = "UNKNOWN"; break; + } + SQLFreeHandle_ptr(_type, _handle); + _handle = nullptr; + std::stringstream ss; + ss << "Freed SQL Handle of type: " << type_str; + LOG(ss.str()); + } +} + // Wrap SQLAllocHandle SQLRETURN SQLAllocHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr InputHandle, SqlHandlePtr& OutputHandle) { LOG("Allocate SQL Handle"); if (!SQLAllocHandle_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLHANDLE rawOutputHandle = nullptr; @@ -724,7 +695,8 @@ SQLRETURN SQLSetEnvAttr_wrap(SqlHandlePtr EnvHandle, SQLINTEGER Attribute, intpt SQLINTEGER StringLength) { LOG("Set SQL environment Attribute"); if (!SQLSetEnvAttr_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } // TODO: Does ValuePtr need to be converted from Python to C++ object? @@ -740,7 +712,8 @@ SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attri py::object ValuePtr) { LOG("Set SQL Connection Attribute"); if (!SQLSetConnectAttr_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } // Print the type of ValuePtr and attribute value - helpful for debugging @@ -796,7 +769,8 @@ SQLRETURN SQLSetStmtAttr_wrap(SqlHandlePtr StatementHandle, SQLINTEGER Attribute SQLINTEGER StringLength) { LOG("Set SQL Statement Attribute"); if (!SQLSetConnectAttr_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } // TODO: Does ValuePtr need to be converted from Python to C++ object? @@ -813,7 +787,8 @@ SQLRETURN SQLSetStmtAttr_wrap(SqlHandlePtr StatementHandle, SQLINTEGER Attribute SQLINTEGER SQLGetConnectionAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER attribute) { LOG("Get SQL COnnection Attribute"); if (!SQLGetConnectAttr_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLINTEGER stringLength; @@ -838,7 +813,8 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET SQLHANDLE rawHandle = handle->get(); if (!SQL_SUCCEEDED(retcode)) { if (!SQLGetDiagRec_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLWCHAR sqlState[6], message[SQL_MAX_MESSAGE_LENGTH]; @@ -861,7 +837,8 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET SQLRETURN SQLDriverConnect_wrap(SqlHandlePtr ConnectionHandle, intptr_t WindowHandle, const std::wstring& ConnectionString) { LOG("Driver Connect to MSSQL"); if (!SQLDriverConnect_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLRETURN ret = SQLDriverConnect_ptr(ConnectionHandle->get(), reinterpret_cast(WindowHandle), @@ -877,7 +854,8 @@ SQLRETURN SQLDriverConnect_wrap(SqlHandlePtr ConnectionHandle, intptr_t WindowHa SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) { LOG("Execute SQL query directly - {}", Query.c_str()); if (!SQLExecDirect_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLRETURN ret = SQLExecDirect_ptr(StatementHandle->get(), const_cast(Query.c_str()), SQL_NTS); @@ -897,7 +875,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, py::list& isStmtPrepared, const bool usePrepare = true) { LOG("Execute SQL Query - {}", query.c_str()); if (!SQLPrepare_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } assert(SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && SQLExecDirect_ptr); @@ -970,7 +949,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, SQLSMALLINT SQLNumResultCols_wrap(SqlHandlePtr statementHandle) { LOG("Get number of columns in result set"); if (!SQLNumResultCols_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLSMALLINT columnCount; @@ -983,7 +963,8 @@ SQLSMALLINT SQLNumResultCols_wrap(SqlHandlePtr statementHandle) { SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMetadata) { LOG("Get column description"); if (!SQLDescribeCol_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLSMALLINT ColumnCount; @@ -1024,7 +1005,8 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) { LOG("Fetch next row"); if (!SQLFetch_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } return SQLFetch_ptr(StatementHandle->get()); @@ -1035,7 +1017,8 @@ SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) { SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, py::list& row) { LOG("Get data from columns"); if (!SQLGetData_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLRETURN ret; @@ -2015,7 +1998,8 @@ SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row) { SQLRETURN SQLMoreResults_wrap(SqlHandlePtr StatementHandle) { LOG("Check for more results"); if (!SQLMoreResults_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } return SQLMoreResults_ptr(StatementHandle->get()); @@ -2025,7 +2009,8 @@ SQLRETURN SQLMoreResults_wrap(SqlHandlePtr StatementHandle) { SQLRETURN SQLEndTran_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle, SQLSMALLINT CompletionType) { LOG("End SQL Transaction"); if (!SQLEndTran_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } return SQLEndTran_ptr(HandleType, Handle->get(), CompletionType); @@ -2035,7 +2020,8 @@ SQLRETURN SQLEndTran_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle, SQLSMALLI SQLRETURN SQLFreeHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle) { LOG("Free SQL handle"); if (!SQLAllocHandle_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLRETURN ret = SQLFreeHandle_ptr(HandleType, Handle->get()); @@ -2049,7 +2035,8 @@ SQLRETURN SQLFreeHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle) { SQLRETURN SQLDisconnect_wrap(SqlHandlePtr ConnectionHandle) { LOG("Disconnect from MSSQL"); if (!SQLDisconnect_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } return SQLDisconnect_ptr(ConnectionHandle->get()); @@ -2059,7 +2046,8 @@ SQLRETURN SQLDisconnect_wrap(SqlHandlePtr ConnectionHandle) { SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) { LOG("Get number of row affected by last execute"); if (!SQLRowCount_ptr) { - LoadDriverOrThrowException(); + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver } SQLLEN rowCount; @@ -2156,7 +2144,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { try { // Try loading the ODBC driver when the module is imported - LoadDriverOrThrowException(); + DriverLoader::getInstance().loadDriver(); // Load the driver } catch (const std::exception& e) { // Log the error but don't throw - let the error happen when functions are called LOG("Failed to load ODBC driver during module initialization: {}", e.what()); diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h new file mode 100644 index 00000000..9d21683f --- /dev/null +++ b/mssql_python/pybind/ddbc_bindings.h @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include + +//------------------------------------------------------------------------------------------------- +// Function pointer typedefs +//------------------------------------------------------------------------------------------------- + +// Handle APIs +typedef SQLRETURN (SQL_API* SQLAllocHandleFunc)(SQLSMALLINT, SQLHANDLE, SQLHANDLE*); +typedef SQLRETURN (SQL_API* SQLSetEnvAttrFunc)(SQLHANDLE, SQLINTEGER, SQLPOINTER, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLSetConnectAttrFunc)(SQLHDBC, SQLINTEGER, SQLPOINTER, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLSetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, SQLPOINTER, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLGetConnectAttrFunc)(SQLHDBC, SQLINTEGER, SQLPOINTER, SQLINTEGER, SQLINTEGER*); + +// Connection and Execution APIs +typedef SQLRETURN (SQL_API* SQLDriverConnectFunc)(SQLHANDLE, SQLHWND, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLSMALLINT*, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLExecDirectFunc)(SQLHANDLE, SQLWCHAR*, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLPrepareFunc)(SQLHANDLE, SQLWCHAR*, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLBindParameterFunc)(SQLHANDLE, SQLUSMALLINT, SQLSMALLINT, SQLSMALLINT, + SQLSMALLINT, SQLULEN, SQLSMALLINT, SQLPOINTER, SQLLEN, + SQLLEN*); +typedef SQLRETURN (SQL_API* SQLExecuteFunc)(SQLHANDLE); +typedef SQLRETURN (SQL_API* SQLRowCountFunc)(SQLHSTMT, SQLLEN*); +typedef SQLRETURN (SQL_API* SQLSetDescFieldFunc)(SQLHDESC, SQLSMALLINT, SQLSMALLINT, SQLPOINTER, SQLINTEGER); +typedef SQLRETURN (SQL_API* SQLGetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, SQLPOINTER, SQLINTEGER, SQLINTEGER*); + +// Data retrieval APIs +typedef SQLRETURN (SQL_API* SQLFetchFunc)(SQLHANDLE); +typedef SQLRETURN (SQL_API* SQLFetchScrollFunc)(SQLHANDLE, SQLSMALLINT, SQLLEN); +typedef SQLRETURN (SQL_API* SQLGetDataFunc)(SQLHANDLE, SQLUSMALLINT, SQLSMALLINT, SQLPOINTER, SQLLEN, + SQLLEN*); +typedef SQLRETURN (SQL_API* SQLNumResultColsFunc)(SQLHSTMT, SQLSMALLINT*); +typedef SQLRETURN (SQL_API* SQLBindColFunc)(SQLHSTMT, SQLUSMALLINT, SQLSMALLINT, SQLPOINTER, SQLLEN, + SQLLEN*); +typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLSMALLINT*, SQLSMALLINT*, SQLULEN*, SQLSMALLINT*, + SQLSMALLINT*); +typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT); +typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, + SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); + +// Transaction APIs +typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); + +// Disconnect/free APIs +typedef SQLRETURN (SQL_API* SQLFreeHandleFunc)(SQLSMALLINT, SQLHANDLE); +typedef SQLRETURN (SQL_API* SQLDisconnectFunc)(SQLHDBC); +typedef SQLRETURN (SQL_API* SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT); + +// Diagnostic APIs +typedef SQLRETURN (SQL_API* SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*, SQLINTEGER*, + SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*); + +//------------------------------------------------------------------------------------------------- +// Extern function pointer declarations (defined in ddbc_bindings.cpp) +//------------------------------------------------------------------------------------------------- + +// Handle APIs +extern SQLAllocHandleFunc SQLAllocHandle_ptr; +extern SQLSetEnvAttrFunc SQLSetEnvAttr_ptr; +extern SQLSetConnectAttrFunc SQLSetConnectAttr_ptr; +extern SQLSetStmtAttrFunc SQLSetStmtAttr_ptr; +extern SQLGetConnectAttrFunc SQLGetConnectAttr_ptr; + +// Connection and Execution APIs +extern SQLDriverConnectFunc SQLDriverConnect_ptr; +extern SQLExecDirectFunc SQLExecDirect_ptr; +extern SQLPrepareFunc SQLPrepare_ptr; +extern SQLBindParameterFunc SQLBindParameter_ptr; +extern SQLExecuteFunc SQLExecute_ptr; +extern SQLRowCountFunc SQLRowCount_ptr; +extern SQLSetDescFieldFunc SQLSetDescField_ptr; +extern SQLGetStmtAttrFunc SQLGetStmtAttr_ptr; + +// Data retrieval APIs +extern SQLFetchFunc SQLFetch_ptr; +extern SQLFetchScrollFunc SQLFetchScroll_ptr; +extern SQLGetDataFunc SQLGetData_ptr; +extern SQLNumResultColsFunc SQLNumResultCols_ptr; +extern SQLBindColFunc SQLBindCol_ptr; +extern SQLDescribeColFunc SQLDescribeCol_ptr; +extern SQLMoreResultsFunc SQLMoreResults_ptr; +extern SQLColAttributeFunc SQLColAttribute_ptr; + +// Transaction APIs +extern SQLEndTranFunc SQLEndTran_ptr; + +// Disconnect/free APIs +extern SQLFreeHandleFunc SQLFreeHandle_ptr; +extern SQLDisconnectFunc SQLDisconnect_ptr; +extern SQLFreeStmtFunc SQLFreeStmt_ptr; + +// Diagnostic APIs +extern SQLGetDiagRecFunc SQLGetDiagRec_ptr; + + +// -- Logging utility -- +template +void LOG(const std::string& formatString, Args&&... args); + +// -- Exception helper -- +void ThrowStdException(const std::string& message); + +// -- Driver loader -- +std::wstring LoadDriverOrThrowException(); + +// -- Singleton wrapper -- +class DriverLoader { + public: + static DriverLoader& getInstance(); + void loadDriver(); + private: + DriverLoader(); + bool m_driverLoaded; + }; + +// -- SqlHandle wrapper -- +class SqlHandle { + public: + SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle); + ~SqlHandle(); + SQLHANDLE get() const; + SQLSMALLINT type() const; + void free(); + private: + SQLSMALLINT _type; + SQLHANDLE _handle; + }; + using SqlHandlePtr = std::shared_ptr; \ No newline at end of file From f4c9c30f7e59707d3247a3caf08389ec11178cff Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Mon, 19 May 2025 14:25:46 +0530 Subject: [PATCH 2/4] add newline --- mssql_python/pybind/ddbc_bindings.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 9d21683f..7adc84a4 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -135,4 +135,5 @@ class SqlHandle { SQLSMALLINT _type; SQLHANDLE _handle; }; - using SqlHandlePtr = std::shared_ptr; \ No newline at end of file + using SqlHandlePtr = std::shared_ptr; + \ No newline at end of file From 6c823780516db7f8acf0c63097fd60fec309ced7 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Mon, 19 May 2025 15:53:50 +0530 Subject: [PATCH 3/4] Delete copy/assign for DriverLoader singleton --- mssql_python/pybind/ddbc_bindings.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 7adc84a4..5a5d9125 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -120,6 +120,8 @@ class DriverLoader { void loadDriver(); private: DriverLoader(); + DriverLoader(const DriverLoader&) = delete; + DriverLoader& operator=(const DriverLoader&) = delete; bool m_driverLoaded; }; @@ -136,4 +138,3 @@ class SqlHandle { SQLHANDLE _handle; }; using SqlHandlePtr = std::shared_ptr; - \ No newline at end of file From 11821907a1c3a4651605901ce4c75861838775fa Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Tue, 20 May 2025 14:20:35 +0530 Subject: [PATCH 4/4] resolve review comments --- mssql_python/pybind/ddbc_bindings.cpp | 2 +- mssql_python/pybind/ddbc_bindings.h | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index e181c7df..493b6269 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -530,7 +530,6 @@ std::wstring LoadDriverOrThrowException() { dllDir += L"\\msodbcsql18.dll"; // Preload mssql-auth.dll from the same path if available - // TODO: Only load mssql-auth.dll if using Entra ID Authentication modes (Active Directory modes) HMODULE hAuthModule = LoadLibraryW(mssqlauthDllPath.c_str()); if (hAuthModule) { LOG("Authentication library loaded successfully from - {}", mssqlauthDllPath.c_str()); @@ -2144,6 +2143,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { try { // Try loading the ODBC driver when the module is imported + LOG("Loading ODBC driver"); DriverLoader::getInstance().loadDriver(); // Load the driver } catch (const std::exception& e) { // Log the error but don't throw - let the error happen when functions are called diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 5a5d9125..81801379 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be +// taken up in future. + #pragma once #include @@ -110,10 +113,20 @@ void LOG(const std::string& formatString, Args&&... args); // -- Exception helper -- void ThrowStdException(const std::string& message); -// -- Driver loader -- +//------------------------------------------------------------------------------------------------- +// Loads the ODBC driver and resolves function pointers. +// Throws if loading or resolution fails. +//------------------------------------------------------------------------------------------------- std::wstring LoadDriverOrThrowException(); -// -- Singleton wrapper -- +//------------------------------------------------------------------------------------------------- +// DriverLoader (Singleton) +// +// Ensures the ODBC driver and all function pointers are loaded exactly once across the process. +// This avoids redundant work and ensures thread-safe, centralized initialization. +// +// Not copyable or assignable. +//------------------------------------------------------------------------------------------------- class DriverLoader { public: static DriverLoader& getInstance(); @@ -125,7 +138,12 @@ class DriverLoader { bool m_driverLoaded; }; -// -- SqlHandle wrapper -- +//------------------------------------------------------------------------------------------------- +// SqlHandle +// +// RAII wrapper around ODBC handles (ENV, DBC, STMT). +// Use `std::shared_ptr` (alias: SqlHandlePtr) for shared ownership. +//------------------------------------------------------------------------------------------------- class SqlHandle { public: SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle);