Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 101 additions & 18 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Connection:
close() -> None:
"""

def __init__(self, connection_str: str, autocommit: bool = False, **kwargs) -> None:
def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None:
"""
Initialize the connection object with the specified connection string and parameters.

Expand All @@ -58,11 +58,12 @@ def __init__(self, connection_str: str, autocommit: bool = False, **kwargs) -> N
self.connection_str = self._construct_connection_string(
connection_str, **kwargs
)
self._attrs_before = attrs_before
self._autocommit = autocommit # Initialize _autocommit before calling _initializer
self._initializer()
self._autocommit = autocommit
self.setautocommit(autocommit)

def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str:
"""
Construct the connection string by concatenating the connection string
with key/value pairs from kwargs.
Expand All @@ -76,13 +77,14 @@ def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
"""
# Add the driver attribute to the connection string
conn_str = add_driver_to_connection_str(connection_str)

# Add additional key-value pairs to the connection string
for key, value in kwargs.items():
if key.lower() == "host":
if key.lower() == "host" or key.lower() == "server":
key = "Server"
elif key.lower() == "user":
elif key.lower() == "user" or key.lower() == "uid":
key = "Uid"
elif key.lower() == "password":
elif key.lower() == "password" or key.lower() == "pwd":
key = "Pwd"
elif key.lower() == "database":
key = "Database"
Expand All @@ -93,6 +95,11 @@ def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
else:
continue
conn_str += f"{key}={value};"
print(f"Connection string after adding driver: {conn_str}")

if ENABLE_LOGGING:
logger.info("Final connection string: %s", conn_str)

return conn_str

def _is_closed(self) -> bool:
Expand All @@ -103,7 +110,7 @@ def _is_closed(self) -> bool:
bool: True if the connection is closed, False otherwise.
"""
return self.hdbc is None

def _initializer(self) -> None:
"""
Initialize the environment and connection handles.
Expand All @@ -115,9 +122,79 @@ def _initializer(self) -> None:
self._allocate_environment_handle()
self._set_environment_attributes()
self._allocate_connection_handle()
self._set_connection_attributes()
if self._attrs_before != {}:
self._apply_attrs_before() # Apply pre-connection attributes
if self._autocommit:
self._set_connection_attributes(
ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value,
ddbc_sql_const.SQL_AUTOCOMMIT_ON.value,
)
self._connect_to_db()

def _apply_attrs_before(self):
"""
Apply specific pre-connection attributes.
Currently, this method only processes an attribute with key 1256 (e.g., SQL_COPT_SS_ACCESS_TOKEN)
if present in `self._attrs_before`. Other attributes are ignored.

Returns:
bool: True.
"""

if ENABLE_LOGGING:
logger.info("Attempting to apply pre-connection attributes (attrs_before): %s", self._attrs_before)

if not isinstance(self._attrs_before, dict):
if self._attrs_before is not None and ENABLE_LOGGING:
logger.warning(
f"_attrs_before is of type {type(self._attrs_before).__name__}, "
f"expected dict. Skipping attribute application."
)
elif self._attrs_before is None and ENABLE_LOGGING:
logger.debug("_attrs_before is None. No pre-connection attributes to apply.")
return True # Exit if _attrs_before is not a dictionary or is None

for key, value in self._attrs_before.items():
ikey = None
if isinstance(key, int):
ikey = key
elif isinstance(key, str) and key.isdigit():
try:
ikey = int(key)
except ValueError:
if ENABLE_LOGGING:
logger.debug(
f"Skipping attribute with key '{key}' in attrs_before: "
f"could not convert string to int."
)
continue # Skip if string key is not a valid integer
else:
if ENABLE_LOGGING:
logger.debug(
f"Skipping attribute with key '{key}' in attrs_before due to "
f"unsupported key type: {type(key).__name__}. Expected int or string representation of an int."
)
continue # Skip keys that are not int or string representation of an int

