diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 8031a26a..7c73a6df 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -53,14 +53,12 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef preparing it for further operations such as connecting to the database, executing queries, etc. """ - self.henv = None - self.hdbc = None self.connection_str = self._construct_connection_string( connection_str, **kwargs ) - self._attrs_before = attrs_before - self._autocommit = autocommit # Initialize _autocommit before calling _initializer - self._initializer() + self._attrs_before = attrs_before or {} + self._conn = ddbc_bindings.Connection(self.connection_str, autocommit) + self._conn.connect(self._attrs_before) self.setautocommit(autocommit) def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str: @@ -100,178 +98,7 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st logger.info("Final connection string: %s", conn_str) return conn_str - - def _is_closed(self) -> bool: - """ - Check if the connection is closed. - - Returns: - bool: True if the connection is closed, False otherwise. - """ - return self.hdbc is None - def _initializer(self) -> None: - """ - Initialize the environment and connection handles. - - This method is responsible for setting up the environment and connection - handles, allocating memory for them, and setting the necessary attributes. - It should be called before establishing a connection to the database. - """ - self._allocate_environment_handle() - self._set_environment_attributes() - self._allocate_connection_handle() - if self._attrs_before != {}: - self._apply_attrs_before() # Apply pre-connection attributes - if self._autocommit: - self._set_connection_attributes( - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, - ddbc_sql_const.SQL_AUTOCOMMIT_ON.value, - ) - self._connect_to_db() - - def _apply_attrs_before(self): - """ - Apply specific pre-connection attributes. - Currently, this method only processes an attribute with key 1256 (e.g., SQL_COPT_SS_ACCESS_TOKEN) - if present in `self._attrs_before`. Other attributes are ignored. - - Returns: - bool: True. - """ - - if ENABLE_LOGGING: - logger.info("Attempting to apply pre-connection attributes (attrs_before): %s", self._attrs_before) - - if not isinstance(self._attrs_before, dict): - if self._attrs_before is not None and ENABLE_LOGGING: - logger.warning( - f"_attrs_before is of type {type(self._attrs_before).__name__}, " - f"expected dict. Skipping attribute application." - ) - elif self._attrs_before is None and ENABLE_LOGGING: - logger.debug("_attrs_before is None. No pre-connection attributes to apply.") - return True # Exit if _attrs_before is not a dictionary or is None - - for key, value in self._attrs_before.items(): - ikey = None - if isinstance(key, int): - ikey = key - elif isinstance(key, str) and key.isdigit(): - try: - ikey = int(key) - except ValueError: - if ENABLE_LOGGING: - logger.debug( - f"Skipping attribute with key '{key}' in attrs_before: " - f"could not convert string to int." - ) - continue # Skip if string key is not a valid integer - else: - if ENABLE_LOGGING: - logger.debug( - f"Skipping attribute with key '{key}' in attrs_before due to " - f"unsupported key type: {type(key).__name__}. Expected int or string representation of an int." - ) - continue # Skip keys that are not int or string representation of an int - - if ikey == ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value: - if ENABLE_LOGGING: - logger.info( - f"Found attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value}. Attempting to set it." - ) - self._set_connection_attributes(ikey, value) - if ENABLE_LOGGING: - logger.info( - f"Call to set attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value} with value '{value}' completed." - ) - # If you expect only one such key, you could add 'break' here. - else: - if ENABLE_LOGGING: - logger.debug( - f"Ignoring attribute with key '{key}' (resolved to {ikey}) in attrs_before " - f"as it is not the target attribute ({ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value})." - ) - return True - - def _allocate_environment_handle(self): - """ - Allocate the environment handle. - """ - ret, handle = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_ENV.value, # SQL environment handle type - None - ) - check_error(ddbc_sql_const.SQL_HANDLE_ENV.value, handle, ret) - self.henv = handle - - def _set_environment_attributes(self): - """ - Set the environment attributes. - """ - ret = ddbc_bindings.DDBCSQLSetEnvAttr( - self.henv, # Use the wrapper class - ddbc_sql_const.SQL_ATTR_DDBC_VERSION.value, # Attribute - ddbc_sql_const.SQL_OV_DDBC3_80.value, # String Length - 0, # Null-terminated string - ) - check_error(ddbc_sql_const.SQL_HANDLE_ENV.value, self.henv, ret) - - def _allocate_connection_handle(self): - """ - Allocate the connection handle. - """ - ret, handle = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_DBC.value, # SQL connection handle type - self.henv - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, handle, ret) - self.hdbc = handle - - def _set_connection_attributes(self, ikey: int, ivalue: any) -> None: - """ - Set the connection attributes before connecting. - - Args: - ikey (int): The attribute key to set. - ivalue (Any): The value to set for the attribute. Can be bytes, bytearray, int, or unicode. - vallen (int): The length of the value. - - Raises: - DatabaseError: If there is an error while setting the connection attribute. - """ - - ret = ddbc_bindings.DDBCSQLSetConnectAttr( - self.hdbc, # Connection handle - ikey, # Attribute - ivalue, # Value - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - - def _connect_to_db(self) -> None: - """ - Establish a connection to the database. - - This method is responsible for creating a connection to the specified database. - It does not take any arguments and does not return any value. The connection - details such as database name, user credentials, host, and port should be - configured within the class or passed during the class instantiation. - - Raises: - DatabaseError: If there is an error while trying to connect to the database. - InterfaceError: If there is an error related to the database interface. - """ - if ENABLE_LOGGING: - logger.info("Connecting to the database") - ret = ddbc_bindings.DDBCSQLDriverConnect( - self.hdbc, # Connection handle (wrapper) - 0, # Window handle - self.connection_str, # Connection string - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - if ENABLE_LOGGING: - logger.info("Connection established successfully.") - @property def autocommit(self) -> bool: """ @@ -279,14 +106,7 @@ def autocommit(self) -> bool: Returns: bool: True if autocommit is enabled, False otherwise. """ - autocommit_mode = ddbc_bindings.DDBCSQLGetConnectionAttr( - self.hdbc, # Connection handle (wrapper) - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, # Attribute - ) - check_error( - ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, autocommit_mode - ) - return autocommit_mode == ddbc_sql_const.SQL_AUTOCOMMIT_ON.value + return self._conn.get_autocommit() @autocommit.setter def autocommit(self, value: bool) -> None: @@ -296,20 +116,8 @@ def autocommit(self, value: bool) -> None: value (bool): True to enable autocommit, False to disable it. Returns: None - Raises: - DatabaseError: If there is an error while setting the autocommit mode. """ - ret = ddbc_bindings.DDBCSQLSetConnectAttr( - self.hdbc, # Connection handle - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, # Attribute - ( - ddbc_sql_const.SQL_AUTOCOMMIT_ON.value - if value - else ddbc_sql_const.SQL_AUTOCOMMIT_OFF.value - ), # Value - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - self._autocommit = value + self.setautocommit(value) if ENABLE_LOGGING: logger.info("Autocommit mode set to %s.", value) @@ -323,7 +131,7 @@ def setautocommit(self, value: bool = True) -> None: Raises: DatabaseError: If there is an error while setting the autocommit mode. """ - self.autocommit = value + self._conn.set_autocommit(value) def cursor(self) -> Cursor: """ @@ -340,9 +148,6 @@ def cursor(self) -> Cursor: DatabaseError: If there is an error while creating the cursor. InterfaceError: If there is an error related to the database interface. """ - if self._is_closed(): - # Cannot create a cursor if the connection is closed - raise Exception("Connection is closed. Cannot create cursor.") return Cursor(self) def commit(self) -> None: @@ -357,17 +162,8 @@ def commit(self) -> None: Raises: DatabaseError: If there is an error while committing the transaction. """ - if self._is_closed(): - # Cannot commit if the connection is closed - raise Exception("Connection is closed. Cannot commit.") - # Commit the current transaction - ret = ddbc_bindings.DDBCSQLEndTran( - ddbc_sql_const.SQL_HANDLE_DBC.value, # Handle type - self.hdbc, # Connection handle (wrapper) - ddbc_sql_const.SQL_COMMIT.value, # Commit the transaction - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) + self._conn.commit() if ENABLE_LOGGING: logger.info("Transaction committed successfully.") @@ -382,17 +178,8 @@ def rollback(self) -> None: Raises: DatabaseError: If there is an error while rolling back the transaction. """ - if self._is_closed(): - # Cannot roll back if the connection is closed - raise Exception("Connection is closed. Cannot roll back.") - # Roll back the current transaction - ret = ddbc_bindings.DDBCSQLEndTran( - ddbc_sql_const.SQL_HANDLE_DBC.value, # Handle type - self.hdbc, # Connection handle (wrapper) - ddbc_sql_const.SQL_ROLLBACK.value, # Roll back the transaction - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) + self._conn.rollback() if ENABLE_LOGGING: logger.info("Transaction rolled back successfully.") @@ -409,16 +196,7 @@ def close(self) -> None: Raises: DatabaseError: If there is an error while closing the connection. """ - if self._is_closed(): - # Connection is already closed - return - # Disconnect from the database - ret = ddbc_bindings.DDBCSQLDisconnect(self.hdbc) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret) - - # Set the reference to None to trigger destructor - self.hdbc.free() - self.hdbc = None - + # Close the connection + self._conn.close() if ENABLE_LOGGING: logger.info("Connection closed successfully.") diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index c038ea7e..7ef3ebac 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -48,8 +48,6 @@ def __init__(self, connection) -> None: Args: connection: Database connection object. """ - if connection.hdbc is None: - raise Exception("Connection is closed. Cannot create a cursor.") self.connection = connection # self.connection.autocommit = False self.hstmt = None @@ -417,19 +415,14 @@ def _allocate_statement_handle(self): """ Allocate the DDBC statement handle. """ - ret, handle = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_STMT.value, - self.connection.hdbc - ) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, handle, ret) - self.hstmt = handle + self.hstmt = self.connection._conn.alloc_statement_handle() def _reset_cursor(self) -> None: """ Reset the DDBC statement handle. """ if self.hstmt: - self.hstmt.free() # Free the existing statement handle + self.hstmt.free() self.hstmt = None if ENABLE_LOGGING: logger.debug("SQLFreeHandle succeeded") @@ -557,7 +550,6 @@ def execute( reset_cursor: Whether to reset the cursor before execution. """ self._check_closed() # Check if the cursor is closed - if reset_cursor: self._reset_cursor() diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 7c90a587..e87e5daf 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -5,7 +5,10 @@ // taken up in future #include "connection.h" -#include +#include +#include + +#define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token SqlHandlePtr Connection::_envHandle = nullptr; //------------------------------------------------------------------------------------------------- @@ -46,14 +49,21 @@ void Connection::allocateDbcHandle() { _dbcHandle = std::make_shared(SQL_HANDLE_DBC, dbc); } -void Connection::connect() { +void Connection::connect(const py::dict& attrs_before) { LOG("Connecting to database"); + // Apply access token before connect + if (!attrs_before.is_none() && py::len(attrs_before) > 0) { + LOG("Apply attributes before connect"); + applyAttrsBefore(attrs_before); + if (_autocommit) { + setAutocommit(_autocommit); + } + } SQLRETURN ret = SQLDriverConnect_ptr( _dbcHandle->get(), nullptr, (SQLWCHAR*)_connStr.c_str(), SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); checkError(ret); - setAutocommit(_autocommit); } void Connection::disconnect() { @@ -128,3 +138,51 @@ SqlHandlePtr Connection::allocStatementHandle() { checkError(ret); return std::make_shared(SQL_HANDLE_STMT, stmt); } + + +SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { + LOG("Setting SQL attribute"); + SQLPOINTER ptr = nullptr; + SQLINTEGER length = 0; + + if (py::isinstance(value)) { + int intValue = value.cast(); + ptr = reinterpret_cast(static_cast(intValue)); + length = SQL_IS_INTEGER; + } else if (py::isinstance(value) || py::isinstance(value)) { + static std::vector buffers; + buffers.emplace_back(value.cast()); + ptr = const_cast(buffers.back().c_str()); + length = static_cast(buffers.back().size()); + } else { + LOG("Unsupported attribute value type"); + return SQL_ERROR; + } + + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), attribute, ptr, length); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to set attribute"); + } + else { + LOG("Set attribute successfully"); + } + return ret; +} + +void Connection::applyAttrsBefore(const py::dict& attrs) { + for (const auto& item : attrs) { + int key; + try { + key = py::cast(item.first); + } catch (...) { + continue; + } + + if (key == SQL_COPT_SS_ACCESS_TOKEN) { + SQLRETURN ret = setAttribute(key, py::reinterpret_borrow(item.second)); + if (!SQL_SUCCEEDED(ret)) { + ThrowStdException("Failed to set access token before connect"); + } + } + } +} \ No newline at end of file diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index 9f9b4d0a..afdeaecb 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -17,7 +17,7 @@ class Connection { ~Connection(); // Establish the connection using the stored connection string. - void connect(); + void connect(const py::dict& attrs_before = py::dict()); // Disconnect and free the connection handle. void disconnect(); @@ -40,6 +40,8 @@ class Connection { private: void allocateDbcHandle(); void checkError(SQLRETURN ret) const; + SQLRETURN setAttribute(SQLINTEGER attribute, py::object value); + void applyAttrsBefore(const py::dict& attrs_before); std::wstring _connStr; bool _usePool = false; diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 33804a94..b6d6b7e7 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3,8 +3,8 @@ // INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be // taken up in beta release - -#include // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions +#include "ddbc_bindings.h" +#include "connection/connection.h" #include #include // std::setw, std::setfill @@ -15,16 +15,6 @@ #include #pragma comment(lib, "shlwapi.lib") -#include "ddbc_bindings.h" -#include -#include -#include -#include // Add this line for datetime support -#include - -namespace py = pybind11; -using namespace pybind11::literals; - //------------------------------------------------------------------------------------------------- // Macro definitions //------------------------------------------------------------------------------------------------- @@ -633,11 +623,11 @@ void DriverLoader::loadDriver() { SqlHandle::SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle) : _type(type), _handle(rawHandle) {} -// Note: Destructor is intentionally a no-op. Python owns the lifecycle. -// Native ODBC handles must be explicitly released by calling `free()` directly from Python. -// This avoids nondeterministic crashes during GC or shutdown during pytest. -// Read the documentation for more details (https://aka.ms/CPPvsPythonGC) -SqlHandle::~SqlHandle() {} +SqlHandle::~SqlHandle() { + if (_handle) { + free(); + } +} SQLHANDLE SqlHandle::get() const { return _handle; @@ -665,134 +655,6 @@ void SqlHandle::free() { } } -// Wrap SQLAllocHandle -SQLRETURN SQLAllocHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr InputHandle, SqlHandlePtr& OutputHandle) { - LOG("Allocate SQL Handle"); - if (!SQLAllocHandle_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - SQLHANDLE rawOutputHandle = nullptr; - SQLRETURN ret = SQLAllocHandle_ptr(HandleType, InputHandle ? InputHandle->get() : nullptr, &rawOutputHandle); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to allocate handle"); - return ret; - } - OutputHandle = std::make_shared(HandleType, rawOutputHandle); - return ret; -} - -// Wrap SQLSetEnvAttr -SQLRETURN SQLSetEnvAttr_wrap(SqlHandlePtr EnvHandle, SQLINTEGER Attribute, intptr_t ValuePtr, - SQLINTEGER StringLength) { - LOG("Set SQL environment Attribute"); - if (!SQLSetEnvAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - // TODO: Does ValuePtr need to be converted from Python to C++ object? - SQLRETURN ret = SQLSetEnvAttr_ptr(EnvHandle->get(), Attribute, reinterpret_cast(ValuePtr), StringLength); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set environment attribute"); - } - return ret; -} - -// Wrap SQLSetConnectAttr -SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, - py::object ValuePtr) { - LOG("Set SQL Connection Attribute"); - if (!SQLSetConnectAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - // Print the type of ValuePtr and attribute value - helpful for debugging - LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast(), Attribute); - - SQLPOINTER value = 0; - SQLINTEGER length = 0; - - if (py::isinstance(ValuePtr)) { - // Handle integer values - int intValue = ValuePtr.cast(); - value = reinterpret_cast(intValue); - length = SQL_IS_INTEGER; // Integer values don't require a length - // } else if (py::isinstance(ValuePtr)) { - // // Handle Unicode string values - // static std::wstring unicodeValueBuffer; - // unicodeValueBuffer = ValuePtr.cast(); - // value = const_cast(unicodeValueBuffer.c_str()); - // length = SQL_NTS; // Indicates null-terminated string - } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { - // Handle byte or bytearray values (like access tokens) - // Store in static buffer to ensure memory remains valid during connection - static std::vector bytesBuffers; - bytesBuffers.push_back(ValuePtr.cast()); - value = const_cast(bytesBuffers.back().c_str()); - length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token) - // } else if (py::isinstance(ValuePtr) || py::isinstance(ValuePtr)) { - // // Handle list or tuple values - // LOG("ValuePtr is a sequence (list or tuple)"); - // for (py::handle item : ValuePtr) { - // LOG("Processing item in sequence"); - // SQLRETURN ret = SQLSetConnectAttr_wrap(ConnectionHandle, Attribute, py::reinterpret_borrow(item)); - // if (!SQL_SUCCEEDED(ret)) { - // LOG("Failed to set attribute for item in sequence"); - // return ret; - // } - // } - } else { - LOG("Unsupported ValuePtr type"); - return SQL_ERROR; - } - - SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set Connection attribute"); - } - LOG("Set Connection attribute successfully"); - return ret; -} - -// Wrap SQLSetStmtAttr -SQLRETURN SQLSetStmtAttr_wrap(SqlHandlePtr StatementHandle, SQLINTEGER Attribute, intptr_t ValuePtr, - SQLINTEGER StringLength) { - LOG("Set SQL Statement Attribute"); - if (!SQLSetConnectAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - // TODO: Does ValuePtr need to be converted from Python to C++ object? - SQLRETURN ret = SQLSetStmtAttr_ptr(StatementHandle->get(), Attribute, reinterpret_cast(ValuePtr), StringLength); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set Statement attribute"); - } - return ret; -} - -// Wrap SQLGetConnectionAttrA -// Currently only supports retrieval of int-valued attributes -// TODO: add support to retrieve all types of attributes -SQLINTEGER SQLGetConnectionAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER attribute) { - LOG("Get SQL COnnection Attribute"); - if (!SQLGetConnectAttr_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - SQLINTEGER stringLength; - SQLINTEGER intValue; - - // Try to get the attribute as an integer - SQLGetConnectAttr_ptr(ConnectionHandle->get(), attribute, &intValue, - sizeof(SQLINTEGER), &stringLength); - return intValue; -} - // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -826,23 +688,6 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET return errorInfo; } -// Wrap SQLDriverConnect -SQLRETURN SQLDriverConnect_wrap(SqlHandlePtr ConnectionHandle, intptr_t WindowHandle, const std::wstring& ConnectionString) { - LOG("Driver Connect to MSSQL"); - if (!SQLDriverConnect_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - SQLRETURN ret = SQLDriverConnect_ptr(ConnectionHandle->get(), - reinterpret_cast(WindowHandle), - const_cast(ConnectionString.c_str()), SQL_NTS, nullptr, - 0, nullptr, SQL_DRIVER_NOPROMPT); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to connect to DB"); - } - return ret; -} - // Wrap SQLExecDirect SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) { LOG("Execute SQL query directly - {}", Query.c_str()); @@ -1998,17 +1843,6 @@ SQLRETURN SQLMoreResults_wrap(SqlHandlePtr StatementHandle) { return SQLMoreResults_ptr(StatementHandle->get()); } -// Wrap SQLEndTran -SQLRETURN SQLEndTran_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle, SQLSMALLINT CompletionType) { - LOG("End SQL Transaction"); - if (!SQLEndTran_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - return SQLEndTran_ptr(HandleType, Handle->get(), CompletionType); -} - // Wrap SQLFreeHandle SQLRETURN SQLFreeHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle) { LOG("Free SQL handle"); @@ -2024,17 +1858,6 @@ SQLRETURN SQLFreeHandle_wrap(SQLSMALLINT HandleType, SqlHandlePtr Handle) { return ret; } -// Wrap SQLDisconnect -SQLRETURN SQLDisconnect_wrap(SqlHandlePtr ConnectionHandle) { - LOG("Disconnect from MSSQL"); - if (!SQLDisconnect_ptr) { - LOG("Function pointer not initialized. Loading the driver."); - DriverLoader::getInstance().loadDriver(); // Load the driver - } - - return SQLDisconnect_ptr(ConnectionHandle->get()); -} - // Wrap SQLRowCount SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) { LOG("Get number of row affected by last execute"); @@ -2095,23 +1918,16 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def_readwrite("ddbcErrorMsg", &ErrorInfo::ddbcErrorMsg); py::class_(m, "SqlHandle") - .def("free", &SqlHandle::free); - - m.def("DDBCSQLAllocHandle", [](SQLSMALLINT HandleType, SqlHandlePtr InputHandle = nullptr) { - SqlHandlePtr OutputHandle; - SQLRETURN rc = SQLAllocHandle_wrap(HandleType, InputHandle, OutputHandle); - return py::make_tuple(rc, OutputHandle); - }, "Allocate an environment, connection, statement, or descriptor handle"); - m.def("DDBCSQLSetEnvAttr", &SQLSetEnvAttr_wrap, - "Set an attribute that governs aspects of environments"); - m.def("DDBCSQLSetConnectAttr", &SQLSetConnectAttr_wrap, - "Set an attribute that governs aspects of connections"); - m.def("DDBCSQLSetStmtAttr", &SQLSetStmtAttr_wrap, - "Set an attribute that governs aspects of statements"); - m.def("DDBCSQLGetConnectionAttr", &SQLGetConnectionAttr_wrap, - "Get an attribute that governs aspects of connections"); - m.def("DDBCSQLDriverConnect", &SQLDriverConnect_wrap, - "Connect to a data source with a connection string"); + .def("free", &SqlHandle::free, "Free the handle"); + py::class_(m, "Connection") + .def(py::init(), py::arg("conn_str"), py::arg("autocommit") = false) + .def("connect", &Connection::connect) + .def("close", &Connection::disconnect, "Close the connection") + .def("commit", &Connection::commit, "Commit the current transaction") + .def("rollback", &Connection::rollback, "Rollback the current transaction") + .def("set_autocommit", &Connection::setAutocommit) + .def("get_autocommit", &Connection::getAutocommit) + .def("alloc_statement_handle", &Connection::allocStatementHandle); m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements"); m.def("DDBCSQLRowCount", &SQLRowCount_wrap, @@ -2127,9 +1943,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), py::arg("fetchSize") = 1, "Fetch many rows from the result set"); m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); - m.def("DDBCSQLEndTran", &SQLEndTran_wrap, "End a transaction"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); - m.def("DDBCSQLDisconnect", &SQLDisconnect_wrap, "Disconnect from a data source"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); // Add a version attribute diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index cb85eaf5..3d3925aa 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -6,6 +6,8 @@ #pragma once +#include // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions + #include #include #include @@ -13,6 +15,14 @@ #include #include +#include +#include +#include +#include // Add this line for datetime support +#include +namespace py = pybind11; +using namespace pybind11::literals; + //------------------------------------------------------------------------------------------------- // Function pointer typedefs //------------------------------------------------------------------------------------------------- @@ -107,11 +117,12 @@ extern SQLFreeStmtFunc SQLFreeStmt_ptr; extern SQLGetDiagRecFunc SQLGetDiagRec_ptr; -// -- Logging utility -- +// Logging utility template void LOG(const std::string& formatString, Args&&... args); -// -- Exception helper -- + +// Throws a std::runtime_error with the given message void ThrowStdException(const std::string& message); //------------------------------------------------------------------------------------------------- diff --git a/tests/test_005_exceptions.py b/tests/test_005_exceptions.py index 030c4f16..5e35d553 100644 --- a/tests/test_005_exceptions.py +++ b/tests/test_005_exceptions.py @@ -125,6 +125,6 @@ def test_foreign_key_constraint_error(cursor, db_connection): db_connection.commit() def test_connection_error(db_connection): - with pytest.raises(OperationalError) as excinfo: + with pytest.raises(RuntimeError) as excinfo: Connection("InvalidConnectionString") - assert "Client unable to establish connection" in str(excinfo.value) + assert "Neither DSN nor SERVER keyword supplied" in str(excinfo.value)