Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef
)
self._attrs_before = attrs_before or {}
self._pooling = PoolingManager.is_enabled()
self._conn = ddbc_bindings.Connection(self.connection_str, autocommit, self._pooling)
self._conn.connect(self._attrs_before)
self._conn = ddbc_bindings.Connection(self.connection_str, self._pooling, self._attrs_before)
self.setautocommit(autocommit)

def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str:
Expand Down
2 changes: 1 addition & 1 deletion mssql_python/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ execute_process(
)

# Add module library
add_library(ddbc_bindings MODULE ddbc_bindings.cpp connection/connection.cpp)
add_library(ddbc_bindings MODULE ddbc_bindings.cpp connection/connection.cpp connection/connection_pool.cpp)

# Add include directories for your project
target_include_directories(ddbc_bindings PRIVATE
Expand Down
147 changes: 130 additions & 17 deletions mssql_python/pybind/connection/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,41 @@
// taken up in future

#include "connection.h"
#include "connection_pool.h"
#include <vector>
#include <pybind11/pybind11.h>

#define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token

SqlHandlePtr Connection::_envHandle = nullptr;
//-------------------------------------------------------------------------------------------------
// Implements the Connection class declared in connection.h.
// This class wraps low-level ODBC operations like connect/disconnect,
// transaction control, and autocommit configuration.
//-------------------------------------------------------------------------------------------------
Connection::Connection(const std::wstring& conn_str, bool autocommit, bool use_pooling)
: _connStr(conn_str) , _autocommit(autocommit), _usePool(use_pooling) {
if (!_envHandle) {
LOG("Allocating environment handle");
SQLHANDLE env = nullptr;
static SqlHandlePtr getEnvHandle() {
static SqlHandlePtr envHandle = []() -> SqlHandlePtr {
LOG("Allocating ODBC environment handle");
if (!SQLAllocHandle_ptr) {
LOG("Function pointers not initialized, loading driver");
DriverLoader::getInstance().loadDriver();
}
SQLHANDLE env = nullptr;
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env);
checkError(ret);
_envHandle = std::make_shared<SqlHandle>(SQL_HANDLE_ENV, env);
if (!SQL_SUCCEEDED(ret)) {
ThrowStdException("Failed to allocate environment handle");
}
ret = SQLSetEnvAttr_ptr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0);
if (!SQL_SUCCEEDED(ret)) {
ThrowStdException("Failed to set environment attributes");
}
return std::make_shared<SqlHandle>(SQL_HANDLE_ENV, env);
}();

LOG("Setting environment attributes");
ret = SQLSetEnvAttr_ptr(_envHandle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0);
checkError(ret);
}
return envHandle;
}

//-------------------------------------------------------------------------------------------------
// Implements the Connection class declared in connection.h.
// This class wraps low-level ODBC operations like connect/disconnect,
// transaction control, and autocommit configuration.
//-------------------------------------------------------------------------------------------------
Connection::Connection(const std::wstring& conn_str, bool use_pool)
: _connStr(conn_str), _autocommit(false), _fromPool(use_pool) {
allocateDbcHandle();
}

Expand All @@ -42,6 +49,7 @@ Connection::~Connection() {

// Allocates connection handle
void Connection::allocateDbcHandle() {
auto _envHandle = getEnvHandle();
SQLHANDLE dbc = nullptr;
LOG("Allocate SQL Connection Handle");
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, _envHandle->get(), &dbc);
Expand All @@ -64,6 +72,7 @@ void Connection::connect(const py::dict& attrs_before) {
(SQLWCHAR*)_connStr.c_str(), SQL_NTS,
nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT);
checkError(ret);
updateLastUsed();
}

void Connection::disconnect() {
Expand Down Expand Up @@ -91,6 +100,7 @@ void Connection::commit() {
if (!_dbcHandle) {
ThrowStdException("Connection handle not allocated");
}
updateLastUsed();
LOG("Committing transaction");
SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT);
checkError(ret);
Expand All @@ -100,6 +110,7 @@ void Connection::rollback() {
if (!_dbcHandle) {
ThrowStdException("Connection handle not allocated");
}
updateLastUsed();
LOG("Rolling back transaction");
SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK);
checkError(ret);
Expand Down Expand Up @@ -132,6 +143,7 @@ SqlHandlePtr Connection::allocStatementHandle() {
if (!_dbcHandle) {
ThrowStdException("Connection handle not allocated");
}
updateLastUsed();
LOG("Allocating statement handle");
SQLHANDLE stmt = nullptr;
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt);
Expand Down Expand Up @@ -185,4 +197,105 @@ void Connection::applyAttrsBefore(const py::dict& attrs) {
}
}
}
}