if ikey == ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value:
if ENABLE_LOGGING:
logger.info(
f"Found attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value}. Attempting to set it."
)
self._set_connection_attributes(ikey, value)
if ENABLE_LOGGING:
logger.info(
f"Call to set attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value} with value '{value}' completed."
)
# If you expect only one such key, you could add 'break' here.
else:
if ENABLE_LOGGING:
logger.debug(
f"Ignoring attribute with key '{key}' (resolved to {ikey}) in attrs_before "
f"as it is not the target attribute ({ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value})."
)
return True

def _allocate_environment_handle(self):
"""
Allocate the environment handle.
Expand Down Expand Up @@ -152,18 +229,25 @@ def _allocate_connection_handle(self):
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, handle, ret)
self.hdbc = handle

def _set_connection_attributes(self):
def _set_connection_attributes(self, ikey: int, ivalue: any) -> None:
"""
Set the connection attributes before connecting.

Args:
ikey (int): The attribute key to set.
ivalue (Any): The value to set for the attribute. Can be bytes, bytearray, int, or unicode.
vallen (int): The length of the value.

Raises:
DatabaseError: If there is an error while setting the connection attribute.
"""
if self.autocommit:
ret = ddbc_bindings.DDBCSQLSetConnectAttr(
self.hdbc, # Using the wrapper class
ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value,
ddbc_sql_const.SQL_AUTOCOMMIT_ON.value,
0
)
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)

ret = ddbc_bindings.DDBCSQLSetConnectAttr(
self.hdbc, # Connection handle
ikey, # Attribute
ivalue, # Value
)
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)

def _connect_to_db(self) -> None:
"""
Expand Down Expand Up @@ -224,7 +308,6 @@ def autocommit(self, value: bool) -> None:
if value
else ddbc_sql_const.SQL_AUTOCOMMIT_OFF.value
), # Value
0, # String length
)
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)
self._autocommit = value
Expand Down
2 changes: 2 additions & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,5 @@ class ConstantsDDBC(Enum):
SQL_C_WCHAR = -8
SQL_NULLABLE = 1
SQL_MAX_NUMERIC_LEN = 16
SQL_IS_POINTER = -4
SQL_COPT_SS_ACCESS_TOKEN = 1256
4 changes: 2 additions & 2 deletions mssql_python/db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mssql_python.connection import Connection


def connect(connection_str: str, autocommit: bool = True, **kwargs) -> Connection:
def connect(connection_str: str = "", autocommit: bool = True, attrs_before: dict = None, **kwargs) -> Connection:
"""
Constructor for creating a connection to the database.

Expand Down Expand Up @@ -34,5 +34,5 @@ def connect(connection_str: str, autocommit: bool = True, **kwargs) -> Connectio
be used to perform database operations such as executing queries, committing
transactions, and closing the connection.
"""
conn = Connection(connection_str, autocommit=autocommit, **kwargs)
conn = Connection(connection_str, autocommit=autocommit, attrs_before=attrs_before, **kwargs)
return conn
1 change: 1 addition & 0 deletions mssql_python/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def add_driver_to_connection_str(connection_str):
# Insert the driver attribute at the beginning of the connection string
final_connection_attributes.insert(0, driver_name)
connection_str = ";".join(final_connection_attributes)
print(f"Connection string after adding driver: {connection_str}")
except Exception as e:
raise Exception(
"Invalid connection string, Please follow the format: "
Expand Down
48 changes: 44 additions & 4 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,18 +692,58 @@ SQLRETURN SQLSetEnvAttr_wrap(SqlHandlePtr EnvHandle, SQLINTEGER Attribute, intpt
}

// Wrap SQLSetConnectAttr
SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, intptr_t ValuePtr,
SQLINTEGER StringLength) {
SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute,
py::object ValuePtr) {
LOG("Set SQL Connection Attribute");
if (!SQLSetConnectAttr_ptr) {
LoadDriverOrThrowException();
}

// TODO: Does ValuePtr need to be converted from Python to C++ object?
SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, reinterpret_cast<SQLPOINTER>(ValuePtr), StringLength);
// Print the type of ValuePtr and attribute value - helpful for debugging
LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast<std::string>(), Attribute);

