diff --git a/eng/pipelines/build-whl-pipeline.yml b/eng/pipelines/build-whl-pipeline.yml index a22edd08..a6540c8a 100644 --- a/eng/pipelines/build-whl-pipeline.yml +++ b/eng/pipelines/build-whl-pipeline.yml @@ -340,7 +340,7 @@ jobs: python -m pytest -v displayName: 'Run Pytest to validate bindings' env: - DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=tcp:127.0.0.1,1433;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_CONNECTION_STRING: 'Server=tcp:127.0.0.1,1433;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' # Build wheel package for universal2 - script: | @@ -801,7 +801,7 @@ jobs: displayName: 'Test wheel installation and basic functionality on $(BASE_IMAGE)' env: - DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=localhost;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_CONNECTION_STRING: 'Server=localhost;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' # Run pytest with source code while testing installed wheel - script: | @@ -856,7 +856,7 @@ jobs: " displayName: 'Run pytest suite on $(BASE_IMAGE) $(ARCH)' env: - DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=localhost;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_CONNECTION_STRING: 'Server=localhost;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' continueOnError: true # Don't fail pipeline if tests fail # Cleanup diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index b1bd7e3b..b0ae7fdd 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -25,6 +25,9 @@ NotSupportedError, ) +# Connection string parser exceptions +from .exceptions import ConnectionStringParseError + # Type Objects from .type import ( Date, diff --git a/mssql_python/connection.py b/mssql_python/connection.py index f0663d72..bbce8a94 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -242,39 +242,66 @@ def _construct_connection_string( self, connection_str: str = "", **kwargs: Any ) -> str: """ - Construct the connection string by concatenating the connection string - with key/value pairs from kwargs. - + Construct the connection string by parsing, validating, and merging parameters. + + This method performs a 6-step process: + 1. Parse and validate the base connection_str (validates against allowlist) + 2. Normalize parameter names (e.g., addr/address -> Server, uid -> UID) + 3. Merge kwargs (which override connection_str params after normalization) + 4. Build connection string from normalized, merged params + 5. Add Driver and APP parameters (always controlled by the driver) + 6. Return the final connection string + Args: connection_str (str): The base connection string. **kwargs: Additional key/value pairs for the connection string. Returns: - str: The constructed connection string. + str: The constructed and validated connection string. """ - # Add the driver attribute to the connection string - conn_str = add_driver_to_connection_str(connection_str) - - # Add additional key-value pairs to the connection string + from mssql_python.connection_string_parser import _ConnectionStringParser, RESERVED_PARAMETERS + from mssql_python.constants import _ConnectionStringAllowList + from mssql_python.connection_string_builder import _ConnectionStringBuilder + + # Step 1: Parse base connection string with allowlist validation + # The parser validates everything: unknown params, reserved params, duplicates, syntax + allowlist = _ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + parsed_params = parser._parse(connection_str) + + # Step 2: Normalize parameter names (e.g., addr/address -> Server, uid -> UID) + # This handles synonym mapping and deduplication via normalized keys + normalized_params = _ConnectionStringAllowList._normalize_params(parsed_params, warn_rejected=False) + + # Step 3: Process kwargs and merge with normalized_params + # kwargs override connection string values (processed after, so they take precedence) for key, value in kwargs.items(): - if key.lower() == "host" or key.lower() == "server": - key = "Server" - elif key.lower() == "user" or key.lower() == "uid": - key = "Uid" - elif key.lower() == "password" or key.lower() == "pwd": - key = "Pwd" - elif key.lower() == "database": - key = "Database" - elif key.lower() == "encrypt": - key = "Encrypt" - elif key.lower() == "trust_server_certificate": - key = "TrustServerCertificate" + normalized_key = _ConnectionStringAllowList.normalize_key(key) + if normalized_key: + # Driver and APP are reserved - raise error if user tries to set them + if normalized_key in RESERVED_PARAMETERS: + raise ValueError( + f"Connection parameter '{key}' is reserved and controlled by the driver. " + f"It cannot be set by the user." + ) + # kwargs override any existing values from connection string + normalized_params[normalized_key] = str(value) else: - continue - conn_str += f"{key}={value};" - - log("info", "Final connection string: %s", sanitize_connection_string(conn_str)) - + log('warning', f"Ignoring unknown connection parameter from kwargs: {key}") + + # Step 4: Build connection string with merged params + builder = _ConnectionStringBuilder(normalized_params) + + # Step 5: Add Driver and APP parameters (always controlled by the driver) + # These maintain existing behavior: Driver is always hardcoded, APP is always MSSQL-Python + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + builder.add_param('APP', 'MSSQL-Python') + + # Step 6: Build final string + conn_str = builder.build() + + log('info', "Final connection string: %s", sanitize_connection_string(conn_str)) + return conn_str @property diff --git a/mssql_python/connection_string_builder.py b/mssql_python/connection_string_builder.py new file mode 100644 index 00000000..76462520 --- /dev/null +++ b/mssql_python/connection_string_builder.py @@ -0,0 +1,113 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Connection string builder for mssql-python. + +Reconstructs ODBC connection strings from parameter dictionaries +with proper escaping and formatting per MS-ODBCSTR specification. +""" + +from typing import Dict, Optional + + +class _ConnectionStringBuilder: + """ + Internal builder for ODBC connection strings. Not part of public API. + + Handles proper escaping of special characters and reconstructs + connection strings in ODBC format. + """ + + def __init__(self, initial_params: Optional[Dict[str, str]] = None): + """ + Initialize the builder with optional initial parameters. + + Args: + initial_params: Dictionary of initial connection parameters + """ + self._params: Dict[str, str] = initial_params.copy() if initial_params else {} + + def add_param(self, key: str, value: str) -> '_ConnectionStringBuilder': + """ + Add or update a connection parameter. + + Args: + key: Parameter name (should be normalized canonical name) + value: Parameter value + + Returns: + Self for method chaining + """ + self._params[key] = str(value) + return self + + def build(self) -> str: + """ + Build the final connection string. + + Returns: + ODBC-formatted connection string with proper escaping + + Note: + - Driver parameter is placed first + - Other parameters are sorted for consistency + - Values are escaped if they contain special characters + """ + parts = [] + + # Build in specific order: Driver first, then others + if 'Driver' in self._params: + parts.append(f"Driver={self._escape_value(self._params['Driver'])}") + + # Add other parameters (sorted for consistency) + for key in sorted(self._params.keys()): + if key == 'Driver': + continue # Already added + + value = self._params[key] + escaped_value = self._escape_value(value) + parts.append(f"{key}={escaped_value}") + + # Join with semicolons + return ';'.join(parts) + + def _escape_value(self, value: str) -> str: + """ + Escape a parameter value if it contains special characters. + + Per MS-ODBCSTR specification: + - Values containing ';', '{', '}', '=', or spaces should be braced for safety + - '}' inside braced values is escaped as '}}' + - '{' inside braced values is escaped as '{{' + + Args: + value: Parameter value to escape + + Returns: + Escaped value (possibly wrapped in braces) + + Examples: + >>> builder = _ConnectionStringBuilder() + >>> builder._escape_value("localhost") + 'localhost' + >>> builder._escape_value("local;host") + '{local;host}' + >>> builder._escape_value("p}w{d") + '{p}}w{{d}' + >>> builder._escape_value("ODBC Driver 18 for SQL Server") + '{ODBC Driver 18 for SQL Server}' + """ + if not value: + return value + + # Check if value contains special characters that require bracing + # Include spaces and = for safety, even though technically not always required + needs_braces = any(ch in value for ch in ';{}= ') + + if needs_braces: + # Escape existing braces by doubling them + escaped = value.replace('}', '}}').replace('{', '{{') + return f'{{{escaped}}}' + else: + return value diff --git a/mssql_python/connection_string_parser.py b/mssql_python/connection_string_parser.py new file mode 100644 index 00000000..2fde2b13 --- /dev/null +++ b/mssql_python/connection_string_parser.py @@ -0,0 +1,306 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +ODBC connection string parser for mssql-python. + +Handles ODBC-specific syntax per MS-ODBCSTR specification: +- Semicolon-separated key=value pairs +- Braced values: {value} +- Escaped braces: }} → }, {{ → { + +Parser behavior: +- Validates all key=value pairs +- Raises exceptions for malformed syntax (missing values, unknown keywords, duplicates) +- Collects all errors and reports them together +""" + +from typing import Dict, Tuple +from mssql_python.exceptions import ConnectionStringParseError + + +# Reserved connection string parameters that are controlled by the driver +# and cannot be set by users +RESERVED_PARAMETERS = ('Driver', 'APP') + + +class _ConnectionStringParser: + """ + Internal parser for ODBC connection strings. Not part of public API. + + Implements the ODBC Connection String format as specified in MS-ODBCSTR. + Handles braced values, escaped characters, and proper tokenization. + + Validates connection strings and raises errors for: + - Unknown/unrecognized keywords + - Duplicate keywords + - Incomplete specifications (keyword with no value) + + Reference: https://learn.microsoft.com/en-us/openspecs/sql_server_protocols/ms-odbcstr/55953f0e-2d30-4ad4-8e56-b4207e491409 + """ + + def __init__(self, allowlist=None): + """ + Initialize the parser. + + Args: + allowlist: Optional _ConnectionStringAllowList instance for keyword validation. + If None, no keyword validation is performed. + """ + self._allowlist = allowlist + + def _parse(self, connection_str: str) -> Dict[str, str]: + """ + Parse a connection string into a dictionary of parameters. + + Validates the connection string and raises ConnectionStringParseError + if any issues are found (unknown keywords, duplicates, missing values). + + Args: + connection_str: ODBC-format connection string + + Returns: + Dictionary mapping parameter names (lowercase) to values + + Raises: + ConnectionStringParseError: If validation errors are found + + Examples: + >>> parser = _ConnectionStringParser() + >>> result = parser._parse("Server=localhost;Database=mydb") + {'server': 'localhost', 'database': 'mydb'} + + >>> parser._parse("Server={;local;};PWD={p}}w{{d}") + {'server': ';local;', 'pwd': 'p}w{d'} + + >>> parser._parse("Server=localhost;Server=other") + ConnectionStringParseError: Duplicate keyword 'server' + """ + if not connection_str: + return {} + + connection_str = connection_str.strip() + if not connection_str: + return {} + + # Collect all errors for batch reporting + errors = [] + + # Dictionary to store parsed key=value pairs + params = {} + + # Track which keys we've seen to detect duplicates + seen_keys = {} # Maps normalized key -> first occurrence position + + # Track current position in the string + current_pos = 0 + str_len = len(connection_str) + + # Main parsing loop + while current_pos < str_len: + # Skip leading whitespace and semicolons + while current_pos < str_len and connection_str[current_pos] in ' \t;': + current_pos += 1 + + if current_pos >= str_len: + break + + # Parse the key + key_start = current_pos + + # Advance until we hit '=', ';', or end of string + while current_pos < str_len and connection_str[current_pos] not in '=;': + current_pos += 1 + + # Check if we found a valid '=' separator + if current_pos >= str_len or connection_str[current_pos] != '=': + # ERROR: No '=' found - incomplete specification + incomplete_text = connection_str[key_start:current_pos].strip() + if incomplete_text: + errors.append(f"Incomplete specification: keyword '{incomplete_text}' has no value (missing '=')") + # Skip to next semicolon + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + continue + + # Extract and normalize the key + key = connection_str[key_start:current_pos].strip().lower() + + # ERROR: Empty key + if not key: + errors.append("Empty keyword found (format: =value)") + current_pos += 1 # Skip the '=' + # Skip to next semicolon + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + continue + + # Move past the '=' + current_pos += 1 + + # Parse the value + try: + value, current_pos = self._parse_value(connection_str, current_pos) + + # ERROR: Empty value + if not value: + errors.append(f"Empty value for keyword '{key}' (all connection string parameters must have non-empty values)") + + # Check for duplicates + if key in seen_keys: + errors.append(f"Duplicate keyword '{key}' found") + else: + seen_keys[key] = True + params[key] = value + + except ValueError as e: + errors.append(f"Error parsing value for keyword '{key}': {e}") + # Skip to next semicolon + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + + # Validate keywords against allowlist if provided + if self._allowlist: + unknown_keys = [] + reserved_keys = [] + + for key in params.keys(): + # Check if this key can be normalized (i.e., it's known) + normalized_key = self._allowlist.normalize_key(key) + + if normalized_key is None: + # Unknown keyword + unknown_keys.append(key) + elif normalized_key in RESERVED_PARAMETERS: + # Reserved keyword - user cannot set these + reserved_keys.append(key) + + if reserved_keys: + for key in reserved_keys: + errors.append( + f"Reserved keyword '{key}' is controlled by the driver and cannot be specified by the user" + ) + + if unknown_keys: + for key in unknown_keys: + errors.append(f"Unknown keyword '{key}' is not recognized") + + # If we collected any errors, raise them all together + if errors: + raise ConnectionStringParseError(errors) + + return params + + def _parse_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a parameter value from the connection string. + + Handles both simple values and braced values with escaping. + + Args: + connection_str: The connection string + start_pos: Starting position of the value + + Returns: + Tuple of (parsed_value, new_position) + + Raises: + ValueError: If braced value is not properly closed + """ + str_len = len(connection_str) + + # Skip leading whitespace before the value + while start_pos < str_len and connection_str[start_pos] in ' \t': + start_pos += 1 + + # If we've consumed the entire string or reached a semicolon, return empty value + if start_pos >= str_len: + return '', start_pos + + # Determine if this is a braced value or simple value + if connection_str[start_pos] == '{': + return self._parse_braced_value(connection_str, start_pos) + else: + return self._parse_simple_value(connection_str, start_pos) + + def _parse_simple_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a simple (non-braced) value up to the next semicolon. + + Args: + connection_str: The connection string + start_pos: Starting position of the value + + Returns: + Tuple of (parsed_value, new_position) + """ + str_len = len(connection_str) + value_start = start_pos + + # Read characters until we hit a semicolon or end of string + while start_pos < str_len and connection_str[start_pos] != ';': + start_pos += 1 + + # Extract the value and strip trailing whitespace + value = connection_str[value_start:start_pos].rstrip() + return value, start_pos + + def _parse_braced_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a braced value with proper handling of escaped braces. + + Braced values: + - Start with '{' and end with '}' + - '}' inside the value is escaped as '}}' + - '{' inside the value is escaped as '{{' + - Can contain semicolons and other special characters + + Args: + connection_str: The connection string + start_pos: Starting position (should point to opening '{') + + Returns: + Tuple of (parsed_value, new_position) + + Raises: + ValueError: If the braced value is not closed (missing '}') + """ + str_len = len(connection_str) + brace_start_pos = start_pos + + # Skip the opening '{' + start_pos += 1 + + # Build the value character by character + value = [] + + while start_pos < str_len: + ch = connection_str[start_pos] + + if ch == '}': + # Check if next character is also '}' (escaped brace) + if start_pos + 1 < str_len and connection_str[start_pos + 1] == '}': + # Escaped right brace: '}}' → '}' + value.append('}') + start_pos += 2 + else: + # Single '}' means end of braced value + start_pos += 1 + return ''.join(value), start_pos + elif ch == '{': + # Check if it's an escaped left brace + if start_pos + 1 < str_len and connection_str[start_pos + 1] == '{': + # Escaped left brace: '{{' → '{' + value.append('{') + start_pos += 2 + else: + # Single '{' inside braced value - keep it as is + value.append(ch) + start_pos += 1 + else: + # Regular character + value.append(ch) + start_pos += 1 + + # Reached end without finding closing '}' + raise ValueError(f"Unclosed braced value starting at position {brace_start_pos}") diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 785d75e6..e8beee5b 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -5,6 +5,7 @@ """ from enum import Enum +from typing import Dict, Optional class ConstantsDDBC(Enum): @@ -443,3 +444,164 @@ def get_attribute_set_timing(attribute): AttributeSetTime: When the attribute can be set """ return ATTRIBUTE_SET_TIMING.get(attribute, AttributeSetTime.AFTER_ONLY) + + +# Import RESERVED_PARAMETERS from parser module to maintain single source of truth +def _get_reserved_parameters(): + """Lazy import to avoid circular dependency.""" + from mssql_python.connection_string_parser import RESERVED_PARAMETERS + return RESERVED_PARAMETERS + + +class _ConnectionStringAllowList: + """ + Manages the allow-list of permitted connection string parameters. + + This class implements a deliberate allow-list approach to exposing + connection string parameters, enabling: + - Incremental ODBC parity while maintaining backward compatibility + - Forward compatibility with future driver enhancements + - Simplified API by normalizing parameter synonyms + """ + + # Core connection parameters with synonym mapping + # Maps lowercase parameter names to their canonical form + # Based on ODBC Driver 18 for SQL Server supported parameters + # A new connection string key to be supported in Python, should be added + # to the dictionary below. the value is the canonical name used in the + # final connection string sent to ODBC driver. + # The left side is what Python connection string supports, the right side + # is the canonical ODBC key name. + ALLOWED_PARAMS = { + # Server identification - addr, address, and server are synonyms + 'server': 'Server', + 'address': 'Server', + 'addr': 'Server', + + # Authentication + 'uid': 'UID', + 'pwd': 'PWD', + 'authentication': 'Authentication', + 'trusted_connection': 'Trusted_Connection', + + # Database + 'database': 'Database', + + # Driver (always controlled by mssql-python) + 'driver': 'Driver', + + # Application name (always controlled by mssql-python) + 'app': 'APP', + + # Encryption and Security + 'encrypt': 'Encrypt', + 'trustservercertificate': 'TrustServerCertificate', + 'trust_server_certificate': 'TrustServerCertificate', # Snake_case synonym + 'hostnameincertificate': 'HostnameInCertificate', # v18.0+ + 'servercertificate': 'ServerCertificate', # v18.1+ + 'serverspn': 'ServerSPN', + + # Connection behavior + 'multisubnetfailover': 'MultiSubnetFailover', + 'applicationintent': 'ApplicationIntent', + 'connectretrycount': 'ConnectRetryCount', + 'connectretryinterval': 'ConnectRetryInterval', + + # Keep-Alive (v17.4+) + 'keepalive': 'KeepAlive', + 'keepaliveinterval': 'KeepAliveInterval', + + # IP Address Preference (v18.1+) + 'ipaddresspreference': 'IpAddressPreference', + + 'packet size': 'PacketSize', # From the tests it looks like pyodbc users use Packet Size + # (with spaces) ODBC only honors "PacketSize" without spaces + # internally. + 'packetsize': 'PacketSize', + } + + @classmethod + def normalize_key(cls, key: str) -> Optional[str]: + """ + Normalize a parameter key to its canonical form. + + Args: + key: Parameter key from connection string (case-insensitive) + + Returns: + Canonical parameter name if allowed, None otherwise + + Examples: + >>> _ConnectionStringAllowList.normalize_key('SERVER') + 'Server' + >>> _ConnectionStringAllowList.normalize_key('user') + 'Uid' + >>> _ConnectionStringAllowList.normalize_key('UnsupportedParam') + None + """ + key_lower = key.lower().strip() + return cls.ALLOWED_PARAMS.get(key_lower) + + @staticmethod + def _normalize_params(params: Dict[str, str], warn_rejected: bool = True) -> Dict[str, str]: + """ + Normalize and filter parameters against the allow-list (internal use only). + + This method performs several operations: + - Normalizes parameter names (e.g., addr/address → Server, uid → UID) + - Filters out parameters not in the allow-list + - Removes reserved parameters (Driver, APP) + - Deduplicates via normalized keys + + Args: + params: Dictionary of connection string parameters (keys should be lowercase) + warn_rejected: Whether to log warnings for rejected parameters + + Returns: + Dictionary containing only allowed parameters with normalized keys + + Note: + Driver and APP parameters are filtered here but will be set by + the driver in _construct_connection_string to maintain control. + """ + # Import here to avoid circular dependency issues + try: + from mssql_python.logging_config import get_logger + from mssql_python.helpers import sanitize_user_input + logger = get_logger() + except ImportError: + logger = None + sanitize_user_input = lambda x: str(x)[:50] # Simple fallback + + filtered = {} + + # The rejected list should ideally be empty when used in the normal connection + # flow, since the parser validates against the allowlist first and raises + # errors for unknown parameters. This filtering is primarily a safety net. + rejected = [] + + reserved_params = _get_reserved_parameters() + + for key, value in params.items(): + normalized_key = _ConnectionStringAllowList.normalize_key(key) + + if normalized_key: + # Skip Driver and APP - these are controlled by the driver + if normalized_key in reserved_params: + continue + + # Parameter is allowed + filtered[normalized_key] = value + else: + # Parameter is not in allow-list + # Note: In normal flow, this should be empty since parser validates first + rejected.append(key) + + # Log all rejected parameters together if any were found + if rejected and warn_rejected and logger: + safe_keys = [sanitize_user_input(key) for key in rejected] + logger.warning( + f"Connection string parameters not in allow-list and will be ignored: {', '.join(safe_keys)}" + ) + + return filtered diff --git a/mssql_python/exceptions.py b/mssql_python/exceptions.py index ff2283f4..b7d2950d 100644 --- a/mssql_python/exceptions.py +++ b/mssql_python/exceptions.py @@ -7,11 +7,33 @@ from typing import Optional from mssql_python.logging_config import get_logger +import builtins logger = get_logger() -class Exception(Exception): +class ConnectionStringParseError(builtins.Exception): + """ + Exception raised when connection string parsing fails. + + This exception is raised when the connection string parser encounters + syntax errors, unknown keywords, duplicate keywords, or other validation + failures. It collects all errors and reports them together. + """ + + def __init__(self, errors: list) -> None: + """ + Initialize the error with a list of validation errors. + + Args: + errors: List of error messages describing what went wrong + """ + self.errors = errors + message = "Connection string parsing failed:\n " + "\n ".join(errors) + super().__init__(message) + + +class Exception(builtins.Exception): """ Base class for all DB API 2.0 exceptions. """ diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index d631ea36..ef3f88fd 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -21,11 +21,9 @@ from mssql_python.exceptions import InterfaceError, ProgrammingError, DatabaseError import mssql_python -import sys import pytest import time -from mssql_python import connect, Connection, pooling, SQL_CHAR, SQL_WCHAR -import threading +from mssql_python import connect, Connection, SQL_CHAR, SQL_WCHAR # Import all exception classes for testing from mssql_python.exceptions import ( @@ -125,104 +123,46 @@ def test_connection(db_connection): def test_construct_connection_string(db_connection): # Check if the connection string is constructed correctly with kwargs - conn_str = db_connection._construct_connection_string( - host="localhost", - user="me", - password="mypwd", - database="mydb", - encrypt="yes", - trust_server_certificate="yes", - ) - assert ( - "Server=localhost;" in conn_str - ), "Connection string should contain 'Server=localhost;'" - assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'" - assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'" - assert ( - "Database=mydb;" in conn_str - ), "Connection string should contain 'Database=mydb;'" - assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'" - assert ( - "TrustServerCertificate=yes;" in conn_str - ), "Connection string should contain 'TrustServerCertificate=yes;'" - assert ( - "APP=MSSQL-Python" in conn_str - ), "Connection string should contain 'APP=MSSQL-Python'" - assert ( - "Driver={ODBC Driver 18 for SQL Server}" in conn_str - ), "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" - assert ( - "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" - == conn_str - ), "Connection string is incorrect" - + # Using official ODBC parameter names + conn_str = db_connection._construct_connection_string(Server="localhost", UID="me", PWD="mypwd", Database="mydb", Encrypt="yes", TrustServerCertificate="yes") + # With the new allow-list implementation, parameters are normalized and validated + assert "Server=localhost" in conn_str, "Connection string should contain 'Server=localhost'" + assert "UID=me" in conn_str, "Connection string should contain 'UID=me'" + assert "PWD=mypwd" in conn_str, "Connection string should contain 'PWD=mypwd'" + assert "Database=mydb" in conn_str, "Connection string should contain 'Database=mydb'" + assert "Encrypt=yes" in conn_str, "Connection string should contain 'Encrypt=yes'" + assert "TrustServerCertificate=yes" in conn_str, "Connection string should contain 'TrustServerCertificate=yes'" + assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'" + assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" def test_connection_string_with_attrs_before(db_connection): # Check if the connection string is constructed correctly with attrs_before - conn_str = db_connection._construct_connection_string( - host="localhost", - user="me", - password="mypwd", - database="mydb", - encrypt="yes", - trust_server_certificate="yes", - attrs_before={1256: "token"}, - ) - assert ( - "Server=localhost;" in conn_str - ), "Connection string should contain 'Server=localhost;'" - assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'" - assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'" - assert ( - "Database=mydb;" in conn_str - ), "Connection string should contain 'Database=mydb;'" - assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'" - assert ( - "TrustServerCertificate=yes;" in conn_str - ), "Connection string should contain 'TrustServerCertificate=yes;'" - assert ( - "APP=MSSQL-Python" in conn_str - ), "Connection string should contain 'APP=MSSQL-Python'" - assert ( - "Driver={ODBC Driver 18 for SQL Server}" in conn_str - ), "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" - assert ( - "{1256: token}" not in conn_str - ), "Connection string should not contain '{1256: token}'" - + # Using official ODBC parameter names + conn_str = db_connection._construct_connection_string(Server="localhost", UID="me", PWD="mypwd", Database="mydb", Encrypt="yes", TrustServerCertificate="yes", attrs_before={1256: "token"}) + # With the new allow-list implementation, parameters are normalized and validated + assert "Server=localhost" in conn_str, "Connection string should contain 'Server=localhost'" + assert "UID=me" in conn_str, "Connection string should contain 'UID=me'" + assert "PWD=mypwd" in conn_str, "Connection string should contain 'PWD=mypwd'" + assert "Database=mydb" in conn_str, "Connection string should contain 'Database=mydb'" + assert "Encrypt=yes" in conn_str, "Connection string should contain 'Encrypt=yes'" + assert "TrustServerCertificate=yes" in conn_str, "Connection string should contain 'TrustServerCertificate=yes'" + assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'" + assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" + assert "{1256: token}" not in conn_str, "Connection string should not contain '{1256: token}'" def test_connection_string_with_odbc_param(db_connection): # Check if the connection string is constructed correctly with ODBC parameters - conn_str = db_connection._construct_connection_string( - server="localhost", - uid="me", - pwd="mypwd", - database="mydb", - encrypt="yes", - trust_server_certificate="yes", - ) - assert ( - "Server=localhost;" in conn_str - ), "Connection string should contain 'Server=localhost;'" - assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'" - assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'" - assert ( - "Database=mydb;" in conn_str - ), "Connection string should contain 'Database=mydb;'" - assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'" - assert ( - "TrustServerCertificate=yes;" in conn_str - ), "Connection string should contain 'TrustServerCertificate=yes;'" - assert ( - "APP=MSSQL-Python" in conn_str - ), "Connection string should contain 'APP=MSSQL-Python'" - assert ( - "Driver={ODBC Driver 18 for SQL Server}" in conn_str - ), "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" - assert ( - "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" - == conn_str - ), "Connection string is incorrect" + # Using lowercase synonyms that normalize to uppercase (uid->UID, pwd->PWD) + conn_str = db_connection._construct_connection_string(server="localhost", uid="me", pwd="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes") + # With the new allow-list implementation, parameters are normalized and validated + assert "Server=localhost" in conn_str, "Connection string should contain 'Server=localhost'" + assert "UID=me" in conn_str, "Connection string should contain 'UID=me'" + assert "PWD=mypwd" in conn_str, "Connection string should contain 'PWD=mypwd'" + assert "Database=mydb" in conn_str, "Connection string should contain 'Database=mydb'" + assert "Encrypt=yes" in conn_str, "Connection string should contain 'Encrypt=yes'" + assert "TrustServerCertificate=yes" in conn_str, "Connection string should contain 'TrustServerCertificate=yes'" + assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'" + assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" def test_autocommit_default(db_connection): @@ -3661,7 +3601,6 @@ def test_execute_multiple_simultaneous_cursors(db_connection, conn_str): if is_azure_sql_connection(conn_str): pytest.skip("Skipping for Azure SQL - connection limits cause this test to hang") import gc - import sys # Start with a clean connection state cursor = db_connection.execute("SELECT 1") @@ -7690,7 +7629,7 @@ def test_set_attr_login_timeout_effect(conn_str): conn = connect(invalid_conn_str) # Don't use the login_timeout parameter conn.close() pytest.fail("Connection to invalid server should have failed") - except Exception as e: + except Exception: end_time = time.time() elapsed = end_time - start_time diff --git a/tests/test_006_exceptions.py b/tests/test_006_exceptions.py index ef705669..1b368838 100644 --- a/tests/test_006_exceptions.py +++ b/tests/test_006_exceptions.py @@ -15,6 +15,7 @@ raise_exception, truncate_error_message, ) +from mssql_python import ConnectionStringParseError def drop_table_if_exists(cursor, table_name): @@ -193,15 +194,11 @@ def test_foreign_key_constraint_error(cursor, db_connection): def test_connection_error(): - # RuntimeError is raised on Windows, while on MacOS it raises OperationalError - # In MacOS the error goes by "Client unable to establish connection" - # In Windows it goes by "Neither DSN nor SERVER keyword supplied" - # TODO: Make this test platform independent - with pytest.raises((RuntimeError, OperationalError)) as excinfo: + # The new connection string parser now validates the connection string before passing to ODBC + # Invalid strings like "InvalidConnectionString" (missing key=value format) will raise ConnectionStringParseError + with pytest.raises(ConnectionStringParseError) as excinfo: connect("InvalidConnectionString") - assert "Client unable to establish connection" in str( - excinfo.value - ) or "Neither DSN nor SERVER keyword supplied" in str(excinfo.value) + assert "Incomplete specification" in str(excinfo.value) or "has no value" in str(excinfo.value) def test_truncate_error_message_successful_cases(): diff --git a/tests/test_010_connection_string_parser.py b/tests/test_010_connection_string_parser.py new file mode 100644 index 00000000..6bbcdb12 --- /dev/null +++ b/tests/test_010_connection_string_parser.py @@ -0,0 +1,448 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Unit tests for _ConnectionStringParser (internal). +""" + +import pytest +from mssql_python.connection_string_parser import _ConnectionStringParser, ConnectionStringParseError +from mssql_python.constants import _ConnectionStringAllowList + + +class TestConnectionStringParser: + """Unit tests for _ConnectionStringParser.""" + + def test_parse_empty_string(self): + """Test parsing an empty string returns empty dict.""" + parser = _ConnectionStringParser() + result = parser._parse("") + assert result == {} + + def test_parse_whitespace_only(self): + """Test parsing whitespace-only connection string.""" + parser = _ConnectionStringParser() + result = parser._parse(" \t ") + assert result == {} + + def test_parse_simple_params(self): + """Test parsing simple key=value pairs.""" + parser = _ConnectionStringParser() + result = parser._parse("Server=localhost;Database=mydb") + assert result == { + 'server': 'localhost', + 'database': 'mydb' + } + + def test_parse_single_param(self): + """Test parsing a single parameter.""" + parser = _ConnectionStringParser() + result = parser._parse("Server=localhost") + assert result == {'server': 'localhost'} + + def test_parse_trailing_semicolon(self): + """Test parsing with trailing semicolon.""" + parser = _ConnectionStringParser() + result = parser._parse("Server=localhost;") + assert result == {'server': 'localhost'} + + def test_parse_multiple_semicolons(self): + """Test parsing with multiple consecutive semicolons.""" + parser = _ConnectionStringParser() + result = parser._parse("Server=localhost;;Database=mydb") + assert result == {'server': 'localhost', 'database': 'mydb'} + + def test_parse_braced_value_with_semicolon(self): + """Test parsing braced values containing semicolons.""" + parser = _ConnectionStringParser() + result = parser._parse("Server={;local;host};Database=mydb") + assert result == { + 'server': ';local;host', + 'database': 'mydb' + } + + def test_parse_braced_value_with_escaped_right_brace(self): + """Test parsing braced values with escaped }}.""" + parser = _ConnectionStringParser() + result = parser._parse("PWD={p}}w{{d}") + assert result == {'pwd': 'p}w{d'} + + def test_parse_braced_value_with_all_escapes(self): + """Test parsing braced values with both {{ and }} escapes.""" + parser = _ConnectionStringParser() + result = parser._parse("Value={test}}{{escape}") + assert result == {'value': 'test}{escape'} + + def test_parse_empty_value(self): + """Test that empty value raises error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=;Database=mydb") + assert "Empty value for keyword 'server'" in str(exc_info.value) + + def test_parse_empty_braced_value(self): + """Test that empty braced value raises error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server={};Database=mydb") + assert "Empty value for keyword 'server'" in str(exc_info.value) + + def test_parse_whitespace_around_key(self): + """Test parsing with whitespace around keys.""" + parser = _ConnectionStringParser() + result = parser._parse(" Server =localhost; Database =mydb") + assert result == {'server': 'localhost', 'database': 'mydb'} + + def test_parse_whitespace_in_simple_value(self): + """Test parsing simple value with trailing whitespace.""" + parser = _ConnectionStringParser() + result = parser._parse("Server=localhost ;Database=mydb") + assert result == {'server': 'localhost', 'database': 'mydb'} + + def test_parse_case_insensitive_keys(self): + """Test that keys are normalized to lowercase.""" + parser = _ConnectionStringParser() + result = parser._parse("SERVER=localhost;DatABase=mydb") + assert result == {'server': 'localhost', 'database': 'mydb'} + + def test_parse_special_chars_in_simple_value(self): + """Test parsing simple values with special characters (not ; { }).""" + parser = _ConnectionStringParser() + result = parser._parse("Server=server:1433;User=domain\\user") + assert result == {'server': 'server:1433', 'user': 'domain\\user'} + + def test_parse_complex_connection_string(self): + """Test parsing a complex realistic connection string.""" + parser = _ConnectionStringParser() + conn_str = "Server=tcp:server.database.windows.net,1433;Database=mydb;UID=user@server;PWD={p@ss;w}}rd};Encrypt=yes" + result = parser._parse(conn_str) + assert result == { + 'server': 'tcp:server.database.windows.net,1433', + 'database': 'mydb', + 'uid': 'user@server', + 'pwd': 'p@ss;w}rd', # }} escapes to single } + 'encrypt': 'yes' + } + + def test_parse_driver_parameter(self): + """Test parsing Driver parameter with braced value.""" + parser = _ConnectionStringParser() + result = parser._parse("Driver={ODBC Driver 18 for SQL Server};Server=localhost") + assert result == { + 'driver': 'ODBC Driver 18 for SQL Server', + 'server': 'localhost' + } + + def test_parse_braced_value_with_left_brace(self): + """Test parsing braced value containing unescaped single {.""" + parser = _ConnectionStringParser() + result = parser._parse("Value={test{value}") + assert result == {'value': 'test{value'} + + def test_parse_braced_value_double_left_brace(self): + """Test parsing braced value with escaped {{ (left brace).""" + parser = _ConnectionStringParser() + result = parser._parse("Value={test{{value}") + assert result == {'value': 'test{value'} + + def test_parse_unicode_characters(self): + """Test parsing values with unicode characters.""" + parser = _ConnectionStringParser() + result = parser._parse("Database=数据库;Server=сервер") + assert result == {'database': '数据库', 'server': 'сервер'} + + def test_parse_equals_in_braced_value(self): + """Test parsing braced value containing equals sign.""" + parser = _ConnectionStringParser() + result = parser._parse("Value={key=value}") + assert result == {'value': 'key=value'} + + def test_parse_special_characters_in_values(self): + """Test parsing values with various special characters.""" + parser = _ConnectionStringParser() + + # Numbers, hyphens, underscores in values + result = parser._parse("Server=server-123_test;Port=1433") + assert result == {'server': 'server-123_test', 'port': '1433'} + + # Dots, colons, commas in values + result = parser._parse("Server=server.domain.com:1433,1434") + assert result == {'server': 'server.domain.com:1433,1434'} + + # At signs, slashes in values + result = parser._parse("UID=user@domain.com;Path=/var/data") + assert result == {'uid': 'user@domain.com', 'path': '/var/data'} + + # Backslashes (common in Windows paths and domain users) + result = parser._parse("User=DOMAIN\\username;Path=C:\\temp") + assert result == {'user': 'DOMAIN\\username', 'path': 'C:\\temp'} + + def test_parse_special_characters_in_braced_values(self): + """Test parsing braced values with special characters that would otherwise be delimiters.""" + parser = _ConnectionStringParser() + + # Semicolons in braced values + result = parser._parse("PWD={pass;word;123};Server=localhost") + assert result == {'pwd': 'pass;word;123', 'server': 'localhost'} + + # Equals signs in braced values + result = parser._parse("ConnectString={Key1=Value1;Key2=Value2}") + assert result == {'connectstring': 'Key1=Value1;Key2=Value2'} + + # Multiple special chars including braces + result = parser._parse("Token={Bearer: abc123; Expires={{2024-01-01}}}") + assert result == {'token': 'Bearer: abc123; Expires={2024-01-01}'} + + def test_parse_numbers_and_symbols_in_passwords(self): + """Test parsing passwords with various numbers and symbols.""" + parser = _ConnectionStringParser() + + # Common password characters without braces + result = parser._parse("Server=localhost;PWD=Pass123!@#") + assert result == {'server': 'localhost', 'pwd': 'Pass123!@#'} + + # Special symbols that require bracing + result = parser._parse("PWD={P@ss;w0rd!};Server=srv") + assert result == {'pwd': 'P@ss;w0rd!', 'server': 'srv'} + + # Complex password with multiple special chars + result = parser._parse("PWD={P@$$w0rd!#123%;^&*()}") + assert result == {'pwd': 'P@$$w0rd!#123%;^&*()'} + + def test_parse_emoji_and_extended_unicode(self): + """Test parsing values with emoji and extended unicode characters.""" + parser = _ConnectionStringParser() + + # Emoji in values + result = parser._parse("Description={Test 🚀 Database};Status=✓") + assert result == {'description': 'Test 🚀 Database', 'status': '✓'} + + # Various unicode scripts + result = parser._parse("Name=مرحبا;Title=こんにちは;Info=안녕하세요") + assert result == {'name': 'مرحبا', 'title': 'こんにちは', 'info': '안녕하세요'} + + def test_parse_whitespace_characters(self): + """Test parsing values with various whitespace characters.""" + parser = _ConnectionStringParser() + + # Spaces in braced values (preserved) + result = parser._parse("Name={John Doe};Title={Senior Engineer}") + assert result == {'name': 'John Doe', 'title': 'Senior Engineer'} + + # Tabs in braced values + result = parser._parse("Data={value1\tvalue2\tvalue3}") + assert result == {'data': 'value1\tvalue2\tvalue3'} + + def test_parse_url_encoded_characters(self): + """Test parsing values that look like URL encoding.""" + parser = _ConnectionStringParser() + + # Values with percent signs and hex-like patterns + result = parser._parse("Value=test%20value;Percent=100%") + assert result == {'value': 'test%20value', 'percent': '100%'} + + # URL-like connection strings + result = parser._parse("Server=https://api.example.com/v1;Key=abc-123-def") + assert result == {'server': 'https://api.example.com/v1', 'key': 'abc-123-def'} + + +class TestConnectionStringParserErrors: + """Test error handling in ConnectionStringParser.""" + + def test_error_duplicate_keys(self): + """Test that duplicate keys raise an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=first;Server=second;Server=third") + + assert "Duplicate keyword 'server'" in str(exc_info.value) + assert len(exc_info.value.errors) == 2 # Two duplicates (second and third) + + def test_error_incomplete_specification_no_equals(self): + """Test that keyword without '=' raises an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server;Database=mydb") + + assert "Incomplete specification" in str(exc_info.value) + assert "'server'" in str(exc_info.value).lower() + + def test_error_incomplete_specification_trailing(self): + """Test that trailing keyword without value raises an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;Database") + + assert "Incomplete specification" in str(exc_info.value) + assert "'database'" in str(exc_info.value).lower() + + def test_error_empty_key(self): + """Test that empty keyword raises an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("=value;Server=localhost") + + assert "Empty keyword" in str(exc_info.value) + + def test_error_unclosed_braced_value(self): + """Test that unclosed braces raise an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("PWD={unclosed;Server=localhost") + + assert "Unclosed braced value" in str(exc_info.value) + + def test_error_multiple_empty_values(self): + """Test that multiple empty values are all collected as errors.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=;Database=;UID=user;PWD=") + + # Should have 3 errors for empty values + errors = exc_info.value.errors + assert len(errors) >= 3 + assert any("Empty value for keyword 'server'" in err for err in errors) + assert any("Empty value for keyword 'database'" in err for err in errors) + assert any("Empty value for keyword 'pwd'" in err for err in errors) + + def test_error_multiple_issues_collected(self): + """Test that multiple different types of errors are collected and reported together.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + # Multiple error types: incomplete spec, duplicate, empty value, empty key + parser._parse("Server=first;InvalidEntry;Server=second;Database=;=value;WhatIsThis") + + # Should have: incomplete spec for InvalidEntry, duplicate Server, empty Database value, empty key + errors = exc_info.value.errors + assert len(errors) >= 4 + + errors_str = str(exc_info.value) + assert "Incomplete specification" in errors_str + assert "Duplicate keyword" in errors_str + assert "Empty value for keyword 'database'" in errors_str + assert "Empty keyword" in errors_str + + def test_error_unknown_keyword_with_allowlist(self): + """Test that unknown keywords are flagged when allowlist is provided.""" + allowlist = _ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;UnknownParam=value") + + assert "Unknown keyword 'unknownparam'" in str(exc_info.value) + + def test_error_multiple_unknown_keywords(self): + """Test that multiple unknown keywords are all flagged.""" + allowlist = _ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;Unknown1=val1;Database=mydb;Unknown2=val2") + + errors_str = str(exc_info.value) + assert "Unknown keyword 'unknown1'" in errors_str + assert "Unknown keyword 'unknown2'" in errors_str + + def test_error_combined_unknown_and_duplicate(self): + """Test that unknown keywords and duplicates are both flagged.""" + allowlist = _ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=first;UnknownParam=value;Server=second") + + errors_str = str(exc_info.value) + assert "Unknown keyword 'unknownparam'" in errors_str + assert "Duplicate keyword 'server'" in errors_str + + def test_valid_with_allowlist(self): + """Test that valid keywords pass when allowlist is provided.""" + allowlist = _ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + # These are all valid keywords in the allowlist + result = parser._parse("Server=localhost;Database=mydb;UID=user;PWD=pass") + assert result == { + 'server': 'localhost', + 'database': 'mydb', + 'uid': 'user', + 'pwd': 'pass' + } + + def test_no_validation_without_allowlist(self): + """Test that unknown keywords are allowed when no allowlist is provided.""" + parser = _ConnectionStringParser() # No allowlist + + # Should parse successfully even with unknown keywords + result = parser._parse("Server=localhost;MadeUpKeyword=value") + assert result == { + 'server': 'localhost', + 'madeupkeyword': 'value' + } + + +class TestConnectionStringParserEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_error_all_duplicates(self): + """Test string with only duplicates.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=a;Server=b;Server=c") + + # First occurrence is kept, other two are duplicates + assert len(exc_info.value.errors) == 2 + + def test_error_mixed_valid_and_errors(self): + """Test that valid params are parsed even when errors exist.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;BadEntry;Database=mydb;Server=dup") + + # Should detect incomplete and duplicate + assert len(exc_info.value.errors) >= 2 + + def test_normalization_still_works(self): + """Test that key normalization to lowercase still works.""" + parser = _ConnectionStringParser() + result = parser._parse("SERVER=srv;DaTaBaSe=db") + assert result == {'server': 'srv', 'database': 'db'} + + def test_error_duplicate_after_normalization(self): + """Test that duplicates are detected after normalization.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=first;SERVER=second") + + assert "Duplicate keyword 'server'" in str(exc_info.value) + + def test_empty_value_edge_cases(self): + """Test that empty values are treated as errors.""" + parser = _ConnectionStringParser() + + # Empty value after = with trailing semicolon + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;Database=") + assert "Empty value for keyword 'database'" in str(exc_info.value) + + # Empty value at end of string (no trailing semicolon) + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;Database=") + assert "Empty value for keyword 'database'" in str(exc_info.value) + + # Value with only whitespace is treated as empty after strip + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;Database= ") + assert "Empty value for keyword 'database'" in str(exc_info.value) + + def test_incomplete_entry_recovery(self): + """Test that parser can recover from incomplete entries and continue parsing.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + # Incomplete entry followed by valid entry + parser._parse("Server;Database=mydb;UID=user") + + # Should have error about incomplete 'Server' + errors = exc_info.value.errors + assert any('Server' in err and 'Incomplete specification' in err for err in errors) diff --git a/tests/test_011_connection_string_allowlist.py b/tests/test_011_connection_string_allowlist.py new file mode 100644 index 00000000..bdd9a674 --- /dev/null +++ b/tests/test_011_connection_string_allowlist.py @@ -0,0 +1,244 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Unit tests for _ConnectionStringAllowList. +""" + +from mssql_python.constants import _ConnectionStringAllowList + + +class Test_ConnectionStringAllowList: + """Unit tests for _ConnectionStringAllowList.""" + + def test_normalize_key_server(self): + """Test normalization of 'server' and related address parameters.""" + # server, address, and addr are all synonyms that map to 'Server' + assert _ConnectionStringAllowList.normalize_key('server') == 'Server' + assert _ConnectionStringAllowList.normalize_key('SERVER') == 'Server' + assert _ConnectionStringAllowList.normalize_key('Server') == 'Server' + assert _ConnectionStringAllowList.normalize_key('address') == 'Server' + assert _ConnectionStringAllowList.normalize_key('ADDRESS') == 'Server' + assert _ConnectionStringAllowList.normalize_key('addr') == 'Server' + assert _ConnectionStringAllowList.normalize_key('ADDR') == 'Server' + + def test_normalize_key_authentication(self): + """Test normalization of authentication parameters.""" + assert _ConnectionStringAllowList.normalize_key('uid') == 'UID' + assert _ConnectionStringAllowList.normalize_key('UID') == 'UID' + assert _ConnectionStringAllowList.normalize_key('pwd') == 'PWD' + assert _ConnectionStringAllowList.normalize_key('PWD') == 'PWD' + assert _ConnectionStringAllowList.normalize_key('authentication') == 'Authentication' + assert _ConnectionStringAllowList.normalize_key('trusted_connection') == 'Trusted_Connection' + + def test_normalize_key_database(self): + """Test normalization of database parameter.""" + assert _ConnectionStringAllowList.normalize_key('database') == 'Database' + assert _ConnectionStringAllowList.normalize_key('DATABASE') == 'Database' + # 'initial catalog' is not in the restricted allowlist + assert _ConnectionStringAllowList.normalize_key('initial catalog') is None + + def test_normalize_key_encryption(self): + """Test normalization of encryption parameters.""" + assert _ConnectionStringAllowList.normalize_key('encrypt') == 'Encrypt' + assert _ConnectionStringAllowList.normalize_key('trustservercertificate') == 'TrustServerCertificate' + assert _ConnectionStringAllowList.normalize_key('hostnameincertificate') == 'HostnameInCertificate' + assert _ConnectionStringAllowList.normalize_key('servercertificate') == 'ServerCertificate' + def test_normalize_key_connection_params(self): + """Test normalization of connection behavior parameters.""" + assert _ConnectionStringAllowList.normalize_key('connectretrycount') == 'ConnectRetryCount' + assert _ConnectionStringAllowList.normalize_key('connectretryinterval') == 'ConnectRetryInterval' + assert _ConnectionStringAllowList.normalize_key('multisubnetfailover') == 'MultiSubnetFailover' + assert _ConnectionStringAllowList.normalize_key('applicationintent') == 'ApplicationIntent' + assert _ConnectionStringAllowList.normalize_key('keepalive') == 'KeepAlive' + assert _ConnectionStringAllowList.normalize_key('keepaliveinterval') == 'KeepAliveInterval' + assert _ConnectionStringAllowList.normalize_key('ipaddresspreference') == 'IpAddressPreference' + # Timeout parameters not in restricted allowlist + assert _ConnectionStringAllowList.normalize_key('connection timeout') is None + assert _ConnectionStringAllowList.normalize_key('login timeout') is None + assert _ConnectionStringAllowList.normalize_key('connect timeout') is None + assert _ConnectionStringAllowList.normalize_key('timeout') is None + + def test_normalize_key_mars(self): + """Test that MARS parameters are not in the allowlist.""" + assert _ConnectionStringAllowList.normalize_key('mars_connection') is None + assert _ConnectionStringAllowList.normalize_key('mars connection') is None + assert _ConnectionStringAllowList.normalize_key('multipleactiveresultsets') is None + + def test_normalize_key_app(self): + """Test normalization of APP parameter.""" + assert _ConnectionStringAllowList.normalize_key('app') == 'APP' + assert _ConnectionStringAllowList.normalize_key('APP') == 'APP' + # 'application name' is not in restricted allowlist + assert _ConnectionStringAllowList.normalize_key('application name') is None + + def test_normalize_key_driver(self): + """Test normalization of Driver parameter.""" + assert _ConnectionStringAllowList.normalize_key('driver') == 'Driver' + assert _ConnectionStringAllowList.normalize_key('DRIVER') == 'Driver' + + def test_normalize_key_not_allowed(self): + """Test normalization of disallowed keys returns None.""" + assert _ConnectionStringAllowList.normalize_key('BadParam') is None + assert _ConnectionStringAllowList.normalize_key('UnsupportedParameter') is None + assert _ConnectionStringAllowList.normalize_key('RandomKey') is None + + def test_normalize_key_whitespace(self): + """Test normalization handles whitespace.""" + assert _ConnectionStringAllowList.normalize_key(' server ') == 'Server' + assert _ConnectionStringAllowList.normalize_key(' uid ') == 'UID' + assert _ConnectionStringAllowList.normalize_key(' database ') == 'Database' + + def test__normalize_params_allows_good_params(self): + """Test filtering allows known parameters.""" + params = {'server': 'localhost', 'database': 'mydb', 'encrypt': 'yes'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + assert 'Server' in filtered + assert 'Database' in filtered + assert 'Encrypt' in filtered + assert filtered['Server'] == 'localhost' + assert filtered['Database'] == 'mydb' + assert filtered['Encrypt'] == 'yes' + + def test__normalize_params_rejects_bad_params(self): + """Test filtering rejects unknown parameters.""" + params = {'server': 'localhost', 'badparam': 'value', 'anotherbad': 'test'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + assert 'Server' in filtered + assert 'badparam' not in filtered + assert 'anotherbad' not in filtered + + def test__normalize_params_normalizes_keys(self): + """Test filtering normalizes parameter keys.""" + params = {'server': 'localhost', 'uid': 'user', 'pwd': 'pass'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + assert 'Server' in filtered + assert 'UID' in filtered + assert 'PWD' in filtered + assert 'server' not in filtered # Original key should not be present + + def test__normalize_params_handles_address_variants(self): + """Test filtering handles address/addr/server as synonyms.""" + params = { + 'address': 'addr1', + 'addr': 'addr2', + 'server': 'server1' + } + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + # All three are synonyms that map to 'Server', last one wins + assert filtered['Server'] == 'server1' + assert 'Address' not in filtered + assert 'Addr' not in filtered + + def test__normalize_params_empty_dict(self): + """Test filtering empty parameter dictionary.""" + filtered = _ConnectionStringAllowList._normalize_params({}, warn_rejected=False) + assert filtered == {} + + def test__normalize_params_removes_driver(self): + """Test that Driver parameter is filtered out (controlled by driver).""" + params = {'driver': '{Some Driver}', 'server': 'localhost'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + assert 'Driver' not in filtered + assert 'Server' in filtered + + def test__normalize_params_removes_app(self): + """Test that APP parameter is filtered out (controlled by driver).""" + params = {'app': 'MyApp', 'server': 'localhost'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + assert 'APP' not in filtered + assert 'Server' in filtered + + def test__normalize_params_mixed_case_keys(self): + """Test filtering with mixed case keys.""" + params = {'SERVER': 'localhost', 'DataBase': 'mydb', 'EncRypt': 'yes'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + assert 'Server' in filtered + assert 'Database' in filtered + assert 'Encrypt' in filtered + + def test__normalize_params_preserves_values(self): + """Test that filtering preserves original values unchanged.""" + params = { + 'server': 'localhost:1433', + 'database': 'MyDatabase', + 'pwd': 'P@ssw0rd!123' + } + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + assert filtered['Server'] == 'localhost:1433' + assert filtered['Database'] == 'MyDatabase' + assert filtered['PWD'] == 'P@ssw0rd!123' + + def test__normalize_params_application_intent(self): + """Test filtering application intent parameters.""" + # Only 'applicationintent' (no spaces) is in the allowlist + params = {'applicationintent': 'ReadOnly', 'application intent': 'ReadWrite'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + # 'application intent' with space is rejected, only compact form accepted + assert filtered['ApplicationIntent'] == 'ReadOnly' + assert len(filtered) == 1 + + def test__normalize_params_failover_partner(self): + """Test that failover partner is not in the restricted allowlist.""" + params = {'failover partner': 'backup.server.com', 'failoverpartner': 'backup2.com'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + # Failover_Partner is not in the restricted allowlist + assert 'Failover_Partner' not in filtered + assert 'FailoverPartner' not in filtered + assert len(filtered) == 0 + + def test__normalize_params_column_encryption(self): + """Test that column encryption parameter is not in the allowlist.""" + params = {'columnencryption': 'Enabled', 'column encryption': 'Disabled'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + # Column encryption is not in the allowlist, so it should be filtered out + assert 'ColumnEncryption' not in filtered + assert len(filtered) == 0 + + def test__normalize_params_multisubnetfailover(self): + """Test filtering multi-subnet failover parameters.""" + # Only 'multisubnetfailover' (no spaces) is in the allowlist + params = {'multisubnetfailover': 'yes', 'multi subnet failover': 'no'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=False) + # 'multi subnet failover' with spaces is rejected + assert filtered['MultiSubnetFailover'] == 'yes' + assert len(filtered) == 1 + + def test__normalize_params_with_warnings(self): + """Test that rejected parameters are logged when warn_rejected=True.""" + import logging + + # Create a custom logger for this test + logger = logging.getLogger('test_normalize_params_warnings') + logger.setLevel(logging.WARNING) + + # Add a handler to capture log messages + import io + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setLevel(logging.WARNING) + logger.addHandler(handler) + + # Temporarily replace the get_logger function + import mssql_python.logging_config as logging_config + original_get_logger = logging_config.get_logger + logging_config.get_logger = lambda: logger + + try: + # Test with unknown parameters and warn_rejected=True + params = {'server': 'localhost', 'badparam1': 'value1', 'badparam2': 'value2'} + filtered = _ConnectionStringAllowList._normalize_params(params, warn_rejected=True) + + # Check that good param was kept + assert 'Server' in filtered + assert len(filtered) == 1 + + # Check that warning was logged with all rejected keys + log_output = log_stream.getvalue() + assert 'badparam1' in log_output + assert 'badparam2' in log_output + assert 'not in allow-list' in log_output + finally: + # Restore original get_logger + logging_config.get_logger = original_get_logger + logger.removeHandler(handler) diff --git a/tests/test_012_connection_string_integration.py b/tests/test_012_connection_string_integration.py new file mode 100644 index 00000000..d2bb02b9 --- /dev/null +++ b/tests/test_012_connection_string_integration.py @@ -0,0 +1,646 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Integration tests for connection string allow-list feature. + +These tests verify end-to-end behavior of the parser, filter, and builder pipeline. +""" + +import pytest +import os +from unittest.mock import patch, MagicMock +from mssql_python.connection_string_parser import _ConnectionStringParser, ConnectionStringParseError +from mssql_python.constants import _ConnectionStringAllowList +from mssql_python.connection_string_builder import _ConnectionStringBuilder +from mssql_python import connect + + +class TestConnectionStringIntegration: + """Integration tests for the complete connection string flow.""" + + def test_parse_filter_build_simple(self): + """Test complete flow with simple parameters.""" + # Parse + parser = _ConnectionStringParser() + parsed = parser._parse("Server=localhost;Database=mydb;Encrypt=yes") + + # Filter + filtered = _ConnectionStringAllowList._normalize_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + builder.add_param('APP', 'MSSQL-Python') + result = builder.build() + + # Verify + assert 'Driver={ODBC Driver 18 for SQL Server}' in result + assert 'Server=localhost' in result + assert 'Database=mydb' in result + assert 'Encrypt=yes' in result + assert 'APP=MSSQL-Python' in result + + def test_parse_filter_build_with_unsupported_param(self): + """Test that unsupported parameters are flagged as errors with allowlist.""" + # Parse with allowlist + allowlist = _ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + # Should raise error for unknown keyword + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;Database=mydb;UnsupportedParam=value") + + assert "Unknown keyword 'unsupportedparam'" in str(exc_info.value) + + def test_parse_filter_build_with_braced_values(self): + """Test complete flow with braced values and special characters.""" + # Parse + parser = _ConnectionStringParser() + parsed = parser._parse("Server={local;host};PWD={p@ss;w}}rd}") + + # Filter + filtered = _ConnectionStringAllowList._normalize_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + result = builder.build() + + # Verify - values with special chars should be re-escaped + assert 'Driver={ODBC Driver 18 for SQL Server}' in result + assert 'Server={local;host}' in result + assert 'Pwd={p@ss;w}}rd}' in result or 'PWD={p@ss;w}}rd}' in result + + def test_parse_filter_build_synonym_normalization(self): + """Test that parameter synonyms are normalized.""" + # Parse + parser = _ConnectionStringParser() + # Use parameters that are in the restricted allowlist + parsed = parser._parse("address=server1;uid=testuser;database=testdb") + + # Filter (normalizes synonyms) + filtered = _ConnectionStringAllowList._normalize_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + result = builder.build() + + # Verify - should use canonical names + assert 'Server=server1' in result # address -> Server + assert 'UID=testuser' in result # uid -> UID + assert 'Database=testdb' in result + # Original names should not appear + assert 'address' not in result.lower() + # uid appears in UID, so check for the exact pattern + assert result.count('UID=') == 1 + + def test_parse_filter_build_driver_and_app_reserved(self): + """Test that Driver and APP in connection string raise errors.""" + # Parser should reject Driver and APP as reserved keywords + allowlist = _ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + # Test with APP + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("APP=UserApp;Server=localhost") + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'app'" in error_lower + + # Test with Driver + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Driver={Some Other Driver};Server=localhost") + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'driver'" in error_lower + + # Test with both + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Driver={Some Other Driver};APP=UserApp;Server=localhost") + error_str = str(exc_info.value).lower() + assert "reserved keyword" in error_str + # Should have errors for both + assert len(exc_info.value.errors) == 2 + + def test_parse_filter_build_empty_input(self): + """Test complete flow with empty input.""" + # Parse + parser = _ConnectionStringParser() + parsed = parser._parse("") + + # Filter + filtered = _ConnectionStringAllowList._normalize_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + result = builder.build() + + # Verify - should only have Driver + assert result == 'Driver={ODBC Driver 18 for SQL Server}' + + def test_parse_filter_build_complex_realistic(self): + """Test complete flow with complex realistic connection string.""" + # Parse + parser = _ConnectionStringParser() + # Note: Connection Timeout is not in the restricted allowlist + conn_str = "Server=tcp:server.database.windows.net,1433;Database=mydb;UID=user@server;PWD={P@ss;w}}rd};Encrypt=yes;TrustServerCertificate=no" + parsed = parser._parse(conn_str) + + # Filter + filtered = _ConnectionStringAllowList._normalize_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + builder.add_param('APP', 'MSSQL-Python') + result = builder.build() + + # Verify key parameters are present + assert 'Driver={ODBC Driver 18 for SQL Server}' in result + assert 'Server=tcp:server.database.windows.net,1433' in result + assert 'Database=mydb' in result + assert 'UID=user@server' in result # UID not Uid (canonical form) + assert 'PWD={P@ss;w}}rd}' in result + assert 'Encrypt=yes' in result + assert 'TrustServerCertificate=no' in result + # Connection Timeout not in result (filtered out) + assert 'Connection Timeout' not in result + assert 'APP=MSSQL-Python' in result + + def test_parse_error_incomplete_specification(self): + """Test that incomplete specifications raise errors.""" + parser = _ConnectionStringParser() + + # Incomplete specification raises error + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server localhost;Database=mydb") + + assert "Incomplete specification" in str(exc_info.value) + assert "'server localhost'" in str(exc_info.value).lower() + + def test_parse_error_unclosed_brace(self): + """Test that unclosed braces raise errors.""" + parser = _ConnectionStringParser() + + # Unclosed brace raises error + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("PWD={unclosed;Server=localhost") + + assert "Unclosed braced value" in str(exc_info.value) + + def test_parse_error_duplicate_keywords(self): + """Test that duplicate keywords raise errors.""" + parser = _ConnectionStringParser() + + # Duplicate keywords raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=first;Server=second") + + assert "Duplicate keyword 'server'" in str(exc_info.value) + + def test_round_trip_preserves_values(self): + """Test that parsing and rebuilding preserves parameter values.""" + original_params = { + 'server': 'localhost:1433', + 'database': 'TestDB', + 'uid': 'testuser', + 'pwd': 'Test@123', + 'encrypt': 'yes' + } + + # Filter + filtered = _ConnectionStringAllowList._normalize_params(original_params, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + result = builder.build() + + # Parse back + parser = _ConnectionStringParser() + parsed = parser._parse(result) + + # Verify values are preserved (keys are normalized to lowercase in parsing) + assert parsed['server'] == 'localhost:1433' + assert parsed['database'] == 'TestDB' + assert parsed['uid'] == 'testuser' + assert parsed['pwd'] == 'Test@123' + assert parsed['encrypt'] == 'yes' + assert parsed['driver'] == 'ODBC Driver 18 for SQL Server' + + def test_builder_escaping_is_correct(self): + """Test that builder correctly escapes special characters.""" + builder = _ConnectionStringBuilder() + builder.add_param('Server', 'local;host') + builder.add_param('PWD', 'p}w{d') + builder.add_param('Value', 'test;{value}') + result = builder.build() + + # Parse back to verify escaping worked + parser = _ConnectionStringParser() + parsed = parser._parse(result) + + assert parsed['server'] == 'local;host' + assert parsed['pwd'] == 'p}w{d' + assert parsed['value'] == 'test;{value}' + + def test_builder_empty_value(self): + """Test that parser rejects empty values built by builder.""" + builder = _ConnectionStringBuilder() + builder.add_param('Server', 'localhost') + builder.add_param('Database', '') # Empty value + builder.add_param('UID', 'user') + result = builder.build() + + # Parser should reject empty value + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse(result) + + assert "Empty value for keyword 'database'" in str(exc_info.value) + + def test_multiple_errors_collected(self): + """Test that multiple errors are collected and reported together.""" + parser = _ConnectionStringParser() + + # Multiple errors: incomplete spec, duplicate + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=first;InvalidEntry;Server=second;Database") + + # Should have multiple errors + assert len(exc_info.value.errors) >= 3 + assert "Incomplete specification" in str(exc_info.value) + assert "Duplicate keyword" in str(exc_info.value) + + def test_parser_without_allowlist_accepts_unknown(self): + """Test that parser without allowlist accepts unknown keywords.""" + parser = _ConnectionStringParser() # No allowlist + + # Should parse successfully even with unknown keywords + result = parser._parse("Server=localhost;MadeUpKeyword=value") + assert result == { + 'server': 'localhost', + 'madeupkeyword': 'value' + } + + def test_parser_with_allowlist_rejects_unknown(self): + """Test that parser with allowlist rejects unknown keywords.""" + allowlist = _ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + # Should raise error for unknown keyword + with pytest.raises(ConnectionStringParseError) as exc_info: + parser._parse("Server=localhost;MadeUpKeyword=value") + + assert "Unknown keyword 'madeupkeyword'" in str(exc_info.value) + + +class TestConnectAPIIntegration: + """Integration tests for the connect() API with connection string validation.""" + + def test_connect_with_unknown_keyword_raises_error(self): + """Test that connect() raises error for unknown keywords.""" + # connect() uses allowlist validation internally + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=localhost;Database=test;UnknownKeyword=value") + + assert "Unknown keyword 'unknownkeyword'" in str(exc_info.value) + + def test_connect_with_duplicate_keywords_raises_error(self): + """Test that connect() raises error for duplicate keywords.""" + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=first;Server=second;Database=test") + + assert "Duplicate keyword 'server'" in str(exc_info.value) + + def test_connect_with_incomplete_specification_raises_error(self): + """Test that connect() raises error for incomplete specifications.""" + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server localhost;Database=test") + + assert "Incomplete specification" in str(exc_info.value) + + def test_connect_with_unclosed_brace_raises_error(self): + """Test that connect() raises error for unclosed braces.""" + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("PWD={unclosed;Server=localhost") + + assert "Unclosed braced value" in str(exc_info.value) + + def test_connect_with_multiple_errors_collected(self): + """Test that connect() collects multiple errors.""" + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=first;InvalidEntry;Server=second;Database") + + # Should have multiple errors + assert len(exc_info.value.errors) >= 3 + error_str = str(exc_info.value) + assert "Incomplete specification" in error_str + assert "Duplicate keyword" in error_str + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_kwargs_override_connection_string(self, mock_ddbc_conn): + """Test that kwargs override connection string parameters.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + conn = connect("Server=original;Database=originaldb", + Server="overridden", + Database="overriddendb") + + # Verify the override worked + assert "overridden" in conn.connection_str.lower() + assert "overriddendb" in conn.connection_str.lower() + # Original values should not be in the final connection string + assert "original" not in conn.connection_str.lower() or "originaldb" not in conn.connection_str.lower() + + conn.close() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_app_parameter_in_connection_string_raises_error(self, mock_ddbc_conn): + """Test that APP parameter in connection string raises ConnectionStringParseError.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # User tries to set APP in connection string - should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=localhost;APP=UserApp;Database=test") + + # Verify error message + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'app'" in error_lower + assert "controlled by the driver" in error_lower + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_app_parameter_in_kwargs_raises_error(self, mock_ddbc_conn): + """Test that APP parameter in kwargs raises ValueError.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # User tries to set APP via kwargs - should raise ValueError + with pytest.raises(ValueError) as exc_info: + connect("Server=localhost;Database=test", APP="UserApp") + + assert "reserved and controlled by the driver" in str(exc_info.value) + assert "APP" in str(exc_info.value) or "app" in str(exc_info.value).lower() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_driver_parameter_in_connection_string_raises_error(self, mock_ddbc_conn): + """Test that Driver parameter in connection string raises ConnectionStringParseError.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # User tries to set Driver in connection string - should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=localhost;Driver={Some Other Driver};Database=test") + + # Verify error message + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'driver'" in error_lower + assert "controlled by the driver" in error_lower + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_driver_parameter_in_kwargs_raises_error(self, mock_ddbc_conn): + """Test that Driver parameter in kwargs raises ValueError.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # User tries to set Driver via kwargs - should raise ValueError + with pytest.raises(ValueError) as exc_info: + connect("Server=localhost;Database=test", Driver="Some Other Driver") + + assert "reserved and controlled by the driver" in str(exc_info.value) + assert "Driver" in str(exc_info.value) + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_synonym_normalization(self, mock_ddbc_conn): + """Test that connect() normalizes parameter synonyms.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # Use parameters that are in the restricted allowlist + conn = connect("address=server1;uid=testuser;database=testdb") + + # Synonyms should be normalized to canonical names + assert "Server=server1" in conn.connection_str # address -> Server + assert "UID=testuser" in conn.connection_str # uid -> UID + assert "Database=testdb" in conn.connection_str + # Verify address was normalized (not present in output) + assert "Address=" not in conn.connection_str + assert "Addr=" not in conn.connection_str + + conn.close() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_kwargs_unknown_parameter_warned(self, mock_ddbc_conn): + """Test that unknown kwargs are warned about but don't raise errors during parsing.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # Unknown kwargs are filtered out with a warning, but don't cause parse errors + # because kwargs bypass the parser's allowlist validation + conn = connect("Server=localhost", Database="test", UnknownParam="value") + + # UnknownParam should be filtered out (warned but not included) + conn_str_lower = conn.connection_str.lower() + assert "database=test" in conn_str_lower + assert "unknownparam" not in conn_str_lower + + conn.close() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_empty_connection_string(self, mock_ddbc_conn): + """Test that connect() works with empty connection string and kwargs.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + conn = connect("", Server="localhost", Database="test") + + # Should have Server and Database from kwargs + conn_str_lower = conn.connection_str.lower() + assert "server=localhost" in conn_str_lower + assert "database=test" in conn_str_lower + assert "driver=" in conn_str_lower # Driver is always added + assert "app=mssql-python" in conn_str_lower # APP is always added + + conn.close() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_special_characters_in_values(self, mock_ddbc_conn): + """Test that connect() properly handles special characters in parameter values.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + conn = connect("Server={local;host};PWD={p@ss;w}}rd};Database=test") + + # Special characters should be preserved through parsing and building + # The connection string should properly escape them + assert "local;host" in conn.connection_str or "{local;host}" in conn.connection_str + assert "p@ss;w}rd" in conn.connection_str or "{p@ss;w}}rd}" in conn.connection_str + + conn.close() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_connect_with_real_database(self, conn_str): + """Test that connect() works with a real database connection.""" + # This test only runs if DB_CONNECTION_STRING is set + conn = connect(conn_str) + assert conn is not None + + # Verify connection string has required parameters + assert "Driver=" in conn.connection_str or "driver=" in conn.connection_str + assert "APP=MSSQL-Python" in conn.connection_str or "app=mssql-python" in conn.connection_str.lower() + + # Test basic query execution + cursor = conn.cursor() + cursor.execute("SELECT 1 AS test") + row = cursor.fetchone() + assert row[0] == 1 + cursor.close() + + conn.close() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_connect_kwargs_override_with_real_database(self, conn_str): + """Test that kwargs override works with a real database connection.""" + + # Create connection with overridden autocommit + conn = connect(conn_str, autocommit=True) + + # Verify connection works and autocommit is set + assert conn.autocommit == True + + # Verify connection string still has all required params + assert "Driver=" in conn.connection_str or "driver=" in conn.connection_str + assert "APP=MSSQL-Python" in conn.connection_str or "app=mssql-python" in conn.connection_str.lower() + + conn.close() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_connect_reserved_params_in_connection_string_raise_error(self, conn_str): + """Test that reserved params (Driver, APP) in connection string raise error.""" + # Try to add Driver to connection string - should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + test_conn_str = conn_str + ";Driver={User Driver}" + connect(test_conn_str) + assert "reserved keyword" in str(exc_info.value).lower() + assert "driver" in str(exc_info.value).lower() + + # Try to add APP to connection string - should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + test_conn_str = conn_str + ";APP=UserApp" + connect(test_conn_str) + assert "reserved keyword" in str(exc_info.value).lower() + assert "app" in str(exc_info.value).lower() + + # Application Name is not in the restricted allowlist (not a synonym for APP) + # It should be rejected as an unknown parameter + with pytest.raises(ConnectionStringParseError) as exc_info: + test_conn_str = conn_str + ";Application Name=UserApp" + connect(test_conn_str) + assert "unknown keyword" in str(exc_info.value).lower() + assert "application name" in str(exc_info.value).lower() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_connect_reserved_params_in_kwargs_raise_error(self, conn_str): + """Test that reserved params (Driver, APP) in kwargs raise ValueError.""" + # Try to override Driver via kwargs - should raise ValueError + with pytest.raises(ValueError) as exc_info: + connect(conn_str, Driver="User Driver") + assert "reserved and controlled by the driver" in str(exc_info.value) + + # Try to override APP via kwargs - should raise ValueError + with pytest.raises(ValueError) as exc_info: + connect(conn_str, APP="UserApp") + assert "reserved and controlled by the driver" in str(exc_info.value) + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_app_name_received_by_sql_server(self, conn_str): + """Test that SQL Server receives the driver-controlled APP name 'MSSQL-Python'.""" + # Connect to SQL Server + with connect(conn_str) as conn: + # Query SQL Server to get the application name it received + cursor = conn.cursor() + cursor.execute("SELECT APP_NAME() AS app_name") + row = cursor.fetchone() + cursor.close() + + # Verify SQL Server received the driver-controlled application name + assert row is not None, "Failed to get APP_NAME() from SQL Server" + app_name_received = row[0] + + # SQL Server should have received 'MSSQL-Python', not any user-provided value + assert app_name_received == 'MSSQL-Python', \ + f"Expected SQL Server to receive 'MSSQL-Python', but got '{app_name_received}'" + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_app_name_in_connection_string_raises_error(self, conn_str): + """Test that APP in connection string raises ConnectionStringParseError.""" + # Connection strings with APP parameter should now raise an error (not silently filter) + + # Try to add APP to connection string + test_conn_str = conn_str + ";APP=UserDefinedApp" + + # Should raise ConnectionStringParseError + with pytest.raises(ConnectionStringParseError) as exc_info: + connect(test_conn_str) + + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'app'" in error_lower + assert "controlled by the driver" in error_lower + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_app_name_in_kwargs_rejected_before_sql_server(self, conn_str): + """Test that APP in kwargs raises ValueError before even attempting to connect to SQL Server.""" + # Unlike connection strings (which are silently filtered), kwargs with APP should raise an error + # This prevents the connection attempt entirely + + with pytest.raises(ValueError) as exc_info: + connect(conn_str, APP="UserDefinedApp") + + assert "reserved and controlled by the driver" in str(exc_info.value) + assert "APP" in str(exc_info.value) or "app" in str(exc_info.value).lower() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_empty_value_raises_error(self, mock_ddbc_conn): + """Test that empty values in connection string raise ConnectionStringParseError.""" + mock_ddbc_conn.return_value = MagicMock() + + # Empty value should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=localhost;Database=;UID=user") + + assert "Empty value for keyword 'database'" in str(exc_info.value) + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_multiple_empty_values_raises_error(self, mock_ddbc_conn): + """Test that multiple empty values are all collected in error.""" + mock_ddbc_conn.return_value = MagicMock() + + # Multiple empty values + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=;Database=mydb;PWD=") + + errors = exc_info.value.errors + assert len(errors) >= 2 + assert any("Empty value for keyword 'server'" in err for err in errors) + assert any("Empty value for keyword 'pwd'" in err for err in errors) + + + + + +