bool Connection::isAlive() const {
if (!_dbcHandle) {
ThrowStdException("Connection handle not allocated");
}
SQLUINTEGER status;
SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_CONNECTION_DEAD,
&status, 0, nullptr);
return SQL_SUCCEEDED(ret) && status == SQL_CD_FALSE;
}

bool Connection::reset() {
if (!_dbcHandle) {
ThrowStdException("Connection handle not allocated");
}
LOG("Resetting connection via SQL_ATTR_RESET_CONNECTION");
SQLULEN reset = SQL_TRUE;
SQLRETURN ret = SQLSetConnectAttr_ptr(
_dbcHandle->get(),
SQL_ATTR_RESET_CONNECTION,
(SQLPOINTER)SQL_RESET_CONNECTION_YES,
SQL_IS_INTEGER);
if (!SQL_SUCCEEDED(ret)) {
LOG("Failed to reset connection. Marking as dead.");
disconnect();
return false;
}
updateLastUsed();
return true;
}

void Connection::updateLastUsed() {
_lastUsed = std::chrono::steady_clock::now();
}

std::chrono::steady_clock::time_point Connection::lastUsed() const {
return _lastUsed;
}

ConnectionHandle::ConnectionHandle(const std::wstring& connStr, bool usePool, const py::dict& attrsBefore)
: _connStr(connStr), _usePool(usePool) {
if (_usePool) {
_conn = ConnectionPoolManager::getInstance().acquireConnection(connStr, attrsBefore);
} else {
_conn = std::make_shared<Connection>(connStr, false);
_conn->connect(attrsBefore);
}
}

ConnectionHandle::~ConnectionHandle() {
if (_conn) {
close();
}
}

void ConnectionHandle::close() {
if (!_conn) {
ThrowStdException("Connection object is not initialized");
}
if (_usePool) {
ConnectionPoolManager::getInstance().returnConnection(_connStr, _conn);
} else {
_conn->disconnect();
}
_conn = nullptr;
}

void ConnectionHandle::commit() {
if (!_conn) {
ThrowStdException("Connection object is not initialized");
}
_conn->commit();
}

void ConnectionHandle::rollback() {
if (!_conn) {
ThrowStdException("Connection object is not initialized");
}
_conn->rollback();
}

void ConnectionHandle::setAutocommit(bool enabled) {
if (!_conn) {
ThrowStdException("Connection object is not initialized");
}
_conn->setAutocommit(enabled);
}

bool ConnectionHandle::getAutocommit() const {
if (!_conn) {
ThrowStdException("Connection object is not initialized");
}
return _conn->getAutocommit();
}

SqlHandlePtr ConnectionHandle::allocStatementHandle() {
if (!_conn) {
ThrowStdException("Connection object is not initialized");
}
return _conn->allocStatementHandle();
}
30 changes: 26 additions & 4 deletions mssql_python/pybind/connection/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

