From 4591b558b7fce2d47cf8c8fabf1512153f234478 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Mon, 19 May 2025 14:23:04 +0530 Subject: [PATCH 01/18] refactor native layer and create reusable components --- mssql_python/pybind/ddbc_bindings.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 81801379..33579a0a 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -1,9 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +<<<<<<< HEAD // INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be // taken up in future. +======= +>>>>>>> 25d6ef0 (refactor native layer and create reusable components) #pragma once #include From e375b202ef86e280346344b9f1a3fc309a78d8fe Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Mon, 19 May 2025 14:25:46 +0530 Subject: [PATCH 02/18] add newline --- mssql_python/pybind/ddbc_bindings.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 33579a0a..44bb29e6 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -159,3 +159,4 @@ class SqlHandle { SQLHANDLE _handle; }; using SqlHandlePtr = std::shared_ptr; + \ No newline at end of file From b5a7d5a0ecb6a92f556fffe331de7e8028992a62 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Mon, 19 May 2025 15:53:50 +0530 Subject: [PATCH 03/18] Delete copy/assign for DriverLoader singleton --- mssql_python/pybind/ddbc_bindings.h | 1 - 1 file changed, 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 44bb29e6..33579a0a 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -159,4 +159,3 @@ class SqlHandle { SQLHANDLE _handle; }; using SqlHandlePtr = std::shared_ptr; - \ No newline at end of file From fc5903a7e8b62ac2c1d3dc6fb7e75038f79e2c77 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Tue, 20 May 2025 14:20:35 +0530 Subject: [PATCH 04/18] resolve review comments --- mssql_python/pybind/ddbc_bindings.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 33579a0a..9bf4a754 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -1,12 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -<<<<<<< HEAD // INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be // taken up in future. -======= ->>>>>>> 25d6ef0 (refactor native layer and create reusable components) +>>>>>>> 1182190 (resolve review comments) #pragma once #include From 998b25e1905ab4c43e33377a677fe4e32d511a12 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Mon, 19 May 2025 19:54:42 +0530 Subject: [PATCH 05/18] initial edit --- mssql_python/pybind/connection/connection.cpp | 48 +++++ mssql_python/pybind/ddbc_bindings.cpp | 195 ++---------------- mssql_python/pybind/ddbc_bindings.h | 1 - 3 files changed, 61 insertions(+), 183 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index aa577c60..33dec23e 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -6,6 +6,10 @@ #include "connection.h" #include +#include + +#include + //------------------------------------------------------------------------------------------------- // Implements the Connection class declared in connection.h. @@ -114,11 +118,55 @@ bool Connection::getAutocommit() const { return value == SQL_AUTOCOMMIT_ON; } +<<<<<<< HEAD SqlHandlePtr Connection::allocStatementHandle() { if (!_dbc_handle) { throw std::runtime_error("Connection handle not allocated"); } LOG("Allocating statement handle"); +======= +SQLRETURN set_attribute(SQLINTEGER Attribute, py::object ValuePtr) { + LOG("Set SQL Connection Attribute"); + if (!SQLSetConnectAttr_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver + } + + // Print the type of ValuePtr and attribute value - helpful for debugging + LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast(), Attribute); + + SQLPOINTER value = 0; + SQLINTEGER length = 0; + + if (py::isinstance(ValuePtr)) { + // Handle integer values + int intValue = ValuePtr.cast(); + value = reinterpret_cast(intValue); + length = SQL_IS_INTEGER; // Integer values don't require a length + } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { + // Handle byte or bytearray values (like access tokens) + // Store in static buffer to ensure memory remains valid during connection + static std::vector bytesBuffers; + bytesBuffers.push_back(ValuePtr.cast()); + value = const_cast(bytesBuffers.back().c_str()); + length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token) + } else { + LOG("Unsupported ValuePtr type"); + return SQL_ERROR; + } + + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), Attribute, value, length); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to set Connection attribute"); + } + else { + LOG("Set Connection attribute successfully"); + } + return ret; +} + +SqlHandlePtr Connection::alloc_statement_handle() { +>>>>>>> 52a2dba (initial edit) SQLHANDLE stmt = nullptr; SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbc_handle->get(), &stmt); if (!SQL_SUCCEEDED(ret)) { diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 493b6269..b26fd6a6 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -671,134 +671,6 @@ void SqlHandle::free() { } } -// Wrap SQLAllocHandle -SQLRETURN SQLAllocHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr InputHandle, SqlHandlePtr& OutputHandle) { - LOG("Allocate SQL Handle"); - if (!SQLAllocHandle_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - SQLHANDLE rawOutputHandle = nullptr; - SQLRETURN ret = SQLAllocHandle_ptr(HandleType, InputHandle ? InputHandle->get() : nullptr, &rawOutputHandle); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to allocate handle"); - return ret; - } - OutputHandle = std::make_shared(HandleType, rawOutputHandle); - return ret; -} - -// Wrap SQLSetEnvAttr -SQLRETURN SQLSetEnvAttr_wrap(SqlHandlePtr EnvHandle, SQLINTEGER Attribute, intptr_t ValuePtr, - SQLINTEGER StringLength) { - LOG("Set SQL environment Attribute"); - if (!SQLSetEnvAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - // TODO: Does ValuePtr need to be converted from Python to C++ object? - SQLRETURN ret = SQLSetEnvAttr_ptr(EnvHandle->get(), Attribute, reinterpret_cast(ValuePtr), StringLength); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set environment attribute"); - } - return ret; -} - -// Wrap SQLSetConnectAttr -SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, - py::object ValuePtr) { - LOG("Set SQL Connection Attribute"); - if (!SQLSetConnectAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - // Print the type of ValuePtr and attribute value - helpful for debugging - LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast(), Attribute); - - SQLPOINTER value = 0; - SQLINTEGER length = 0; - - if (py::isinstance(ValuePtr)) { - // Handle integer values - int intValue = ValuePtr.cast(); - value = reinterpret_cast(intValue); - length = SQL_IS_INTEGER; // Integer values don't require a length - // } else if (py::isinstance(ValuePtr)) { - // // Handle Unicode string values - // static std::wstring unicodeValueBuffer; - // unicodeValueBuffer = ValuePtr.cast(); - // value = const_cast(unicodeValueBuffer.c_str()); - // length = SQL_NTS; // Indicates null-terminated string - } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { - // Handle byte or bytearray values (like access tokens) - // Store in static buffer to ensure memory remains valid during connection - static std::vector bytesBuffers; - bytesBuffers.push_back(ValuePtr.cast()); - value = const_cast(bytesBuffers.back().c_str()); - length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token) - // } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { - // // Handle list or tuple values - // LOG("ValuePtr is a sequence (list or tuple)"); - // for (py::handle item : ValuePtr) { - // LOG("Processing item in sequence"); - // SQLRETURN ret = SQLSetConnectAttr_wrap(ConnectionHandle, Attribute, py::reinterpret_borrow(item)); - // if (!SQL_SUCCEEDED(ret)) { - // LOG("Failed to set attribute for item in sequence"); - // return ret; - // } - // } - } else { - LOG("Unsupported ValuePtr type"); - return SQL_ERROR; - } - - SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set Connection attribute"); - } - LOG("Set Connection attribute successfully"); - return ret; -} - -// Wrap SQLSetStmtAttr -SQLRETURN SQLSetStmtAttr_wrap(SqlHandlePtr StatementHandle, SQLINTEGER Attribute, intptr_t ValuePtr, - SQLINTEGER StringLength) { - LOG("Set SQL Statement Attribute"); - if (!SQLSetConnectAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - // TODO: Does ValuePtr need to be converted from Python to C++ object? - SQLRETURN ret = SQLSetStmtAttr_ptr(StatementHandle->get(), Attribute, reinterpret_cast(ValuePtr), StringLength); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set Statement attribute"); - } - return ret; -} - -// Wrap SQLGetConnectionAttrA -// Currently only supports retrieval of int-valued attributes -// TODO: add support to retrieve all types of attributes -SQLINTEGER SQLGetConnectionAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER attribute) { - LOG("Get SQL COnnection Attribute"); - if (!SQLGetConnectAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - SQLINTEGER stringLength; - SQLINTEGER intValue; - - // Try to get the attribute as an integer - SQLGetConnectAttr_ptr(ConnectionHandle->get(), attribute, &intValue, - sizeof(SQLINTEGER), &stringLength); - return intValue; -} - // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -832,23 +704,6 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET return errorInfo; } -// Wrap SQLDriverConnect -SQLRETURN SQLDriverConnect_wrap(SqlHandlePtr ConnectionHandle, intptr_t WindowHandle, const std::wstring& ConnectionString) { - LOG("Driver Connect to MSSQL"); - if (!SQLDriverConnect_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - SQLRETURN ret = SQLDriverConnect_ptr(ConnectionHandle->get(), - reinterpret_cast(WindowHandle), - const_cast(ConnectionString.c_str()), SQL_NTS, nullptr, - 0, nullptr, SQL_DRIVER_NOPROMPT); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to connect to DB"); - } - return ret; -} - // Wrap SQLExecDirect SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) { LOG("Execute SQL query directly - {}", Query.c_str()); @@ -2004,17 +1859,6 @@ SQLRETURN SQLMoreResults_wrap(SqlHandlePtr StatementHandle) { return SQLMoreResults_ptr(StatementHandle->get()); } -// Wrap SQLEndTran -SQLRETURN SQLEndTran_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle, SQLSMALLINT CompletionType) { - LOG("End SQL Transaction"); - if (!SQLEndTran_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - return SQLEndTran_ptr(HandleType, Handle->get(), CompletionType); -} - // Wrap SQLFreeHandle SQLRETURN SQLFreeHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle) { LOG("Free SQL handle"); @@ -2030,17 +1874,6 @@ SQLRETURN SQLFreeHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle) { return ret; } -// Wrap SQLDisconnect -SQLRETURN SQLDisconnect_wrap(SqlHandlePtr ConnectionHandle) { - LOG("Disconnect from MSSQL"); - if (!SQLDisconnect_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - return SQLDisconnect_ptr(ConnectionHandle->get()); -} - // Wrap SQLRowCount SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) { LOG("Get number of row affected by last execute"); @@ -2102,22 +1935,21 @@ PYBIND11_MODULE(ddbc_bindings, m) { py::class_(m, "SqlHandle") .def("free", &SqlHandle::free); - - m.def("DDBCSQLAllocHandle", [](SQLSMALLINT HandleType, SqlHandlePtr InputHandle = nullptr) { - SqlHandlePtr OutputHandle; - SQLRETURN rc = SQLAllocHandle_wrap(HandleType, InputHandle, OutputHandle); - return py::make_tuple(rc, OutputHandle); - }, "Allocate an environment, connection, statement, or descriptor handle"); - m.def("DDBCSQLSetEnvAttr", &SQLSetEnvAttr_wrap, - "Set an attribute that governs aspects of environments"); + py::class_(m, "Connection") + .def(py::init(), py::arg("conn_str")) + .def("connect", &Connection::connect, "Establish a connection to the database") + .def("close", &Connection::close, "Close the connection") + .def("commit", [](Connection& self) { + self.end_transaction(SQL_COMMIT); + }) + .def("rollback", [](Connection& self) { + self.end_transaction(SQL_ROLLBACK)}) + .def("set_autocommit", &Connection::set_autocommit) + .def("get_autocommit", &Connection::get_autocommit) + .def("set_attribute", &Connection::set_attribute); + .def("alloc_statement_handle", &Connection::alloc_statement_handle); m.def("DDBCSQLSetConnectAttr", &SQLSetConnectAttr_wrap, "Set an attribute that governs aspects of connections"); - m.def("DDBCSQLSetStmtAttr", &SQLSetStmtAttr_wrap, - "Set an attribute that governs aspects of statements"); - m.def("DDBCSQLGetConnectionAttr", &SQLGetConnectionAttr_wrap, - "Get an attribute that governs aspects of connections"); - m.def("DDBCSQLDriverConnect", &SQLDriverConnect_wrap, - "Connect to a data source with a connection string"); m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements"); m.def("DDBCSQLRowCount", &SQLRowCount_wrap, @@ -2135,7 +1967,6 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); m.def("DDBCSQLEndTran", &SQLEndTran_wrap, "End a transaction"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); - m.def("DDBCSQLDisconnect", &SQLDisconnect_wrap, "Disconnect from a data source"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); // Add a version attribute diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 9bf4a754..81801379 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -4,7 +4,6 @@ // INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be // taken up in future. ->>>>>>> 1182190 (resolve review comments) #pragma once #include From 7f9304bae893967539a49a4912e6f8408b864cd9 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Mon, 19 May 2025 22:46:55 +0530 Subject: [PATCH 06/18] working flow with c++ connection class --- mssql_python/connection.py | 321 ++++++------------ mssql_python/cursor.py | 12 +- mssql_python/pybind/connection/connection.cpp | 20 +- mssql_python/pybind/ddbc_bindings.cpp | 48 ++- mssql_python/pybind/ddbc_bindings.h | 4 +- 5 files changed, 162 insertions(+), 243 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 8031a26a..f4fc39b5 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -53,14 +53,20 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef preparing it for further operations such as connecting to the database, executing queries, etc. """ - self.henv = None - self.hdbc = None self.connection_str = self._construct_connection_string( connection_str, **kwargs ) - self._attrs_before = attrs_before - self._autocommit = autocommit # Initialize _autocommit before calling _initializer - self._initializer() + # self._attrs_before = attrs_before + # self._autocommit = autocommit + # if self._attrs_before != {}: + # self._apply_attrs_before() # Apply pre-connection attributes + # if self._autocommit: + # self.setautocommit(autocommit) + print("Connection string:", self.connection_str) + self._conn = ddbc_bindings.Connection(self.connection_str) + self._conn.connect() + print("Connection established") + self._autocommit = autocommit self.setautocommit(autocommit) def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str: @@ -100,177 +106,90 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st logger.info("Final connection string: %s", conn_str) return conn_str - - def _is_closed(self) -> bool: - """ - Check if the connection is closed. - - Returns: - bool: True if the connection is closed, False otherwise. - """ - return self.hdbc is None - def _initializer(self) -> None: - """ - Initialize the environment and connection handles. - - This method is responsible for setting up the environment and connection - handles, allocating memory for them, and setting the necessary attributes. - It should be called before establishing a connection to the database. - """ - self._allocate_environment_handle() - self._set_environment_attributes() - self._allocate_connection_handle() - if self._attrs_before != {}: - self._apply_attrs_before() # Apply pre-connection attributes - if self._autocommit: - self._set_connection_attributes( - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, - ddbc_sql_const.SQL_AUTOCOMMIT_ON.value, - ) - self._connect_to_db() - - def _apply_attrs_before(self): - """ - Apply specific pre-connection attributes. - Currently, this method only processes an attribute with key 1256 (e.g., SQL_COPT_SS_ACCESS_TOKEN) - if present in `self._attrs_before`. Other attributes are ignored. - - Returns: - bool: True. - """ - - if ENABLE_LOGGING: - logger.info("Attempting to apply pre-connection attributes (attrs_before): %s", self._attrs_before) - - if not isinstance(self._attrs_before, dict): - if self._attrs_before is not None and ENABLE_LOGGING: - logger.warning( - f"_attrs_before is of type {type(self._attrs_before).__name__}, " - f"expected dict. Skipping attribute application." - ) - elif self._attrs_before is None and ENABLE_LOGGING: - logger.debug("_attrs_before is None. No pre-connection attributes to apply.") - return True # Exit if _attrs_before is not a dictionary or is None - - for key, value in self._attrs_before.items(): - ikey = None - if isinstance(key, int): - ikey = key - elif isinstance(key, str) and key.isdigit(): - try: - ikey = int(key) - except ValueError: - if ENABLE_LOGGING: - logger.debug( - f"Skipping attribute with key '{key}' in attrs_before: " - f"could not convert string to int." - ) - continue # Skip if string key is not a valid integer - else: - if ENABLE_LOGGING: - logger.debug( - f"Skipping attribute with key '{key}' in attrs_before due to " - f"unsupported key type: {type(key).__name__}. Expected int or string representation of an int." - ) - continue # Skip keys that are not int or string representation of an int - - if ikey == ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value: - if ENABLE_LOGGING: - logger.info( - f"Found attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value}. Attempting to set it." - ) - self._set_connection_attributes(ikey, value) - if ENABLE_LOGGING: - logger.info( - f"Call to set attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value} with value '{value}' completed." - ) - # If you expect only one such key, you could add 'break' here. - else: - if ENABLE_LOGGING: - logger.debug( - f"Ignoring attribute with key '{key}' (resolved to {ikey}) in attrs_before " - f"as it is not the target attribute ({ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value})." - ) - return True - - def _allocate_environment_handle(self): - """ - Allocate the environment handle. - """ - ret, handle = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_ENV.value, # SQL environment handle type - None - ) - check_error(ddbc_sql_const.SQL_HANDLE_ENV.value, handle, ret) - self.henv = handle - - def _set_environment_attributes(self): - """ - Set the environment attributes. - """ - ret = ddbc_bindings.DDBCSQLSetEnvAttr( - self.henv, # Use the wrapper class - ddbc_sql_const.SQL_ATTR_DDBC_VERSION.value, # Attribute - ddbc_sql_const.SQL_OV_DDBC3_80.value, # String Length - 0, # Null-terminated string - ) - check_error(ddbc_sql_const.SQL_HANDLE_ENV.value, self.henv, ret) - - def _allocate_connection_handle(self): - """ - Allocate the connection handle. - """ - ret, handle = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_DBC.value, # SQL connection handle type - self.henv - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, handle, ret) - self.hdbc = handle - - def _set_connection_attributes(self, ikey: int, ivalue: any) -> None: - """ - Set the connection attributes before connecting. - - Args: - ikey (int): The attribute key to set. - ivalue (Any): The value to set for the attribute. Can be bytes, bytearray, int, or unicode. - vallen (int): The length of the value. - - Raises: - DatabaseError: If there is an error while setting the connection attribute. - """ - - ret = ddbc_bindings.DDBCSQLSetConnectAttr( - self.hdbc, # Connection handle - ikey, # Attribute - ivalue, # Value - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - - def _connect_to_db(self) -> None: - """ - Establish a connection to the database. - - This method is responsible for creating a connection to the specified database. - It does not take any arguments and does not return any value. The connection - details such as database name, user credentials, host, and port should be - configured within the class or passed during the class instantiation. - - Raises: - DatabaseError: If there is an error while trying to connect to the database. - InterfaceError: If there is an error related to the database interface. - """ - if ENABLE_LOGGING: - logger.info("Connecting to the database") - ret = ddbc_bindings.DDBCSQLDriverConnect( - self.hdbc, # Connection handle (wrapper) - 0, # Window handle - self.connection_str, # Connection string - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - if ENABLE_LOGGING: - logger.info("Connection established successfully.") + # def _apply_attrs_before(self): + # """ + # Apply specific pre-connection attributes. + # Currently, this method only processes an attribute with key 1256 (e.g., SQL_COPT_SS_ACCESS_TOKEN) + # if present in `self._attrs_before`. Other attributes are ignored. + + # Returns: + # bool: True. + # """ + + # if ENABLE_LOGGING: + # logger.info("Attempting to apply pre-connection attributes (attrs_before): %s", self._attrs_before) + + # if not isinstance(self._attrs_before, dict): + # if self._attrs_before is not None and ENABLE_LOGGING: + # logger.warning( + # f"_attrs_before is of type {type(self._attrs_before).__name__}, " + # f"expected dict. Skipping attribute application." + # ) + # elif self._attrs_before is None and ENABLE_LOGGING: + # logger.debug("_attrs_before is None. No pre-connection attributes to apply.") + # return True # Exit if _attrs_before is not a dictionary or is None + + # for key, value in self._attrs_before.items(): + # ikey = None + # if isinstance(key, int): + # ikey = key + # elif isinstance(key, str) and key.isdigit(): + # try: + # ikey = int(key) + # except ValueError: + # if ENABLE_LOGGING: + # logger.debug( + # f"Skipping attribute with key '{key}' in attrs_before: " + # f"could not convert string to int." + # ) + # continue # Skip if string key is not a valid integer + # else: + # if ENABLE_LOGGING: + # logger.debug( + # f"Skipping attribute with key '{key}' in attrs_before due to " + # f"unsupported key type: {type(key).__name__}. Expected int or string representation of an int." + # ) + # continue # Skip keys that are not int or string representation of an int + + # if ikey == ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value: + # if ENABLE_LOGGING: + # logger.info( + # f"Found attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value}. Attempting to set it." + # ) + # self._conn.set_attribute(ikey, value) + # if ENABLE_LOGGING: + # logger.info( + # f"Call to set attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value} with value '{value}' completed." + # ) + # # If you expect only one such key, you could add 'break' here. + # else: + # if ENABLE_LOGGING: + # logger.debug( + # f"Ignoring attribute with key '{key}' (resolved to {ikey}) in attrs_before " + # f"as it is not the target attribute ({ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value})." + # ) + # return True + + # def _set_connection_attributes(self, ikey: int, ivalue: any) -> None: + # """ + # Set the connection attributes before connecting. + + # Args: + # ikey (int): The attribute key to set. + # ivalue (Any): The value to set for the attribute. Can be bytes, bytearray, int, or unicode. + # vallen (int): The length of the value. + + # Raises: + # DatabaseError: If there is an error while setting the connection attribute. + # """ + + # ret = ddbc_bindings.DDBCSQLSetConnectAttr( + # self.hdbc, # Connection handle + # ikey, # Attribute + # ivalue, # Value + # ) + # check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) @property def autocommit(self) -> bool: @@ -279,14 +198,7 @@ def autocommit(self) -> bool: Returns: bool: True if autocommit is enabled, False otherwise. """ - autocommit_mode = ddbc_bindings.DDBCSQLGetConnectionAttr( - self.hdbc, # Connection handle (wrapper) - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, # Attribute - ) - check_error( - ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, autocommit_mode - ) - return autocommit_mode == ddbc_sql_const.SQL_AUTOCOMMIT_ON.value + return self._conn.get_autocommit() @autocommit.setter def autocommit(self, value: bool) -> None: @@ -299,17 +211,7 @@ def autocommit(self, value: bool) -> None: Raises: DatabaseError: If there is an error while setting the autocommit mode. """ - ret = ddbc_bindings.DDBCSQLSetConnectAttr( - self.hdbc, # Connection handle - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, # Attribute - ( - ddbc_sql_const.SQL_AUTOCOMMIT_ON.value - if value - else ddbc_sql_const.SQL_AUTOCOMMIT_OFF.value - ), # Value - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - self._autocommit = value + self.setautocommit(value) if ENABLE_LOGGING: logger.info("Autocommit mode set to %s.", value) @@ -323,7 +225,8 @@ def setautocommit(self, value: bool = True) -> None: Raises: DatabaseError: If there is an error while setting the autocommit mode. """ - self.autocommit = value + self._conn.set_autocommit(value) + self._autocommit = value def cursor(self) -> Cursor: """ @@ -340,9 +243,9 @@ def cursor(self) -> Cursor: DatabaseError: If there is an error while creating the cursor. InterfaceError: If there is an error related to the database interface. """ - if self._is_closed(): - # Cannot create a cursor if the connection is closed - raise Exception("Connection is closed. Cannot create cursor.") + # if self._is_closed(): + # # Cannot create a cursor if the connection is closed + # raise Exception("Connection is closed. Cannot create cursor.") return Cursor(self) def commit(self) -> None: @@ -362,12 +265,7 @@ def commit(self) -> None: raise Exception("Connection is closed. Cannot commit.") # Commit the current transaction - ret = ddbc_bindings.DDBCSQLEndTran( - ddbc_sql_const.SQL_HANDLE_DBC.value, # Handle type - self.hdbc, # Connection handle (wrapper) - ddbc_sql_const.SQL_COMMIT.value, # Commit the transaction - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) + self._conn.commit() if ENABLE_LOGGING: logger.info("Transaction committed successfully.") @@ -387,12 +285,7 @@ def rollback(self) -> None: raise Exception("Connection is closed. Cannot roll back.") # Roll back the current transaction - ret = ddbc_bindings.DDBCSQLEndTran( - ddbc_sql_const.SQL_HANDLE_DBC.value, # Handle type - self.hdbc, # Connection handle (wrapper) - ddbc_sql_const.SQL_ROLLBACK.value, # Roll back the transaction - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) + self._conn.rollback() if ENABLE_LOGGING: logger.info("Transaction rolled back successfully.") @@ -412,13 +305,7 @@ def close(self) -> None: if self._is_closed(): # Connection is already closed return - # Disconnect from the database - ret = ddbc_bindings.DDBCSQLDisconnect(self.hdbc) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - - # Set the reference to None to trigger destructor - self.hdbc.free() - self.hdbc = None + self._conn.close() if ENABLE_LOGGING: logger.info("Connection closed successfully.") diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index c038ea7e..53ac36e8 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -48,8 +48,6 @@ def __init__(self, connection) -> None: Args: connection: Database connection object. """ - if connection.hdbc is None: - raise Exception("Connection is closed. Cannot create a cursor.") self.connection = connection # self.connection.autocommit = False self.hstmt = None @@ -417,19 +415,14 @@ def _allocate_statement_handle(self): """ Allocate the DDBC statement handle. """ - ret, handle = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_STMT.value, - self.connection.hdbc - ) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, handle, ret) - self.hstmt = handle + self.connection._conn.alloc_statement_handle() + print(f"Statement handle: {self.hstmt}") def _reset_cursor(self) -> None: """ Reset the DDBC statement handle. """ if self.hstmt: - self.hstmt.free() # Free the existing statement handle self.hstmt = None if ENABLE_LOGGING: logger.debug("SQLFreeHandle succeeded") @@ -447,7 +440,6 @@ def close(self) -> None: raise Exception("Cursor is already closed.") if self.hstmt: - self.hstmt.free() self.hstmt = None if ENABLE_LOGGING: logger.debug("SQLFreeHandle succeeded") diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 33dec23e..2bbed27e 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -8,16 +8,23 @@ #include #include +<<<<<<< HEAD #include - //------------------------------------------------------------------------------------------------- // Implements the Connection class declared in connection.h. // This class wraps low-level ODBC operations like connect/disconnect, // transaction control, and autocommit configuration. //------------------------------------------------------------------------------------------------- +<<<<<<< HEAD Connection::Connection(const std::wstring& conn_str, bool autocommit) : _conn_str(conn_str) , _autocommit(autocommit) {} +======= +Connection::Connection(const std::wstring& conn_str, bool autocommit) : _conn_str(conn_str) , _autocommit(autocommit) {} +======= +Connection::Connection(const std::wstring& conn_str) : _conn_str(conn_str) {} +>>>>>>> fcd64d4 (working flow with c++ connection class) +>>>>>>> 29a7a65 (working flow with c++ connection class) Connection::~Connection() { close(); // Ensure the connection is closed when the object is destroyed. @@ -118,13 +125,6 @@ bool Connection::getAutocommit() const { return value == SQL_AUTOCOMMIT_ON; } -<<<<<<< HEAD -SqlHandlePtr Connection::allocStatementHandle() { - if (!_dbc_handle) { - throw std::runtime_error("Connection handle not allocated"); - } - LOG("Allocating statement handle"); -======= SQLRETURN set_attribute(SQLINTEGER Attribute, py::object ValuePtr) { LOG("Set SQL Connection Attribute"); if (!SQLSetConnectAttr_ptr) { @@ -165,8 +165,8 @@ SQLRETURN set_attribute(SQLINTEGER Attribute, py::object ValuePtr) { return ret; } -SqlHandlePtr Connection::alloc_statement_handle() { ->>>>>>> 52a2dba (initial edit) +SqlHandlePtr Connection::allocStatementHandle() { + LOG("Allocating statement handle"); SQLHANDLE stmt = nullptr; SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbc_handle->get(), &stmt); if (!SQL_SUCCEEDED(ret)) { diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index b26fd6a6..b9aa0ea3 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -21,6 +21,7 @@ #include #include // Add this line for datetime support #include +#include "connection/connection.h" namespace py = pybind11; using namespace pybind11::literals; @@ -671,6 +672,46 @@ void SqlHandle::free() { } } +// Wrap SQLSetConnectAttr +SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, + py::object ValuePtr) { + LOG("Set SQL Connection Attribute"); + if (!SQLSetConnectAttr_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver + } + + // Print the type of ValuePtr and attribute value - helpful for debugging + LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast(), Attribute); + + SQLPOINTER value = 0; + SQLINTEGER length = 0; + + if (py::isinstance(ValuePtr)) { + // Handle integer values + int intValue = ValuePtr.cast(); + value = reinterpret_cast(intValue); + length = SQL_IS_INTEGER; // Integer values don't require a length + } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { + // Handle byte or bytearray values (like access tokens) + // Store in static buffer to ensure memory remains valid during connection + static std::vector bytesBuffers; + bytesBuffers.push_back(ValuePtr.cast()); + value = const_cast(bytesBuffers.back().c_str()); + length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token) + } else { + LOG("Unsupported ValuePtr type"); + return SQL_ERROR; + } + + SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to set Connection attribute"); + } + LOG("Set Connection attribute successfully"); + return ret; +} + // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -1936,20 +1977,20 @@ PYBIND11_MODULE(ddbc_bindings, m) { py::class_(m, "SqlHandle") .def("free", &SqlHandle::free); py::class_(m, "Connection") - .def(py::init(), py::arg("conn_str")) + .def(py::init(), py::arg("conn_str")) .def("connect", &Connection::connect, "Establish a connection to the database") .def("close", &Connection::close, "Close the connection") .def("commit", [](Connection& self) { self.end_transaction(SQL_COMMIT); }) .def("rollback", [](Connection& self) { - self.end_transaction(SQL_ROLLBACK)}) + self.end_transaction(SQL_ROLLBACK);}) .def("set_autocommit", &Connection::set_autocommit) .def("get_autocommit", &Connection::get_autocommit) - .def("set_attribute", &Connection::set_attribute); .def("alloc_statement_handle", &Connection::alloc_statement_handle); m.def("DDBCSQLSetConnectAttr", &SQLSetConnectAttr_wrap, "Set an attribute that governs aspects of connections"); + m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements"); m.def("DDBCSQLRowCount", &SQLRowCount_wrap, @@ -1965,7 +2006,6 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), py::arg("fetchSize") = 1, "Fetch many rows from the result set"); m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); - m.def("DDBCSQLEndTran", &SQLEndTran_wrap, "End a transaction"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 81801379..69086967 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -106,11 +106,11 @@ extern SQLFreeStmtFunc SQLFreeStmt_ptr; extern SQLGetDiagRecFunc SQLGetDiagRec_ptr; -// -- Logging utility -- +// Logging utility template void LOG(const std::string& formatString, Args&&... args); -// -- Exception helper -- +// Throws a std::runtime_error with the given message void ThrowStdException(const std::string& message); //------------------------------------------------------------------------------------------------- From 65d19a8af4ced99c4d3ab1eb9fd3a72e04541ba3 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Tue, 20 May 2025 00:38:50 +0530 Subject: [PATCH 07/18] working interation with access token --- mssql_python/connection.py | 115 +----------------- mssql_python/cursor.py | 3 +- mssql_python/pybind/connection/connection.cpp | 55 +++++++-- mssql_python/pybind/connection/connection.h | 7 +- mssql_python/pybind/ddbc_bindings.cpp | 98 +++++++-------- mssql_python/pybind/ddbc_bindings.h | 10 ++ 6 files changed, 107 insertions(+), 181 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index f4fc39b5..524a970d 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -56,16 +56,9 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef self.connection_str = self._construct_connection_string( connection_str, **kwargs ) - # self._attrs_before = attrs_before - # self._autocommit = autocommit - # if self._attrs_before != {}: - # self._apply_attrs_before() # Apply pre-connection attributes - # if self._autocommit: - # self.setautocommit(autocommit) - print("Connection string:", self.connection_str) - self._conn = ddbc_bindings.Connection(self.connection_str) - self._conn.connect() - print("Connection established") + self._attrs_before = attrs_before or {} + self._conn = ddbc_bindings.Connection(self.connection_str, autocommit) + self._conn.connect(self._attrs_before) self._autocommit = autocommit self.setautocommit(autocommit) @@ -107,90 +100,6 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st return conn_str - # def _apply_attrs_before(self): - # """ - # Apply specific pre-connection attributes. - # Currently, this method only processes an attribute with key 1256 (e.g., SQL_COPT_SS_ACCESS_TOKEN) - # if present in `self._attrs_before`. Other attributes are ignored. - - # Returns: - # bool: True. - # """ - - # if ENABLE_LOGGING: - # logger.info("Attempting to apply pre-connection attributes (attrs_before): %s", self._attrs_before) - - # if not isinstance(self._attrs_before, dict): - # if self._attrs_before is not None and ENABLE_LOGGING: - # logger.warning( - # f"_attrs_before is of type {type(self._attrs_before).__name__}, " - # f"expected dict. Skipping attribute application." - # ) - # elif self._attrs_before is None and ENABLE_LOGGING: - # logger.debug("_attrs_before is None. No pre-connection attributes to apply.") - # return True # Exit if _attrs_before is not a dictionary or is None - - # for key, value in self._attrs_before.items(): - # ikey = None - # if isinstance(key, int): - # ikey = key - # elif isinstance(key, str) and key.isdigit(): - # try: - # ikey = int(key) - # except ValueError: - # if ENABLE_LOGGING: - # logger.debug( - # f"Skipping attribute with key '{key}' in attrs_before: " - # f"could not convert string to int." - # ) - # continue # Skip if string key is not a valid integer - # else: - # if ENABLE_LOGGING: - # logger.debug( - # f"Skipping attribute with key '{key}' in attrs_before due to " - # f"unsupported key type: {type(key).__name__}. Expected int or string representation of an int." - # ) - # continue # Skip keys that are not int or string representation of an int - - # if ikey == ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value: - # if ENABLE_LOGGING: - # logger.info( - # f"Found attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value}. Attempting to set it." - # ) - # self._conn.set_attribute(ikey, value) - # if ENABLE_LOGGING: - # logger.info( - # f"Call to set attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value} with value '{value}' completed." - # ) - # # If you expect only one such key, you could add 'break' here. - # else: - # if ENABLE_LOGGING: - # logger.debug( - # f"Ignoring attribute with key '{key}' (resolved to {ikey}) in attrs_before " - # f"as it is not the target attribute ({ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value})." - # ) - # return True - - # def _set_connection_attributes(self, ikey: int, ivalue: any) -> None: - # """ - # Set the connection attributes before connecting. - - # Args: - # ikey (int): The attribute key to set. - # ivalue (Any): The value to set for the attribute. Can be bytes, bytearray, int, or unicode. - # vallen (int): The length of the value. - - # Raises: - # DatabaseError: If there is an error while setting the connection attribute. - # """ - - # ret = ddbc_bindings.DDBCSQLSetConnectAttr( - # self.hdbc, # Connection handle - # ikey, # Attribute - # ivalue, # Value - # ) - # check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - @property def autocommit(self) -> bool: """ @@ -208,8 +117,6 @@ def autocommit(self, value: bool) -> None: value (bool): True to enable autocommit, False to disable it. Returns: None - Raises: - DatabaseError: If there is an error while setting the autocommit mode. """ self.setautocommit(value) if ENABLE_LOGGING: @@ -243,9 +150,6 @@ def cursor(self) -> Cursor: DatabaseError: If there is an error while creating the cursor. InterfaceError: If there is an error related to the database interface. """ - # if self._is_closed(): - # # Cannot create a cursor if the connection is closed - # raise Exception("Connection is closed. Cannot create cursor.") return Cursor(self) def commit(self) -> None: @@ -260,10 +164,6 @@ def commit(self) -> None: Raises: DatabaseError: If there is an error while committing the transaction. """ - if self._is_closed(): - # Cannot commit if the connection is closed - raise Exception("Connection is closed. Cannot commit.") - # Commit the current transaction self._conn.commit() if ENABLE_LOGGING: @@ -280,10 +180,6 @@ def rollback(self) -> None: Raises: DatabaseError: If there is an error while rolling back the transaction. """ - if self._is_closed(): - # Cannot roll back if the connection is closed - raise Exception("Connection is closed. Cannot roll back.") - # Roll back the current transaction self._conn.rollback() if ENABLE_LOGGING: @@ -302,10 +198,7 @@ def close(self) -> None: Raises: DatabaseError: If there is an error while closing the connection. """ - if self._is_closed(): - # Connection is already closed - return + # Close the connection self._conn.close() - if ENABLE_LOGGING: logger.info("Connection closed successfully.") diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 53ac36e8..3939a012 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -415,8 +415,7 @@ def _allocate_statement_handle(self): """ Allocate the DDBC statement handle. """ - self.connection._conn.alloc_statement_handle() - print(f"Statement handle: {self.hstmt}") + self.hstmt = self.connection._conn.alloc_statement_handle() def _reset_cursor(self) -> None: """ diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 2bbed27e..a11a7660 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -7,8 +7,6 @@ #include "connection.h" #include #include - -<<<<<<< HEAD #include //------------------------------------------------------------------------------------------------- @@ -16,15 +14,8 @@ // This class wraps low-level ODBC operations like connect/disconnect, // transaction control, and autocommit configuration. //------------------------------------------------------------------------------------------------- -<<<<<<< HEAD Connection::Connection(const std::wstring& conn_str, bool autocommit) : _conn_str(conn_str) , _autocommit(autocommit) {} -======= -Connection::Connection(const std::wstring& conn_str, bool autocommit) : _conn_str(conn_str) , _autocommit(autocommit) {} -======= -Connection::Connection(const std::wstring& conn_str) : _conn_str(conn_str) {} ->>>>>>> fcd64d4 (working flow with c++ connection class) ->>>>>>> 29a7a65 (working flow with c++ connection class) Connection::~Connection() { close(); // Ensure the connection is closed when the object is destroyed. @@ -175,6 +166,52 @@ SqlHandlePtr Connection::allocStatementHandle() { return std::make_shared(SQL_HANDLE_STMT, stmt); } +SQLRETURN Connection::set_attribute(SQLINTEGER attribute, py::object value) { + LOG("Setting SQL attribute {}"); + + SQLPOINTER ptr = nullptr; + SQLINTEGER length = 0; + + if (py::isinstance(value)) { + int intValue = value.cast(); + ptr = reinterpret_cast(static_cast(intValue)); + length = SQL_IS_INTEGER; + } else if (py::isinstance(value) || py::isinstance(value)) { + static std::vector buffers; + buffers.emplace_back(value.cast()); + ptr = const_cast(buffers.back().c_str()); + length = static_cast(buffers.back().size()); + } else { + LOG("Unsupported attribute value type"); + return SQL_ERROR; + } + + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), attribute, ptr, length); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to set attribute {}"); + } + return ret; +} + +void Connection::apply_attrs_before(const py::dict& attrs) { + for (const auto& item : attrs) { + int key; + try { + key = py::cast(item.first); + } catch (...) { + continue; + } + + //do not hard code the key values + if (key == 1256) { + SQLRETURN ret = set_attribute(key, py::reinterpret_borrow(item.second)); + if (!SQL_SUCCEEDED(ret)) { + throw std::runtime_error("Failed to set access token before connect"); + } + } + } +} + SqlHandlePtr Connection::getSharedEnvHandle() { static std::once_flag flag; static SqlHandlePtr env_handle; diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index cc30c7c6..bd195042 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -18,8 +18,7 @@ class Connection { Connection(const std::wstring& conn_str, bool autocommit = false); ~Connection(); - // Establish the connection using the stored connection string. - SQLRETURN connect(); + SQLRETURN connect(const py::dict& attrs_before = py::dict()); // Close the connection and free resources. SQLRETURN close(); @@ -46,8 +45,10 @@ class Connection { std::wstring _conn_str; SqlHandlePtr _dbc_handle; bool _autocommit = false; + std::shared_ptr _conn; - static SqlHandlePtr getSharedEnvHandle(); + SQLRETURN set_attribute(SQLINTEGER attribute, pybind11::object value); + void apply_attrs_before(const pybind11::dict& attrs); }; #endif // CONNECTION_H \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index b9aa0ea3..26b14ceb 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3,8 +3,8 @@ // INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be // taken up in beta release - -#include // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions +#include "ddbc_bindings.h" +#include "connection/connection.h" #include #include // std::setw, std::setfill @@ -15,17 +15,6 @@ #include #pragma comment(lib, "shlwapi.lib") -#include "ddbc_bindings.h" -#include -#include -#include -#include // Add this line for datetime support -#include -#include "connection/connection.h" - -namespace py = pybind11; -using namespace pybind11::literals; - //------------------------------------------------------------------------------------------------- // Macro definitions //------------------------------------------------------------------------------------------------- @@ -673,44 +662,44 @@ void SqlHandle::free() { } // Wrap SQLSetConnectAttr -SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, - py::object ValuePtr) { - LOG("Set SQL Connection Attribute"); - if (!SQLSetConnectAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - // Print the type of ValuePtr and attribute value - helpful for debugging - LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast(), Attribute); - - SQLPOINTER value = 0; - SQLINTEGER length = 0; - - if (py::isinstance(ValuePtr)) { - // Handle integer values - int intValue = ValuePtr.cast(); - value = reinterpret_cast(intValue); - length = SQL_IS_INTEGER; // Integer values don't require a length - } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { - // Handle byte or bytearray values (like access tokens) - // Store in static buffer to ensure memory remains valid during connection - static std::vector bytesBuffers; - bytesBuffers.push_back(ValuePtr.cast()); - value = const_cast(bytesBuffers.back().c_str()); - length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token) - } else { - LOG("Unsupported ValuePtr type"); - return SQL_ERROR; - } - - SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set Connection attribute"); - } - LOG("Set Connection attribute successfully"); - return ret; -} +// SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, +// py::object ValuePtr) { +// LOG("Set SQL Connection Attribute"); +// if (!SQLSetConnectAttr_ptr) { +// LOG("Function pointer not initialized. Loading the driver."); +// DriverLoader::getInstance().loadDriver(); // Load the driver +// } + +// // Print the type of ValuePtr and attribute value - helpful for debugging +// LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast(), Attribute); + +// SQLPOINTER value = 0; +// SQLINTEGER length = 0; + +// if (py::isinstance(ValuePtr)) { +// // Handle integer values +// int intValue = ValuePtr.cast(); +// value = reinterpret_cast(intValue); +// length = SQL_IS_INTEGER; // Integer values don't require a length +// } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { +// // Handle byte or bytearray values (like access tokens) +// // Store in static buffer to ensure memory remains valid during connection +// static std::vector bytesBuffers; +// bytesBuffers.push_back(ValuePtr.cast()); +// value = const_cast(bytesBuffers.back().c_str()); +// length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token) +// } else { +// LOG("Unsupported ValuePtr type"); +// return SQL_ERROR; +// } + +// SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length); +// if (!SQL_SUCCEEDED(ret)) { +// LOG("Failed to set Connection attribute"); +// } +// LOG("Set Connection attribute successfully"); +// return ret; +// } // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { @@ -1977,8 +1966,8 @@ PYBIND11_MODULE(ddbc_bindings, m) { py::class_(m, "SqlHandle") .def("free", &SqlHandle::free); py::class_(m, "Connection") - .def(py::init(), py::arg("conn_str")) - .def("connect", &Connection::connect, "Establish a connection to the database") + .def(py::init(), py::arg("conn_str"), py::arg("autocommit") = false) + .def("connect", &Connection::connect, py::arg("attrs_before") = py::dict(), "Establish a connection to the database") .def("close", &Connection::close, "Close the connection") .def("commit", [](Connection& self) { self.end_transaction(SQL_COMMIT); @@ -1988,9 +1977,6 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def("set_autocommit", &Connection::set_autocommit) .def("get_autocommit", &Connection::get_autocommit) .def("alloc_statement_handle", &Connection::alloc_statement_handle); - m.def("DDBCSQLSetConnectAttr", &SQLSetConnectAttr_wrap, - "Set an attribute that governs aspects of connections"); - m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements"); m.def("DDBCSQLRowCount", &SQLRowCount_wrap, diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 69086967..a093ec5d 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -6,12 +6,22 @@ #pragma once +#include // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions + #include #include #include #include #include +#include +#include +#include +#include // Add this line for datetime support +#include +namespace py = pybind11; +using namespace pybind11::literals; + //------------------------------------------------------------------------------------------------- // Function pointer typedefs //------------------------------------------------------------------------------------------------- From be946d01cc90dc13b09b490f219e8755ee4c1b1e Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Tue, 20 May 2025 17:27:04 +0530 Subject: [PATCH 08/18] cleanup free() in sqlhandle --- mssql_python/pybind/ddbc_bindings.cpp | 23 ++++++++++------------- mssql_python/pybind/ddbc_bindings.h | 1 - 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 26b14ceb..1a9a6b68 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -633,17 +633,7 @@ SqlHandle::SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle) // Native ODBC handles must be explicitly released by calling `free()` directly from Python. // This avoids nondeterministic crashes during GC or shutdown during pytest. // Read the documentation for more details (https://aka.ms/CPPvsPythonGC) -SqlHandle::~SqlHandle() {} - -SQLHANDLE SqlHandle::get() const { - return _handle; -} - -SQLSMALLINT SqlHandle::type() const { - return _type; -} - -void SqlHandle::free() { +SqlHandle::~SqlHandle() { if (_handle && SQLFreeHandle_ptr) { const char* type_str = nullptr; switch (_type) { @@ -661,6 +651,14 @@ void SqlHandle::free() { } } +SQLHANDLE SqlHandle::get() const { + return _handle; +} + +SQLSMALLINT SqlHandle::type() const { + return _type; +} + // Wrap SQLSetConnectAttr // SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, // py::object ValuePtr) { @@ -1963,8 +1961,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def_readwrite("sqlState", &ErrorInfo::sqlState) .def_readwrite("ddbcErrorMsg", &ErrorInfo::ddbcErrorMsg); - py::class_(m, "SqlHandle") - .def("free", &SqlHandle::free); + py::class_(m, "SqlHandle"); py::class_(m, "Connection") .def(py::init(), py::arg("conn_str"), py::arg("autocommit") = false) .def("connect", &Connection::connect, py::arg("attrs_before") = py::dict(), "Establish a connection to the database") diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index a093ec5d..9b5461a3 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -160,7 +160,6 @@ class SqlHandle { ~SqlHandle(); SQLHANDLE get() const; SQLSMALLINT type() const; - void free(); private: SQLSMALLINT _type; SQLHANDLE _handle; From aad1e2250ba7deb4f16a055984429fc61beaab61 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Tue, 20 May 2025 18:22:58 +0530 Subject: [PATCH 09/18] removed unnecessary prints --- mssql_python/pybind/connection/connection.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index a11a7660..b40a0046 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -95,7 +95,7 @@ SQLRETURN Connection::setAutocommit(bool enable) { throw std::runtime_error("Connection handle not allocated"); } SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; - LOG("Set SQL Connection Attribute"); + LOG("Set SQL Connection Attribute - Autocommit"); SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)value, 0); if (!SQL_SUCCEEDED(ret)) { throw std::runtime_error("Failed to set autocommit mode."); @@ -167,7 +167,7 @@ SqlHandlePtr Connection::allocStatementHandle() { } SQLRETURN Connection::set_attribute(SQLINTEGER attribute, py::object value) { - LOG("Setting SQL attribute {}"); + LOG("Setting SQL attribute"); SQLPOINTER ptr = nullptr; SQLINTEGER length = 0; @@ -188,7 +188,10 @@ SQLRETURN Connection::set_attribute(SQLINTEGER attribute, py::object value) { SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), attribute, ptr, length); if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set attribute {}"); + LOG("Failed to set attribute"); + } + else { + LOG("Set attribute successfully"); } return ret; } From 179ebf2fc9497c066becda340f76112bb6c71943 Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Tue, 20 May 2025 18:32:04 +0530 Subject: [PATCH 10/18] removing comment --- mssql_python/pybind/ddbc_bindings.cpp | 44 --------------------------- 1 file changed, 44 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 1a9a6b68..08d52dd2 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -629,10 +629,6 @@ void DriverLoader::loadDriver() { SqlHandle::SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle) : _type(type), _handle(rawHandle) {} -// Note: Destructor is intentionally a no-op. Python owns the lifecycle. -// Native ODBC handles must be explicitly released by calling `free()` directly from Python. -// This avoids nondeterministic crashes during GC or shutdown during pytest. -// Read the documentation for more details (https://aka.ms/CPPvsPythonGC) SqlHandle::~SqlHandle() { if (_handle && SQLFreeHandle_ptr) { const char* type_str = nullptr; @@ -659,46 +655,6 @@ SQLSMALLINT SqlHandle::type() const { return _type; } -// Wrap SQLSetConnectAttr -// SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, -// py::object ValuePtr) { -// LOG("Set SQL Connection Attribute"); -// if (!SQLSetConnectAttr_ptr) { -// LOG("Function pointer not initialized. Loading the driver."); -// DriverLoader::getInstance().loadDriver(); // Load the driver -// } - -// // Print the type of ValuePtr and attribute value - helpful for debugging -// LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast(), Attribute); - -// SQLPOINTER value = 0; -// SQLINTEGER length = 0; - -// if (py::isinstance(ValuePtr)) { -// // Handle integer values -// int intValue = ValuePtr.cast(); -// value = reinterpret_cast(intValue); -// length = SQL_IS_INTEGER; // Integer values don't require a length -// } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { -// // Handle byte or bytearray values (like access tokens) -// // Store in static buffer to ensure memory remains valid during connection -// static std::vector bytesBuffers; -// bytesBuffers.push_back(ValuePtr.cast()); -// value = const_cast(bytesBuffers.back().c_str()); -// length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token) -// } else { -// LOG("Unsupported ValuePtr type"); -// return SQL_ERROR; -// } - -// SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length); -// if (!SQL_SUCCEEDED(ret)) { -// LOG("Failed to set Connection attribute"); -// } -// LOG("Set Connection attribute successfully"); -// return ret; -// } - // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); From 70f6aa3e02910a8516c94b021538d5089695f73d Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Tue, 20 May 2025 19:27:52 +0530 Subject: [PATCH 11/18] resolving conflict --- mssql_python/pybind/connection/connection.cpp | 43 +------------------ mssql_python/pybind/connection/connection.h | 1 + 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index b40a0046..d5eae9f5 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -116,46 +116,6 @@ bool Connection::getAutocommit() const { return value == SQL_AUTOCOMMIT_ON; } -SQLRETURN set_attribute(SQLINTEGER Attribute, py::object ValuePtr) { - LOG("Set SQL Connection Attribute"); - if (!SQLSetConnectAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - // Print the type of ValuePtr and attribute value - helpful for debugging - LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast(), Attribute); - - SQLPOINTER value = 0; - SQLINTEGER length = 0; - - if (py::isinstance(ValuePtr)) { - // Handle integer values - int intValue = ValuePtr.cast(); - value = reinterpret_cast(intValue); - length = SQL_IS_INTEGER; // Integer values don't require a length - } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { - // Handle byte or bytearray values (like access tokens) - // Store in static buffer to ensure memory remains valid during connection - static std::vector bytesBuffers; - bytesBuffers.push_back(ValuePtr.cast()); - value = const_cast(bytesBuffers.back().c_str()); - length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token) - } else { - LOG("Unsupported ValuePtr type"); - return SQL_ERROR; - } - - SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), Attribute, value, length); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set Connection attribute"); - } - else { - LOG("Set Connection attribute successfully"); - } - return ret; -} - SqlHandlePtr Connection::allocStatementHandle() { LOG("Allocating statement handle"); SQLHANDLE stmt = nullptr; @@ -205,8 +165,7 @@ void Connection::apply_attrs_before(const py::dict& attrs) { continue; } - //do not hard code the key values - if (key == 1256) { + if (key == SQL_COPT_SS_ACCESS_TOKEN) { SQLRETURN ret = set_attribute(key, py::reinterpret_borrow(item.second)); if (!SQL_SUCCEEDED(ret)) { throw std::runtime_error("Failed to set access token before connect"); diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index bd195042..f58c6e86 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -18,6 +18,7 @@ class Connection { Connection(const std::wstring& conn_str, bool autocommit = false); ~Connection(); + // Establish the connection using the stored connection string. SQLRETURN connect(const py::dict& attrs_before = py::dict()); // Close the connection and free resources. From b22b197fb7b37fccfe45d29512bf209acfed57fd Mon Sep 17 00:00:00 2001 From: Saumya Garg Date: Wed, 21 May 2025 17:54:47 +0530 Subject: [PATCH 12/18] working --- mssql_python/pybind/connection/connection.cpp | 26 +++++++++++++------ mssql_python/pybind/connection/connection.h | 8 +++--- mssql_python/pybind/ddbc_bindings.cpp | 13 ++++------ 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index d5eae9f5..5843b60b 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -9,6 +9,8 @@ #include #include +#define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token + //------------------------------------------------------------------------------------------------- // Implements the Connection class declared in connection.h. // This class wraps low-level ODBC operations like connect/disconnect, @@ -21,8 +23,16 @@ Connection::~Connection() { close(); // Ensure the connection is closed when the object is destroyed. } -SQLRETURN Connection::connect() { +SQLRETURN Connection::connect(const py::dict& attrs_before) { allocDbcHandle(); + // Apply access token before connect + if (!attrs_before.is_none() && py::len(attrs_before) > 0) { + LOG("Apply attributes before connect"); + applyAttrsBefore(attrs_before); + if (_autocommit) { + setAutocommit(_autocommit); + } + } return connectToDb(); } @@ -91,11 +101,8 @@ SQLRETURN Connection::rollback() { } SQLRETURN Connection::setAutocommit(bool enable) { - if (!_dbc_handle) { - throw std::runtime_error("Connection handle not allocated"); - } SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; - LOG("Set SQL Connection Attribute - Autocommit"); + LOG("Set SQL Connection Attribute - Autocommit"); SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)value, 0); if (!SQL_SUCCEEDED(ret)) { throw std::runtime_error("Failed to set autocommit mode."); @@ -117,6 +124,9 @@ bool Connection::getAutocommit() const { } SqlHandlePtr Connection::allocStatementHandle() { + if (!_dbc_handle) { + throw std::runtime_error("Connection handle not allocated"); + } LOG("Allocating statement handle"); SQLHANDLE stmt = nullptr; SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbc_handle->get(), &stmt); @@ -126,7 +136,7 @@ SqlHandlePtr Connection::allocStatementHandle() { return std::make_shared(SQL_HANDLE_STMT, stmt); } -SQLRETURN Connection::set_attribute(SQLINTEGER attribute, py::object value) { +SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { LOG("Setting SQL attribute"); SQLPOINTER ptr = nullptr; @@ -156,7 +166,7 @@ SQLRETURN Connection::set_attribute(SQLINTEGER attribute, py::object value) { return ret; } -void Connection::apply_attrs_before(const py::dict& attrs) { +void Connection::applyAttrsBefore(const py::dict& attrs) { for (const auto& item : attrs) { int key; try { @@ -166,7 +176,7 @@ void Connection::apply_attrs_before(const py::dict& attrs) { } if (key == SQL_COPT_SS_ACCESS_TOKEN) { - SQLRETURN ret = set_attribute(key, py::reinterpret_borrow(item.second)); + SQLRETURN ret = setAttribute(key, py::reinterpret_borrow(item.second)); if (!SQL_SUCCEEDED(ret)) { throw std::runtime_error("Failed to set access token before connect"); } diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index f58c6e86..8a77640f 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -46,10 +46,10 @@ class Connection { std::wstring _conn_str; SqlHandlePtr _dbc_handle; bool _autocommit = false; - std::shared_ptr _conn; - - SQLRETURN set_attribute(SQLINTEGER attribute, pybind11::object value); - void apply_attrs_before(const pybind11::dict& attrs); + + static SqlHandlePtr getSharedEnvHandle(); + SQLRETURN setAttribute(SQLINTEGER attribute, pybind11::object value); + void applyAttrsBefore(const pybind11::dict& attrs); }; #endif // CONNECTION_H \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 08d52dd2..15e1f03c 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1922,14 +1922,11 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def(py::init(), py::arg("conn_str"), py::arg("autocommit") = false) .def("connect", &Connection::connect, py::arg("attrs_before") = py::dict(), "Establish a connection to the database") .def("close", &Connection::close, "Close the connection") - .def("commit", [](Connection& self) { - self.end_transaction(SQL_COMMIT); - }) - .def("rollback", [](Connection& self) { - self.end_transaction(SQL_ROLLBACK);}) - .def("set_autocommit", &Connection::set_autocommit) - .def("get_autocommit", &Connection::get_autocommit) - .def("alloc_statement_handle", &Connection::alloc_statement_handle); + .def("commit", &Connection::commit, "Commit the current transaction") + .def("rollback", &Connection::rollback, "Rollback the current transaction") + .def("set_autocommit", &Connection::setAutocommit) + .def("get_autocommit", &Connection::getAutocommit) + .def("alloc_statement_handle", &Connection::allocStatementHandle); m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements"); m.def("DDBCSQLRowCount", &SQLRowCount_wrap, From e302c0c973c1b210c3a479fb88353d05bf66759b Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Thu, 22 May 2025 00:24:22 +0530 Subject: [PATCH 13/18] final working-fix test --- mssql_python/cursor.py | 3 ++- mssql_python/pybind/connection/connection.cpp | 10 +++----- mssql_python/pybind/ddbc_bindings.cpp | 25 ++++++++++++------- mssql_python/pybind/ddbc_bindings.h | 1 + tests/test_005_exceptions.py | 8 +++--- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 3939a012..7ef3ebac 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -422,6 +422,7 @@ def _reset_cursor(self) -> None: Reset the DDBC statement handle. """ if self.hstmt: + self.hstmt.free() self.hstmt = None if ENABLE_LOGGING: logger.debug("SQLFreeHandle succeeded") @@ -439,6 +440,7 @@ def close(self) -> None: raise Exception("Cursor is already closed.") if self.hstmt: + self.hstmt.free() self.hstmt = None if ENABLE_LOGGING: logger.debug("SQLFreeHandle succeeded") @@ -548,7 +550,6 @@ def execute( reset_cursor: Whether to reset the cursor before execution. """ self._check_closed() # Check if the cursor is closed - if reset_cursor: self._reset_cursor() diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 5843b60b..e5f91350 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -5,7 +5,6 @@ // taken up in future #include "connection.h" -#include #include #include @@ -19,9 +18,7 @@ Connection::Connection(const std::wstring& conn_str, bool autocommit) : _conn_str(conn_str) , _autocommit(autocommit) {} -Connection::~Connection() { - close(); // Ensure the connection is closed when the object is destroyed. -} +Connection::~Connection() {} SQLRETURN Connection::connect(const py::dict& attrs_before) { allocDbcHandle(); @@ -54,7 +51,7 @@ SQLRETURN Connection::connectToDb() { (SQLWCHAR*)_conn_str.c_str(), SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to connect to database"); + ThrowStdException("Client unable to establish connection"); } LOG("Connected to database successfully"); return ret; @@ -72,7 +69,7 @@ SQLRETURN Connection::close() { } SQLRETURN ret = SQLDisconnect_ptr(_dbc_handle->get()); - _dbc_handle.reset(); + _dbc_handle->free(); return ret; } @@ -119,7 +116,6 @@ bool Connection::getAutocommit() const { SQLINTEGER value; SQLINTEGER string_length; SQLGetConnectAttr_ptr(_dbc_handle->get(), SQL_ATTR_AUTOCOMMIT, &value, sizeof(value), &string_length); - return value == SQL_AUTOCOMMIT_ON; } diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 15e1f03c..416d8dd5 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -630,6 +630,20 @@ SqlHandle::SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle) : _type(type), _handle(rawHandle) {} SqlHandle::~SqlHandle() { + if (_handle) { + free(); + } +} + +SQLHANDLE SqlHandle::get() const { + return _handle; +} + +SQLSMALLINT SqlHandle::type() const { + return _type; +} + +void SqlHandle::free() { if (_handle && SQLFreeHandle_ptr) { const char* type_str = nullptr; switch (_type) { @@ -647,14 +661,6 @@ SqlHandle::~SqlHandle() { } } -SQLHANDLE SqlHandle::get() const { - return _handle; -} - -SQLSMALLINT SqlHandle::type() const { - return _type; -} - // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -1917,7 +1923,8 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def_readwrite("sqlState", &ErrorInfo::sqlState) .def_readwrite("ddbcErrorMsg", &ErrorInfo::ddbcErrorMsg); - py::class_(m, "SqlHandle"); + py::class_(m, "SqlHandle") + .def("free", &SqlHandle::free, "Free the handle"); py::class_(m, "Connection") .def(py::init(), py::arg("conn_str"), py::arg("autocommit") = false) .def("connect", &Connection::connect, py::arg("attrs_before") = py::dict(), "Establish a connection to the database") diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 9b5461a3..a093ec5d 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -160,6 +160,7 @@ class SqlHandle { ~SqlHandle(); SQLHANDLE get() const; SQLSMALLINT type() const; + void free(); private: SQLSMALLINT _type; SQLHANDLE _handle; diff --git a/tests/test_005_exceptions.py b/tests/test_005_exceptions.py index 030c4f16..9406a14d 100644 --- a/tests/test_005_exceptions.py +++ b/tests/test_005_exceptions.py @@ -124,7 +124,7 @@ def test_foreign_key_constraint_error(cursor, db_connection): drop_table_if_exists(cursor, "pytest_parent_table") db_connection.commit() -def test_connection_error(db_connection): - with pytest.raises(OperationalError) as excinfo: - Connection("InvalidConnectionString") - assert "Client unable to establish connection" in str(excinfo.value) +# def test_connection_error(db_connection): +# with pytest.raises(OperationalError) as excinfo: +# Connection("InvalidConnectionString") +# assert "Client unable to establish connection" in str(excinfo.value) From 6a077c6caa469f2a34891166561a066ae5774556 Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Thu, 29 May 2025 12:33:12 +0530 Subject: [PATCH 14/18] minor updates --- mssql_python/pybind/connection/connection.cpp | 151 +++++++----------- mssql_python/pybind/connection/connection.h | 31 ++-- mssql_python/pybind/ddbc_bindings.cpp | 4 +- mssql_python/pybind/ddbc_bindings.h | 3 + 4 files changed, 81 insertions(+), 108 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index aa577c60..0c92e114 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -7,148 +7,121 @@ #include "connection.h" #include +SqlHandlePtr Connection::_envHandle = nullptr; //------------------------------------------------------------------------------------------------- // Implements the Connection class declared in connection.h. // This class wraps low-level ODBC operations like connect/disconnect, // transaction control, and autocommit configuration. //------------------------------------------------------------------------------------------------- Connection::Connection(const std::wstring& conn_str, bool autocommit) - : _conn_str(conn_str) , _autocommit(autocommit) {} + : _connStr(conn_str) , _autocommit(autocommit) { + if (!_envHandle) { + LOG("Allocating environment handle"); + SQLHANDLE env = nullptr; + if (!SQLAllocHandle_ptr) { + LOG("Function pointers not initialized, loading driver"); + DriverLoader::getInstance().loadDriver(); + } + SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); + checkError(ret, "Failed to allocate environment handle"); + _envHandle = std::make_shared(SQL_HANDLE_ENV, env); -Connection::~Connection() { - close(); // Ensure the connection is closed when the object is destroyed. + LOG("Setting environment attributes"); + ret = SQLSetEnvAttr_ptr(_envHandle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0); + checkError(ret, "Failed to set environment attribute"); + } + allocateDbcHandle(); } -SQLRETURN Connection::connect() { - allocDbcHandle(); - return connectToDb(); +Connection::~Connection() { + disconnect(); // fallback if app forgets to disconnect } -// Allocates DBC handle -void Connection::allocDbcHandle() { +// Allocates connection handle +void Connection::allocateDbcHandle() { SQLHANDLE dbc = nullptr; LOG("Allocate SQL Connection Handle"); - SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, getSharedEnvHandle()->get(), &dbc); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to allocate connection handle"); - } - _dbc_handle = std::make_shared(SQL_HANDLE_DBC, dbc); + SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, _envHandle->get(), &dbc); + checkError(ret, "Failed to allocate connection handle"); + _dbcHandle = std::make_shared(SQL_HANDLE_DBC, dbc); } -// Connects to the database -SQLRETURN Connection::connectToDb() { +void Connection::connect() { LOG("Connecting to database"); - SQLRETURN ret = SQLDriverConnect_ptr(_dbc_handle->get(), nullptr, - (SQLWCHAR*)_conn_str.c_str(), SQL_NTS, - nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to connect to database"); - } - LOG("Connected to database successfully"); - return ret; + SQLRETURN ret = SQLDriverConnect_ptr( + _dbcHandle->get(), nullptr, + (SQLWCHAR*)_connStr.c_str(), SQL_NTS, + nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); + checkError(ret, "SQLDriverConnect failed"); + setAutocommit(_autocommit); } -SQLRETURN Connection::close() { - if (!_dbc_handle) { - LOG("No connection handle to close"); - return SQL_SUCCESS; +void Connection::disconnect() { + if (_dbcHandle) { + LOG("Disconnecting from database"); + SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get()); + checkError(ret, "Failed to disconnect from database"); + _dbcHandle.reset(); // triggers SQLFreeHandle via destructor, if last owner } - LOG("Disconnect from MSSQL"); - if (!SQLDisconnect_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); + else { + LOG("No connection handle to disconnect"); } +} - SQLRETURN ret = SQLDisconnect_ptr(_dbc_handle->get()); - _dbc_handle.reset(); - return ret; +void Connection::checkError(SQLRETURN ret, const std::string& msg) const{ + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + throw std::runtime_error("[ODBC Error] " + msg); + } } -SQLRETURN Connection::commit() { - if (!_dbc_handle) { +void Connection::commit() { + if (!_dbcHandle) { throw std::runtime_error("Connection handle not allocated"); } LOG("Committing transaction"); - SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbc_handle->get(), SQL_COMMIT); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to commit transaction"); - } - return ret; + SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT); + checkError(ret, "Failed to commit transaction"); } -SQLRETURN Connection::rollback() { - if (!_dbc_handle) { +void Connection::rollback() { + if (!_dbcHandle) { throw std::runtime_error("Connection handle not allocated"); } LOG("Rolling back transaction"); - SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbc_handle->get(), SQL_ROLLBACK); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to rollback transaction"); - } - return ret; + SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK); + checkError(ret, "Failed to rollback transaction"); } -SQLRETURN Connection::setAutocommit(bool enable) { - if (!_dbc_handle) { +void Connection::setAutocommit(bool enable) { + if (!_dbcHandle) { throw std::runtime_error("Connection handle not allocated"); } SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; LOG("Set SQL Connection Attribute"); - SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)value, 0); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to set autocommit mode."); - } + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)value, 0); + checkError(ret, "Failed to set autocommit attribute"); _autocommit = enable; - return ret; } bool Connection::getAutocommit() const { - if (!_dbc_handle) { + if (!_dbcHandle) { throw std::runtime_error("Connection handle not allocated"); } LOG("Get SQL Connection Attribute"); SQLINTEGER value; SQLINTEGER string_length; - SQLGetConnectAttr_ptr(_dbc_handle->get(), SQL_ATTR_AUTOCOMMIT, &value, sizeof(value), &string_length); - + SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, &value, sizeof(value), &string_length); + checkError(ret, "Failed to get autocommit attribute"); return value == SQL_AUTOCOMMIT_ON; } SqlHandlePtr Connection::allocStatementHandle() { - if (!_dbc_handle) { + if (!_dbcHandle) { throw std::runtime_error("Connection handle not allocated"); } LOG("Allocating statement handle"); SQLHANDLE stmt = nullptr; - SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbc_handle->get(), &stmt); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to allocate statement handle"); - } + SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt); + checkError(ret, "Failed to allocate statement handle"); return std::make_shared(SQL_HANDLE_STMT, stmt); } - -SqlHandlePtr Connection::getSharedEnvHandle() { - static std::once_flag flag; - static SqlHandlePtr env_handle; - - std::call_once(flag, []() { - LOG("Allocating environment handle"); - SQLHANDLE env = nullptr; - if (!SQLAllocHandle_ptr) { - LOG("Function pointers not initialized, loading driver"); - DriverLoader::getInstance().loadDriver(); - } - SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to allocate environment handle"); - } - env_handle = std::make_shared(SQL_HANDLE_ENV, env); - - LOG("Setting environment attributes"); - ret = SQLSetEnvAttr_ptr(env_handle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to set environment attribute"); - } - }); - return env_handle; -} \ No newline at end of file diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index cc30c7c6..6071ee12 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -4,9 +4,7 @@ // INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be // taken up in future. -#ifndef CONNECTION_H -#define CONNECTION_H - +#pragma once #include "ddbc_bindings.h" // Represents a single ODBC database connection. @@ -19,19 +17,19 @@ class Connection { ~Connection(); // Establish the connection using the stored connection string. - SQLRETURN connect(); + void connect(); - // Close the connection and free resources. - SQLRETURN close(); + // Disconnect and free the connection handle. + void disconnect(); // Commit the current transaction. - SQLRETURN commit(); + void commit(); // Rollback the current transaction. - SQLRETURN rollback(); + void rollback(); // Enable or disable autocommit mode. - SQLRETURN setAutocommit(bool value); + void setAutocommit(bool value); // Check whether autocommit is enabled. bool getAutocommit() const; @@ -40,14 +38,13 @@ class Connection { SqlHandlePtr allocStatementHandle(); private: - void allocDbcHandle(); - SQLRETURN connectToDb(); + void allocateDbcHandle(); + void checkError(SQLRETURN ret, const std::string& msg) const; - std::wstring _conn_str; - SqlHandlePtr _dbc_handle; - bool _autocommit = false; + std::wstring _connStr; + bool _usePool = false; + bool _autocommit = true; + SqlHandlePtr _dbcHandle; - static SqlHandlePtr getSharedEnvHandle(); + static SqlHandlePtr _envHandle; }; - -#endif // CONNECTION_H \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 493b6269..c49754a9 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -629,10 +629,10 @@ DriverLoader& DriverLoader::getInstance() { } void DriverLoader::loadDriver() { - if (!m_driverLoaded) { + std::call_once(m_onceFlag, [this]() { LoadDriverOrThrowException(); m_driverLoaded = true; - } + }); } // SqlHandle definition diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 81801379..cb97b508 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -11,6 +11,7 @@ #include #include #include +#include //------------------------------------------------------------------------------------------------- // Function pointer typedefs @@ -135,7 +136,9 @@ class DriverLoader { DriverLoader(); DriverLoader(const DriverLoader&) = delete; DriverLoader& operator=(const DriverLoader&) = delete; + bool m_driverLoaded; + std::once_flag m_onceFlag; }; //------------------------------------------------------------------------------------------------- From 5181a0e7d963b7682bac09cc69ff86ada09f8e69 Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Thu, 29 May 2025 16:00:41 +0530 Subject: [PATCH 15/18] updating file --- mssql_python/pybind/connection/connection.cpp | 103 ++++++++++++------ mssql_python/pybind/connection/connection.h | 6 +- mssql_python/pybind/ddbc_bindings.cpp | 10 +- mssql_python/pybind/ddbc_bindings.h | 8 ++ tests/test_005_exceptions.py | 8 +- 5 files changed, 85 insertions(+), 50 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index f5e8c749..e10d5777 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -26,18 +26,18 @@ Connection::Connection(const std::wstring& conn_str, bool autocommit) DriverLoader::getInstance().loadDriver(); } SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); - checkError(ret, "Failed to allocate environment handle"); + checkError(ret); _envHandle = std::make_shared(SQL_HANDLE_ENV, env); LOG("Setting environment attributes"); ret = SQLSetEnvAttr_ptr(_envHandle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0); - checkError(ret, "Failed to set environment attribute"); + checkError(ret); } allocateDbcHandle(); } Connection::~Connection() { - disconnect(); // fallback if app forgets to disconnect + disconnect(); // fallback if user forgets to disconnect } // Allocates connection handle @@ -45,25 +45,32 @@ void Connection::allocateDbcHandle() { SQLHANDLE dbc = nullptr; LOG("Allocate SQL Connection Handle"); SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, _envHandle->get(), &dbc); - checkError(ret, "Failed to allocate connection handle"); + checkError(ret); _dbcHandle = std::make_shared(SQL_HANDLE_DBC, dbc); } -void Connection::connect() { +SQLRETURN Connection::connect(const py::dict& attrs_before) { LOG("Connecting to database"); + // Apply access token before connect + if (!attrs_before.is_none() && py::len(attrs_before) > 0) { + LOG("Apply attributes before connect"); + applyAttrsBefore(attrs_before); + if (_autocommit) { + setAutocommit(_autocommit); + } + } SQLRETURN ret = SQLDriverConnect_ptr( _dbcHandle->get(), nullptr, (SQLWCHAR*)_connStr.c_str(), SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); - checkError(ret, "SQLDriverConnect failed"); - setAutocommit(_autocommit); + checkError(ret); } void Connection::disconnect() { if (_dbcHandle) { LOG("Disconnecting from database"); SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get()); - checkError(ret, "Failed to disconnect from database"); + checkError(ret); _dbcHandle.reset(); // triggers SQLFreeHandle via destructor, if last owner } else { @@ -71,9 +78,11 @@ void Connection::disconnect() { } } -void Connection::checkError(SQLRETURN ret, const std::string& msg) const{ - if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { - throw std::runtime_error("[ODBC Error] " + msg); +void Connection::checkError(SQLRETURN ret) const{ + if (!SQL_SUCCEEDED(ret)) { + ErrorInfo err = SQLCheckError_Wrap(SQL_HANDLE_DBC, _dbcHandle, ret); + std::string errorMsg = std::string(err.ddbcErrorMsg.begin(), err.ddbcErrorMsg.end()); + ThrowStdException(errorMsg); } } @@ -83,7 +92,7 @@ void Connection::commit() { } LOG("Committing transaction"); SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT); - checkError(ret, "Failed to commit transaction"); + checkError(ret); } void Connection::rollback() { @@ -92,7 +101,7 @@ void Connection::rollback() { } LOG("Rolling back transaction"); SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK); - checkError(ret, "Failed to rollback transaction"); + checkError(ret); } void Connection::setAutocommit(bool enable) { @@ -102,7 +111,7 @@ void Connection::setAutocommit(bool enable) { SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; LOG("Set SQL Connection Attribute"); SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)value, 0); - checkError(ret, "Failed to set autocommit attribute"); + checkError(ret); _autocommit = enable; } @@ -114,7 +123,7 @@ bool Connection::getAutocommit() const { SQLINTEGER value; SQLINTEGER string_length; SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, &value, sizeof(value), &string_length); - checkError(ret, "Failed to get autocommit attribute"); + checkError(ret); return value == SQL_AUTOCOMMIT_ON; } @@ -125,32 +134,54 @@ SqlHandlePtr Connection::allocStatementHandle() { LOG("Allocating statement handle"); SQLHANDLE stmt = nullptr; SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt); - checkError(ret, "Failed to allocate statement handle"); + checkError(ret); return std::make_shared(SQL_HANDLE_STMT, stmt); } -SqlHandlePtr Connection::getSharedEnvHandle() { - static std::once_flag flag; - static SqlHandlePtr env_handle; - std::call_once(flag, []() { - LOG("Allocating environment handle"); - SQLHANDLE env = nullptr; - if (!SQLAllocHandle_ptr) { - LOG("Function pointers not initialized, loading driver"); - DriverLoader::getInstance().loadDriver(); - } - SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to allocate environment handle"); +SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { + LOG("Setting SQL attribute"); + SQLPOINTER ptr = nullptr; + SQLINTEGER length = 0; + + if (py::isinstance(value)) { + int intValue = value.cast(); + ptr = reinterpret_cast(static_cast(intValue)); + length = SQL_IS_INTEGER; + } else if (py::isinstance(value) || py::isinstance(value)) { + static std::vector buffers; + buffers.emplace_back(value.cast()); + ptr = const_cast(buffers.back().c_str()); + length = static_cast(buffers.back().size()); + } else { + LOG("Unsupported attribute value type"); + return SQL_ERROR; + } + + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), attribute, ptr, length); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to set attribute"); + } + else { + LOG("Set attribute successfully"); + } + return ret; +} + +void Connection::applyAttrsBefore(const py::dict& attrs) { + for (const auto& item : attrs) { + int key; + try { + key = py::cast(item.first); + } catch (...) { + continue; } - env_handle = std::make_shared(SQL_HANDLE_ENV, env); - LOG("Setting environment attributes"); - ret = SQLSetEnvAttr_ptr(env_handle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0); - if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to set environment attribute"); + if (key == SQL_COPT_SS_ACCESS_TOKEN) { + SQLRETURN ret = setAttribute(key, py::reinterpret_borrow(item.second)); + if (!SQL_SUCCEEDED(ret)) { + throw std::runtime_error("Failed to set access token before connect"); + } } - }); - return env_handle; + } } \ No newline at end of file diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index 6071ee12..206ded8b 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -17,7 +17,7 @@ class Connection { ~Connection(); // Establish the connection using the stored connection string. - void connect(); + SQLRETURN connect(const py::dict& attrs_before = py::dict()); // Disconnect and free the connection handle. void disconnect(); @@ -39,7 +39,9 @@ class Connection { private: void allocateDbcHandle(); - void checkError(SQLRETURN ret, const std::string& msg) const; + void checkError(SQLRETURN ret) const; + SQLRETURN setAttribute(SQLINTEGER attribute, py::object value); + void applyAttrsBefore(const py::dict& attrs_before); std::wstring _connStr; bool _usePool = false; diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 2580c994..b6d6b7e7 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -95,12 +95,6 @@ struct ColumnBuffers { indicators(numCols, std::vector(fetchSize)) {} }; -// This struct is used to relay error info obtained from SQLDiagRec API to the Python module -struct ErrorInfo { - std::wstring sqlState; - std::wstring ddbcErrorMsg; -}; - //------------------------------------------------------------------------------------------------- // Function pointer initialization //------------------------------------------------------------------------------------------------- @@ -1927,8 +1921,8 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def("free", &SqlHandle::free, "Free the handle"); py::class_(m, "Connection") .def(py::init(), py::arg("conn_str"), py::arg("autocommit") = false) - .def("connect", &Connection::connect, py::arg("attrs_before") = py::dict(), "Establish a connection to the database") - .def("close", &Connection::close, "Close the connection") + .def("connect", &Connection::connect) + .def("close", &Connection::disconnect, "Close the connection") .def("commit", &Connection::commit, "Commit the current transaction") .def("rollback", &Connection::rollback, "Rollback the current transaction") .def("set_autocommit", &Connection::setAutocommit) diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 8a393d43..3d3925aa 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -121,6 +121,7 @@ extern SQLGetDiagRecFunc SQLGetDiagRec_ptr; template void LOG(const std::string& formatString, Args&&... args); + // Throws a std::runtime_error with the given message void ThrowStdException(const std::string& message); @@ -169,3 +170,10 @@ class SqlHandle { SQLHANDLE _handle; }; using SqlHandlePtr = std::shared_ptr; + +// This struct is used to relay error info obtained from SQLDiagRec API to the Python module +struct ErrorInfo { + std::wstring sqlState; + std::wstring ddbcErrorMsg; +}; +ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode); diff --git a/tests/test_005_exceptions.py b/tests/test_005_exceptions.py index 9406a14d..5e35d553 100644 --- a/tests/test_005_exceptions.py +++ b/tests/test_005_exceptions.py @@ -124,7 +124,7 @@ def test_foreign_key_constraint_error(cursor, db_connection): drop_table_if_exists(cursor, "pytest_parent_table") db_connection.commit() -# def test_connection_error(db_connection): -# with pytest.raises(OperationalError) as excinfo: -# Connection("InvalidConnectionString") -# assert "Client unable to establish connection" in str(excinfo.value) +def test_connection_error(db_connection): + with pytest.raises(RuntimeError) as excinfo: + Connection("InvalidConnectionString") + assert "Neither DSN nor SERVER keyword supplied" in str(excinfo.value) From 3240394ced22ba503f9c5c11791c03d224920bf8 Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Thu, 29 May 2025 16:08:22 +0530 Subject: [PATCH 16/18] added a TODO comment to address review comments --- mssql_python/connection.py | 2 -- mssql_python/pybind/connection/connection.cpp | 13 +++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 524a970d..7c73a6df 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -59,7 +59,6 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef self._attrs_before = attrs_before or {} self._conn = ddbc_bindings.Connection(self.connection_str, autocommit) self._conn.connect(self._attrs_before) - self._autocommit = autocommit self.setautocommit(autocommit) def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str: @@ -133,7 +132,6 @@ def setautocommit(self, value: bool = True) -> None: DatabaseError: If there is an error while setting the autocommit mode. """ self._conn.set_autocommit(value) - self._autocommit = value def cursor(self) -> Cursor: """ diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index e10d5777..5b03618e 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -78,6 +78,7 @@ void Connection::disconnect() { } } +// TODO: Add an exception class in C++ for error handling, DB spec compliant void Connection::checkError(SQLRETURN ret) const{ if (!SQL_SUCCEEDED(ret)) { ErrorInfo err = SQLCheckError_Wrap(SQL_HANDLE_DBC, _dbcHandle, ret); @@ -88,7 +89,7 @@ void Connection::checkError(SQLRETURN ret) const{ void Connection::commit() { if (!_dbcHandle) { - throw std::runtime_error("Connection handle not allocated"); + ThrowStdException("Connection handle not allocated"); } LOG("Committing transaction"); SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT); @@ -97,7 +98,7 @@ void Connection::commit() { void Connection::rollback() { if (!_dbcHandle) { - throw std::runtime_error("Connection handle not allocated"); + ThrowStdException("Connection handle not allocated"); } LOG("Rolling back transaction"); SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK); @@ -106,7 +107,7 @@ void Connection::rollback() { void Connection::setAutocommit(bool enable) { if (!_dbcHandle) { - throw std::runtime_error("Connection handle not allocated"); + ThrowStdException("Connection handle not allocated"); } SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; LOG("Set SQL Connection Attribute"); @@ -117,7 +118,7 @@ void Connection::setAutocommit(bool enable) { bool Connection::getAutocommit() const { if (!_dbcHandle) { - throw std::runtime_error("Connection handle not allocated"); + ThrowStdException("Connection handle not allocated"); } LOG("Get SQL Connection Attribute"); SQLINTEGER value; @@ -129,7 +130,7 @@ bool Connection::getAutocommit() const { SqlHandlePtr Connection::allocStatementHandle() { if (!_dbcHandle) { - throw std::runtime_error("Connection handle not allocated"); + ThrowStdException("Connection handle not allocated"); } LOG("Allocating statement handle"); SQLHANDLE stmt = nullptr; @@ -180,7 +181,7 @@ void Connection::applyAttrsBefore(const py::dict& attrs) { if (key == SQL_COPT_SS_ACCESS_TOKEN) { SQLRETURN ret = setAttribute(key, py::reinterpret_borrow(item.second)); if (!SQL_SUCCEEDED(ret)) { - throw std::runtime_error("Failed to set access token before connect"); + ThrowStdException("Failed to set access token before connect"); } } } From d373e9b3706117df441a1ca6ca85d00f0c0b5a42 Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Fri, 30 May 2025 02:03:01 +0530 Subject: [PATCH 17/18] adding c++ support for pooling --- mssql_python/pybind/CMakeLists.txt | 2 +- mssql_python/pybind/connection/connection.cpp | 106 +++++++++++++++++- mssql_python/pybind/connection/connection.h | 33 +++++- .../pybind/connection/connection_pool.cpp | 85 ++++++++++++++ .../pybind/connection/connection_pool.h | 64 +++++++++++ mssql_python/pybind/ddbc_bindings.cpp | 23 ++-- 6 files changed, 296 insertions(+), 17 deletions(-) create mode 100644 mssql_python/pybind/connection/connection_pool.cpp create mode 100644 mssql_python/pybind/connection/connection_pool.h diff --git a/mssql_python/pybind/CMakeLists.txt b/mssql_python/pybind/CMakeLists.txt index dceb2efc..aea9a323 100644 --- a/mssql_python/pybind/CMakeLists.txt +++ b/mssql_python/pybind/CMakeLists.txt @@ -90,7 +90,7 @@ execute_process( ) # Add module library -add_library(ddbc_bindings MODULE ddbc_bindings.cpp connection/connection.cpp) +add_library(ddbc_bindings MODULE ddbc_bindings.cpp connection/connection.cpp connection/connection_pool.cpp) # Add include directories for your project target_include_directories(ddbc_bindings PRIVATE diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index e87e5daf..dba3a9b2 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -5,6 +5,7 @@ // taken up in future #include "connection.h" +#include "connection_pool.h" #include #include @@ -16,8 +17,8 @@ SqlHandlePtr Connection::_envHandle = nullptr; // This class wraps low-level ODBC operations like connect/disconnect, // transaction control, and autocommit configuration. //------------------------------------------------------------------------------------------------- -Connection::Connection(const std::wstring& conn_str, bool autocommit) - : _connStr(conn_str) , _autocommit(autocommit) { +Connection::Connection(const std::wstring& conn_str, bool use_pool) + : _connStr(conn_str), _autocommit(false), _fromPool(use_pool) { if (!_envHandle) { LOG("Allocating environment handle"); SQLHANDLE env = nullptr; @@ -64,6 +65,7 @@ void Connection::connect(const py::dict& attrs_before) { (SQLWCHAR*)_connStr.c_str(), SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); checkError(ret); + updateLastUsed(); } void Connection::disconnect() { @@ -91,6 +93,7 @@ void Connection::commit() { if (!_dbcHandle) { ThrowStdException("Connection handle not allocated"); } + updateLastUsed(); LOG("Committing transaction"); SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT); checkError(ret); @@ -100,6 +103,7 @@ void Connection::rollback() { if (!_dbcHandle) { ThrowStdException("Connection handle not allocated"); } + updateLastUsed(); LOG("Rolling back transaction"); SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK); checkError(ret); @@ -132,6 +136,7 @@ SqlHandlePtr Connection::allocStatementHandle() { if (!_dbcHandle) { ThrowStdException("Connection handle not allocated"); } + updateLastUsed(); LOG("Allocating statement handle"); SQLHANDLE stmt = nullptr; SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt); @@ -185,4 +190,99 @@ void Connection::applyAttrsBefore(const py::dict& attrs) { } } } -} \ No newline at end of file +} + +bool Connection::isAlive() const { + if (!_dbcHandle) { + ThrowStdException("Connection handle not allocated"); + } + SQLUINTEGER status; + SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_CONNECTION_DEAD, + &status, 0, nullptr); + return SQL_SUCCEEDED(ret) && status == SQL_CD_FALSE; +} + +bool Connection::reset() { + if (!_dbcHandle) { + ThrowStdException("Connection handle not allocated"); + } + LOG("Resetting connection via SQL_ATTR_RESET_CONNECTION"); + SQLULEN reset = SQL_TRUE; + SQLRETURN ret = SQLSetConnectAttr_ptr( + _dbcHandle->get(), + SQL_ATTR_RESET_CONNECTION, + (SQLPOINTER)SQL_RESET_CONNECTION_YES, + SQL_IS_INTEGER); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to reset connection. Marking as dead."); + disconnect(); + return false; + } + updateLastUsed(); + return true; +} + +void Connection::updateLastUsed() { + _lastUsed = std::chrono::steady_clock::now(); +} + +std::chrono::steady_clock::time_point Connection::lastUsed() const { + return _lastUsed; +} + +ConnectionHandle::ConnectionHandle(const std::wstring& connStr, bool usePool, const py::dict& attrsBefore) + : _connStr(connStr), _usePool(usePool) { + if (_usePool) { + _conn = ConnectionPoolManager::getInstance().acquireConnection(connStr, attrsBefore); + } else { + _conn = std::make_shared(connStr, false); + _conn->connect(attrsBefore); + } +} + +void ConnectionHandle::close() { + if (!_conn) { + ThrowStdException("Connection object is not initialized"); + } + if (_usePool) { + ConnectionPoolManager::getInstance().returnConnection(_connStr, _conn); + } else { + _conn->disconnect(); + } + _conn = nullptr; +} + +void ConnectionHandle::commit() { + if (!_conn) { + ThrowStdException("Connection object is not initialized"); + } + _conn->commit(); +} + +void ConnectionHandle::rollback() { + if (!_conn) { + ThrowStdException("Connection object is not initialized"); + } + _conn->rollback(); +} + +void ConnectionHandle::setAutocommit(bool enabled) { + if (!_conn) { + ThrowStdException("Connection object is not initialized"); + } + _conn->setAutocommit(enabled); +} + +bool ConnectionHandle::getAutocommit() const { + if (!_conn) { + ThrowStdException("Connection object is not initialized"); + } + return _conn->getAutocommit(); +} + +SqlHandlePtr ConnectionHandle::allocStatementHandle() { + if (!_conn) { + ThrowStdException("Connection object is not initialized"); + } + return _conn->allocStatementHandle(); +} diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index afdeaecb..eb604fd6 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -13,7 +13,7 @@ class Connection { public: - Connection(const std::wstring& conn_str, bool autocommit = false); + Connection(const std::wstring& connStr, bool fromPool); ~Connection(); // Establish the connection using the stored connection string. @@ -34,8 +34,16 @@ class Connection { // Check whether autocommit is enabled. bool getAutocommit() const; + bool isAlive() const; + + bool reset(); + + void updateLastUsed(); + + std::chrono::steady_clock::time_point lastUsed() const; + // Allocate a new statement handle on this connection. - SqlHandlePtr allocStatementHandle(); + SqlHandlePtr allocStatementHandle(); private: void allocateDbcHandle(); @@ -44,9 +52,26 @@ class Connection { void applyAttrsBefore(const py::dict& attrs_before); std::wstring _connStr; - bool _usePool = false; + bool _fromPool = false; bool _autocommit = true; SqlHandlePtr _dbcHandle; - static SqlHandlePtr _envHandle; + std::chrono::steady_clock::time_point _lastUsed; }; + +class ConnectionHandle { +public: + ConnectionHandle(const std::wstring& connStr, bool usePool, const py::dict& attrsBefore = py::dict()); + + void close(); + void commit(); + void rollback(); + void setAutocommit(bool enabled); + bool getAutocommit() const; + SqlHandlePtr allocStatementHandle(); + +private: + std::shared_ptr _conn; + bool _usePool; + std::wstring _connStr; +}; \ No newline at end of file diff --git a/mssql_python/pybind/connection/connection_pool.cpp b/mssql_python/pybind/connection/connection_pool.cpp new file mode 100644 index 00000000..ef006f96 --- /dev/null +++ b/mssql_python/pybind/connection/connection_pool.cpp @@ -0,0 +1,85 @@ +#include "connection_pool.h" +#include +#include + +ConnectionPool::ConnectionPool(size_t max_size, int idle_timeout_secs) + : _max_size(max_size), _idle_timeout_secs(idle_timeout_secs), _current_size(0) {} + +std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, const py::dict& attrs_before) { + std::lock_guard lock(_mutex); + auto now = std::chrono::steady_clock::now(); + size_t before = _pool.size(); + _pool.erase(std::remove_if(_pool.begin(), _pool.end(), [&](const std::shared_ptr& conn) { + auto idle_time = std::chrono::duration_cast(now - conn->lastUsed()).count(); + if (idle_time > _idle_timeout_secs) { + conn->disconnect(); + return true; + } + return false; + }), _pool.end()); + size_t pruned = before - _pool.size(); + _current_size = (_current_size >= pruned) ? (_current_size - pruned) : 0; + + while (!_pool.empty()) { + auto conn = _pool.front(); + _pool.pop_front(); + if (conn->isAlive()) { + if (!conn->reset()) { + continue; + } + return conn; + } else { + conn->disconnect(); + --_current_size; + } + } + if (_current_size < _max_size) { + auto conn = std::make_shared(connStr, true); + conn->connect(attrs_before); + return conn; + } else { + LOG("Cannot acquire connection: pool size limit reached"); + return nullptr; + } +} + +void ConnectionPool::release(std::shared_ptr conn) { + std::lock_guard lock(_mutex); + if (_pool.size() < _max_size) { + conn->updateLastUsed(); + _pool.push_back(conn); + } + else { + conn->disconnect(); + if (_current_size > 0) --_current_size; + } +} + +ConnectionPoolManager& ConnectionPoolManager::getInstance() { + static ConnectionPoolManager manager; + return manager; +} + +std::shared_ptr ConnectionPoolManager::acquireConnection(const std::wstring& connStr, const py::dict& attrs_before) { + std::lock_guard lock(_manager_mutex); + + auto& pool = _pools[connStr]; + if (!pool) { + LOG("Creating new connection pool"); + pool = std::make_shared(_default_max_size, _default_idle_secs); + } + return pool->acquire(connStr, attrs_before); +} + +void ConnectionPoolManager::returnConnection(const std::wstring& conn_str, const std::shared_ptr conn) { + std::lock_guard lock(_manager_mutex); + if (_pools.find(conn_str) != _pools.end()) { + _pools[conn_str]->release((conn)); + } +} + +void ConnectionPoolManager::configure(int max_size, int idle_timeout_secs) { + std::lock_guard lock(_manager_mutex); + _default_max_size = max_size; + _default_idle_secs = idle_timeout_secs; +} \ No newline at end of file diff --git a/mssql_python/pybind/connection/connection_pool.h b/mssql_python/pybind/connection/connection_pool.h new file mode 100644 index 00000000..f4a5dbfd --- /dev/null +++ b/mssql_python/pybind/connection/connection_pool.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be +// taken up in future. + +#pragma once +#include +#include +#include +#include +#include +#include +#include "connection.h" + +// Manages a fixed-size pool of reusable database connections for a single connection string +class ConnectionPool { +public: + ConnectionPool(size_t max_size, int idle_timeout_secs); + + // Acquires a connection from the pool or creates a new one if under limit + std::shared_ptr acquire(const std::wstring& connStr, const py::dict& attrs_before = py::dict()); + + // Returns a connection to the pool for reuse + void release(std::shared_ptr conn); + +private: + size_t _max_size; // Maximum number of connections allowed + int _idle_timeout_secs; // Idle time before connections are considered stale + size_t _current_size = 0; + std::deque> _pool; // Available connections + std::mutex _mutex; // Mutex for thread-safe access +}; + +// Singleton manager that handles multiple pools keyed by connection string +class ConnectionPoolManager { +public: + // Returns the singleton instance of the manager + static ConnectionPoolManager& getInstance(); + + void configure(int max_size, int idle_timeout); + + // Gets a connection from the appropriate pool (creates one if none exists) + std::shared_ptr acquireConnection(const std::wstring& conn_str, const py::dict& attrs_before = py::dict()); + + // Returns a connection to its original pool + void returnConnection(const std::wstring& conn_str, std::shared_ptr conn); + +private: + ConnectionPoolManager() = default; + ~ConnectionPoolManager() = default; + + // Map from connection string to connection pool + std::unordered_map> _pools; + + // Protects access to the _pools map + std::mutex _manager_mutex; + size_t _default_max_size = 10; + int _default_idle_secs = 300; + + // Prevent copying + ConnectionPoolManager(const ConnectionPoolManager&) = delete; + ConnectionPoolManager& operator=(const ConnectionPoolManager&) = delete; +}; diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index b6d6b7e7..65219a9f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -5,6 +5,7 @@ // taken up in beta release #include "ddbc_bindings.h" #include "connection/connection.h" +#include "connection/connection_pool.h" #include #include // std::setw, std::setfill @@ -1876,6 +1877,10 @@ SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) { return rowCount; } +void enable_pooling(int maxSize, int idleTimeout) { + ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout); +} + // Architecture-specific defines #ifndef ARCHITECTURE #define ARCHITECTURE "win64" // Default to win64 if not defined during compilation @@ -1919,15 +1924,15 @@ PYBIND11_MODULE(ddbc_bindings, m) { py::class_(m, "SqlHandle") .def("free", &SqlHandle::free, "Free the handle"); - py::class_(m, "Connection") - .def(py::init(), py::arg("conn_str"), py::arg("autocommit") = false) - .def("connect", &Connection::connect) - .def("close", &Connection::disconnect, "Close the connection") - .def("commit", &Connection::commit, "Commit the current transaction") - .def("rollback", &Connection::rollback, "Rollback the current transaction") - .def("set_autocommit", &Connection::setAutocommit) - .def("get_autocommit", &Connection::getAutocommit) - .def("alloc_statement_handle", &Connection::allocStatementHandle); + py::class_(m, "Connection") + .def(py::init(), py::arg("conn_str"), py::arg("use_pool"), py::arg("attrs_before") = py::dict()) + .def("close", &ConnectionHandle::close, "Close the connection") + .def("commit", &ConnectionHandle::commit, "Commit the current transaction") + .def("rollback", &ConnectionHandle::rollback, "Rollback the current transaction") + .def("set_autocommit", &ConnectionHandle::setAutocommit) + .def("get_autocommit", &ConnectionHandle::getAutocommit) + .def("alloc_statement_handle", &ConnectionHandle::allocStatementHandle); + m.def("enable_pooling", &enable_pooling, "Enable global connection pooling"); m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements"); m.def("DDBCSQLRowCount", &SQLRowCount_wrap, From f576b08c950ac7177eef92491d91fbb349ff5f3e Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Sun, 1 Jun 2025 10:26:43 +0530 Subject: [PATCH 18/18] addressed review comments --- mssql_python/pybind/connection/connection.cpp | 45 ++++++---- mssql_python/pybind/connection/connection.h | 2 +- .../pybind/connection/connection_pool.cpp | 84 ++++++++++++------- 3 files changed, 83 insertions(+), 48 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index e380e447..58f35ae4 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -11,7 +11,28 @@ #define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token -SqlHandlePtr Connection::_envHandle = nullptr; +static SqlHandlePtr getEnvHandle() { + static SqlHandlePtr envHandle = []() -> SqlHandlePtr { + LOG("Allocating ODBC environment handle"); + if (!SQLAllocHandle_ptr) { + LOG("Function pointers not initialized, loading driver"); + DriverLoader::getInstance().loadDriver(); + } + SQLHANDLE env = nullptr; + SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); + if (!SQL_SUCCEEDED(ret)) { + ThrowStdException("Failed to allocate environment handle"); + } + ret = SQLSetEnvAttr_ptr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0); + if (!SQL_SUCCEEDED(ret)) { + ThrowStdException("Failed to set environment attributes"); + } + return std::make_shared(SQL_HANDLE_ENV, env); + }(); + + return envHandle; +} + //------------------------------------------------------------------------------------------------- // Implements the Connection class declared in connection.h. // This class wraps low-level ODBC operations like connect/disconnect, @@ -19,21 +40,6 @@ SqlHandlePtr Connection::_envHandle = nullptr; //------------------------------------------------------------------------------------------------- Connection::Connection(const std::wstring& conn_str, bool use_pool) : _connStr(conn_str), _autocommit(false), _fromPool(use_pool) { - if (!_envHandle) { - LOG("Allocating environment handle"); - SQLHANDLE env = nullptr; - if (!SQLAllocHandle_ptr) { - LOG("Function pointers not initialized, loading driver"); - DriverLoader::getInstance().loadDriver(); - } - SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); - checkError(ret); - _envHandle = std::make_shared(SQL_HANDLE_ENV, env); - - LOG("Setting environment attributes"); - ret = SQLSetEnvAttr_ptr(_envHandle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0); - checkError(ret); - } allocateDbcHandle(); } @@ -43,6 +49,7 @@ Connection::~Connection() { // Allocates connection handle void Connection::allocateDbcHandle() { + auto _envHandle = getEnvHandle(); SQLHANDLE dbc = nullptr; LOG("Allocate SQL Connection Handle"); SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, _envHandle->get(), &dbc); @@ -240,6 +247,12 @@ ConnectionHandle::ConnectionHandle(const std::wstring& connStr, bool usePool, co } } +ConnectionHandle::~ConnectionHandle() { + if (_conn) { + close(); + } +} + void ConnectionHandle::close() { if (!_conn) { ThrowStdException("Connection object is not initialized"); diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index 8e127978..b9cc50b6 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -52,13 +52,13 @@ class Connection { bool _fromPool = false; bool _autocommit = true; SqlHandlePtr _dbcHandle; - static SqlHandlePtr _envHandle; std::chrono::steady_clock::time_point _lastUsed; }; class ConnectionHandle { public: ConnectionHandle(const std::wstring& connStr, bool usePool, const py::dict& attrsBefore = py::dict()); + ~ConnectionHandle(); void close(); void commit(); diff --git a/mssql_python/pybind/connection/connection_pool.cpp b/mssql_python/pybind/connection/connection_pool.cpp index f5510512..dbe2765e 100644 --- a/mssql_python/pybind/connection/connection_pool.cpp +++ b/mssql_python/pybind/connection/connection_pool.cpp @@ -12,42 +12,64 @@ ConnectionPool::ConnectionPool(size_t max_size, int idle_timeout_secs) : _max_size(max_size), _idle_timeout_secs(idle_timeout_secs), _current_size(0) {} std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, const py::dict& attrs_before) { - std::lock_guard lock(_mutex); - auto now = std::chrono::steady_clock::now(); - size_t before = _pool.size(); - _pool.erase(std::remove_if(_pool.begin(), _pool.end(), [&](const std::shared_ptr& conn) { - auto idle_time = std::chrono::duration_cast(now - conn->lastUsed()).count(); - if (idle_time > _idle_timeout_secs) { - conn->disconnect(); - return true; - } - return false; - }), _pool.end()); - size_t pruned = before - _pool.size(); - _current_size = (_current_size >= pruned) ? (_current_size - pruned) : 0; + std::vector> to_disconnect; + std::shared_ptr valid_conn = nullptr; + { + std::lock_guard lock(_mutex); + auto now = std::chrono::steady_clock::now(); + size_t before = _pool.size(); - while (!_pool.empty()) { - auto conn = _pool.front(); - _pool.pop_front(); - if (conn->isAlive()) { - if (!conn->reset()) { - continue; + // Phase 1: Remove stale connections, collect for later disconnect + _pool.erase(std::remove_if(_pool.begin(), _pool.end(), + [&](const std::shared_ptr& conn) { + auto idle_time = std::chrono::duration_cast(now - conn->lastUsed()).count(); + if (idle_time > _idle_timeout_secs) { + to_disconnect.push_back(conn); + return true; + } + return false; + }), _pool.end()); + + size_t pruned = before - _pool.size(); + _current_size = (_current_size >= pruned) ? (_current_size - pruned) : 0; + + // Phase 2: Attempt to reuse healthy connections + while (!_pool.empty()) { + auto conn = _pool.front(); + _pool.pop_front(); + if (conn->isAlive()) { + if (!conn->reset()) { + to_disconnect.push_back(conn); + --_current_size; + continue; + } + valid_conn = conn; + break; + } else { + to_disconnect.push_back(conn); + --_current_size; } - return conn; - } else { - conn->disconnect(); - --_current_size; + } + + // Create new connection if none reusable + if (!valid_conn && _current_size < _max_size) { + valid_conn = std::make_shared(connStr, true); + valid_conn->connect(attrs_before); + ++_current_size; + } else if (!valid_conn) { + throw std::runtime_error("ConnectionPool::acquire: pool size limit reached"); } } - if (_current_size < _max_size) { - auto conn = std::make_shared(connStr, true); - conn->connect(attrs_before); - ++_current_size; - return conn; - } else { - LOG("Cannot acquire connection: pool size limit reached"); - return nullptr; + + // Phase 3: Disconnect expired/bad connections outside lock + for (auto& conn : to_disconnect) { + try { + conn->disconnect(); + } catch (const std::exception& ex) { + std::cout << "disconnect() failed: " << ex.what() << std::endl; + } } + return valid_conn; } void ConnectionPool::release(std::shared_ptr conn) {