SQLPOINTER value = 0;
SQLINTEGER length = 0;

if (py::isinstance<py::int_>(ValuePtr)) {
// Handle integer values
int intValue = ValuePtr.cast<int>();
value = reinterpret_cast<SQLPOINTER>(intValue);
length = SQL_IS_INTEGER; // Integer values don't require a length
// } else if (py::isinstance<py::str>(ValuePtr)) {
// // Handle Unicode string values
// static std::wstring unicodeValueBuffer;
// unicodeValueBuffer = ValuePtr.cast<std::wstring>();
// value = const_cast<SQLWCHAR*>(unicodeValueBuffer.c_str());
// length = SQL_NTS; // Indicates null-terminated string
} else if (py::isinstance<py::bytes>(ValuePtr) || py::isinstance<py::bytearray>(ValuePtr)) {
// Handle byte or bytearray values (like access tokens)
// Store in static buffer to ensure memory remains valid during connection
static std::vector<std::string> bytesBuffers;
bytesBuffers.push_back(ValuePtr.cast<std::string>());
value = const_cast<char*>(bytesBuffers.back().c_str());
length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token)
// } else if (py::isinstance<py::list>(ValuePtr) || py::isinstance<py::tuple>(ValuePtr)) {
// // Handle list or tuple values
// LOG("ValuePtr is a sequence (list or tuple)");
// for (py::handle item : ValuePtr) {
// LOG("Processing item in sequence");
// SQLRETURN ret = SQLSetConnectAttr_wrap(ConnectionHandle, Attribute, py::reinterpret_borrow<py::object>(item));
// if (!SQL_SUCCEEDED(ret)) {
// LOG("Failed to set attribute for item in sequence");
// return ret;
// }
// }
} else {
LOG("Unsupported ValuePtr type");
return SQL_ERROR;
}

SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length);
if (!SQL_SUCCEEDED(ret)) {
LOG("Failed to set Connection attribute");
}
LOG("Set Connection attribute successfully");
return ret;
}

Expand Down
28 changes: 27 additions & 1 deletion tests/test_003_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,33 @@ def test_connection(db_connection):

def test_construct_connection_string(db_connection):
# Check if the connection string is constructed correctly with kwargs
conn_str = db_connection._construct_connection_string("",host="localhost", user="me", password="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes")
conn_str = db_connection._construct_connection_string(host="localhost", user="me", password="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes")
assert "Server=localhost;" in conn_str, "Connection string should contain 'Server=localhost;'"
assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'"
assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'"
assert "Database=mydb;" in conn_str, "Connection string should contain 'Database=mydb;'"
assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'"
assert "TrustServerCertificate=yes;" in conn_str, "Connection string should contain 'TrustServerCertificate=yes;'"
assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'"
assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'"
assert "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"

def test_connection_string_with_attrs_before(db_connection):
# Check if the connection string is constructed correctly with attrs_before
conn_str = db_connection._construct_connection_string(host="localhost", user="me", password="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes", attrs_before={1256: "token"})
assert "Server=localhost;" in conn_str, "Connection string should contain 'Server=localhost;'"
assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'"
assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'"
assert "Database=mydb;" in conn_str, "Connection string should contain 'Database=mydb;'"
assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'"
assert "TrustServerCertificate=yes;" in conn_str, "Connection string should contain 'TrustServerCertificate=yes;'"
assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'"
assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'"
assert "{1256: token}" not in conn_str, "Connection string should not contain '{1256: token}'"

def test_connection_string_with_odbc_param(db_connection):
# Check if the connection string is constructed correctly with ODBC parameters
conn_str = db_connection._construct_connection_string(server="localhost", uid="me", pwd="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes")
assert "Server=localhost;" in conn_str, "Connection string should contain 'Server=localhost;'"
assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'"
assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'"
Expand Down