class Connection {
public:
Connection(const std::wstring& conn_str, bool autocommit = false, bool use_pooling = false);
Connection(const std::wstring& connStr, bool fromPool);

~Connection();

// Establish the connection using the stored connection string.
Expand All @@ -33,6 +34,10 @@ class Connection {

// Check whether autocommit is enabled.
bool getAutocommit() const;
bool isAlive() const;
bool reset();
void updateLastUsed();
std::chrono::steady_clock::time_point lastUsed() const;

// Allocate a new statement handle on this connection.
SqlHandlePtr allocStatementHandle();
Expand All @@ -44,9 +49,26 @@ class Connection {
void applyAttrsBefore(const py::dict& attrs_before);

std::wstring _connStr;
bool _usePool = false;
bool _fromPool = false;
bool _autocommit = true;
SqlHandlePtr _dbcHandle;

static SqlHandlePtr _envHandle;
std::chrono::steady_clock::time_point _lastUsed;
};

class ConnectionHandle {
public:
ConnectionHandle(const std::wstring& connStr, bool usePool, const py::dict& attrsBefore = py::dict());
~ConnectionHandle();

void close();
void commit();
void rollback();
void setAutocommit(bool enabled);
bool getAutocommit() const;
SqlHandlePtr allocStatementHandle();

private:
std::shared_ptr<Connection> _conn;
bool _usePool;
std::wstring _connStr;
};
114 changes: 114 additions & 0 deletions mssql_python/pybind/connection/connection_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be
// taken up in future.

#include "connection_pool.h"
#include <iostream>
#include <exception>

ConnectionPool::ConnectionPool(size_t max_size, int idle_timeout_secs)
: _max_size(max_size), _idle_timeout_secs(idle_timeout_secs), _current_size(0) {}

std::shared_ptr<Connection> ConnectionPool::acquire(const std::wstring& connStr, const py::dict& attrs_before) {
std::vector<std::shared_ptr<Connection>> to_disconnect;
std::shared_ptr<Connection> valid_conn = nullptr;
{
std::lock_guard<std::mutex> lock(_mutex);
auto now = std::chrono::steady_clock::now();
size_t before = _pool.size();

// Phase 1: Remove stale connections, collect for later disconnect
_pool.erase(std::remove_if(_pool.begin(), _pool.end(),
[&](const std::shared_ptr<Connection>& conn) {
auto idle_time = std::chrono::duration_cast<std::chrono::seconds>(now - conn->lastUsed()).count();
if (idle_time > _idle_timeout_secs) {
to_disconnect.push_back(conn);
return true;
}
return false;
}), _pool.end());

size_t pruned = before - _pool.size();
_current_size = (_current_size >= pruned) ? (_current_size - pruned) : 0;

// Phase 2: Attempt to reuse healthy connections
while (!_pool.empty()) {
auto conn = _pool.front();
_pool.pop_front();
if (conn->isAlive()) {
if (!conn->reset()) {
to_disconnect.push_back(conn);
--_current_size;
continue;
}
valid_conn = conn;
break;
} else {
to_disconnect.push_back(conn);
--_current_size;
}
}

// Create new connection if none reusable
if (!valid_conn && _current_size < _max_size) {
valid_conn = std::make_shared<Connection>(connStr, true);
valid_conn->connect(attrs_before);
++_current_size;
} else if (!valid_conn) {
throw std::runtime_error("ConnectionPool::acquire: pool size limit reached");
}
}

// Phase 3: Disconnect expired/bad connections outside lock
for (auto& conn : to_disconnect) {
try {
conn->disconnect();
} catch (const std::exception& ex) {
std::cout << "disconnect() failed: " << ex.what() << std::endl;
}
}
return valid_conn;
}

void ConnectionPool::release(std::shared_ptr<Connection> conn) {
std::lock_guard<std::mutex> lock(_mutex);
if (_pool.size() < _max_size) {
conn->updateLastUsed();
_pool.push_back(conn);
}
else {
conn->disconnect();
if (_current_size > 0) --_current_size;
}
}

ConnectionPoolManager& ConnectionPoolManager::getInstance() {
static ConnectionPoolManager manager;
return manager;
}

std::shared_ptr<Connection> ConnectionPoolManager::acquireConnection(const std::wstring& connStr, const py::dict& attrs_before) {
std::lock_guard<std::mutex> lock(_manager_mutex);

auto& pool = _pools[connStr];
if (!pool) {
LOG("Creating new connection pool");
pool = std::make_shared<ConnectionPool>(_default_max_size, _default_idle_secs);
}
return pool->acquire(connStr, attrs_before);
}

void ConnectionPoolManager::returnConnection(const std::wstring& conn_str, const std::shared_ptr<Connection> conn) {
std::lock_guard<std::mutex> lock(_manager_mutex);
if (_pools.find(conn_str) != _pools.end()) {
_pools[conn_str]->release((conn));
}
}

void ConnectionPoolManager::configure(int max_size, int idle_timeout_secs) {
std::lock_guard<std::mutex> lock(_manager_mutex);
_default_max_size = max_size;
_default_idle_secs = idle_timeout_secs;
}
Loading