diff --git a/mssql_python/auth.py b/mssql_python/auth.py index c7e6683a..b6275e44 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -9,6 +9,67 @@ from typing import Tuple, Dict, Optional, Union from mssql_python.constants import AuthType +def validate_access_token_struct(token_struct: bytes) -> None: + """ + Validate ACCESSTOKEN structure to prevent ODBC driver crashes. + + The ODBC driver crashes (segfault on macOS/Linux, access violation on Windows) + when given malformed access tokens. This function validates the structure + before passing to the driver. + + ACCESSTOKEN structure: typedef struct { DWORD dataSize; BYTE data[]; } ACCESSTOKEN; + + Args: + token_struct (bytes): The ACCESSTOKEN structure to validate + + Raises: + ValueError: If the token structure is invalid + """ + # Check minimum size (4-byte header + data) + if len(token_struct) < 4: + raise ValueError( + f"Invalid access token: minimum 4 bytes required for ACCESSTOKEN structure, got {len(token_struct)} bytes" + ) + + # Extract declared size from first 4 bytes + declared_size = struct.unpack('= 6: + has_utf16_pattern = all([ + 0x20 <= token_data[0] <= 0x7E and token_data[1] == 0, # First char + 0x20 <= token_data[2] <= 0x7E and token_data[3] == 0, # Second char + 0x20 <= token_data[4] <= 0x7E and token_data[5] == 0 # Third char + ]) + + if not has_utf16_pattern: + raise ValueError( + "Invalid access token: must be UTF-16LE encoded JWT. " + "Expected alternating ASCII and null bytes (e.g., 'e\\x00y\\x00J\\x00' for 'eyJ')" + ) + class AADAuth: """Handles Azure Active Directory authentication""" diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 832d2aac..198f0036 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -148,6 +148,16 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef connection_str, **kwargs ) self._attrs_before = attrs_before or {} + + # Validate access token if provided directly via attrs_before + if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: + from mssql_python.auth import validate_access_token_struct + token_struct = self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] + if isinstance(token_struct, (bytes, bytearray)): + try: + validate_access_token_struct(bytes(token_struct)) + except ValueError as e: + raise ValueError(f"Invalid access token in attrs_before: {e}") from e # Initialize encoding settings with defaults for Python 3 # Python 3 only has str (which is Unicode), so we use utf-16le by default diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 05df3e14..b64b0ee4 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -136,6 +136,9 @@ class ConstantsDDBC(Enum): SQL_QUICK = 0 SQL_ENSURE = 1 + # Connection Attributes + SQL_COPT_SS_ACCESS_TOKEN = 1256 + class GetInfoConstants(Enum): """ These constants are used with various methods like getinfo(). diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 6bf6c410..633906d0 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -219,4 +219,179 @@ def test_error_handling(): # Test non-string input with pytest.raises(ValueError, match="Connection string must be a string"): - process_connection_string(None) \ No newline at end of file + process_connection_string(None) + + +def test_short_access_token_protection_blocks_short_tokens(): + """ + Test protection against ODBC driver crashes with malformed access tokens. + + Microsoft ODBC Driver 18 has a bug where it crashes (segfault on macOS/Linux, + access violation on Windows) when given malformed access tokens. This test + verifies that our defensive validation properly rejects invalid tokens before + they reach the ODBC driver. + + The validation is implemented in Connection::setAttribute() in connection.cpp + and checks: + 1. Minimum size (4 bytes for ACCESSTOKEN header) + 2. Structure integrity (declared size matches actual size) + 3. Non-empty data (not all zeros) + + This test runs in a subprocess to isolate potential crashes. + """ + import os + import subprocess + + # Get connection string and remove UID/Pwd to force token-only mode + conn_str = os.getenv("DB_CONNECTION_STRING") + if not conn_str: + pytest.skip("DB_CONNECTION_STRING environment variable not set") + + # Remove authentication to force pure token mode + conn_str_no_auth = conn_str + for remove_param in ["UID=", "Pwd=", "uid=", "pwd="]: + if remove_param in conn_str_no_auth: + parts = conn_str_no_auth.split(";") + parts = [p for p in parts if not p.lower().startswith(remove_param.lower())] + conn_str_no_auth = ";".join(parts) + + # Escape connection string for embedding in subprocess code + escaped_conn_str = conn_str_no_auth.replace('\\', '\\\\').replace('"', '\\"') + + # Test cases for problematic tokens + test_cases = [ + (b"", "empty token"), + (b"x" * 3, "too small (< 4 bytes)"), + (b"\x00\x00\x00\x00", "header only, no data"), + (b"\x10\x00\x00\x00" + b"\x00" * 16, "size mismatch (declares 16, total 20)"), + (b"\x10\x00\x00\x00" + b"\x00" * 12, "size mismatch (declares 16, has 12)"), + (b"\x08\x00\x00\x00" + b"\x00" * 8, "all zeros data"), + ] + + for token, description in test_cases: + # Convert bytes to hex string for safe embedding in subprocess code + token_hex = token.hex() + + code = f""" +import sys +from mssql_python import connect + +conn_str = "{escaped_conn_str}" +fake_token = bytes.fromhex("{token_hex}") +attrs_before = {{1256: fake_token}} # SQL_COPT_SS_ACCESS_TOKEN = 1256 + +try: + connect(conn_str, attrs_before=attrs_before) + print("ERROR: Should have raised exception for {description}") + sys.exit(1) +except Exception as e: + error_msg = str(e) + # Check for our validation error messages + if "Invalid access token" in error_msg: + print(f"PASS: Got expected validation error for {description}") + sys.exit(0) + else: + print(f"ERROR: Got unexpected error for {description}: {{error_msg}}") + sys.exit(1) +""" + + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True + ) + + # Should not crash (exit code 139 on Linux, 134 on macOS, -11 on some systems) + assert result.returncode not in [134, 139, -11], \ + f"Crash detected for {description}! STDERR: {result.stderr}" + + # Should exit cleanly with our validation error + assert result.returncode == 0, \ + f"Expected validation error for {description}. Exit code: {result.returncode}\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}" + + assert "PASS" in result.stdout, \ + f"Expected PASS message for {description}, got: {result.stdout}" + + +def test_short_access_token_protection_allows_valid_tokens(): + """ + Test that properly formatted access tokens are NOT blocked by validation. + + This verifies that our defensive validation only blocks malformed tokens, + and allows properly structured tokens to proceed (even though they may fail + authentication if the token is invalid, which is expected behavior). + + Runs in separate subprocess to avoid ODBC driver state pollution from earlier tests. + """ + import os + import subprocess + import struct + + # Get connection string and remove UID/Pwd to force token-only mode + conn_str = os.getenv("DB_CONNECTION_STRING") + if not conn_str: + pytest.skip("DB_CONNECTION_STRING environment variable not set") + + # Remove authentication to force pure token mode + conn_str_no_auth = conn_str + for remove_param in ["UID=", "Pwd=", "uid=", "pwd="]: + if remove_param in conn_str_no_auth: + parts = conn_str_no_auth.split(";") + parts = [p for p in parts if not p.lower().startswith(remove_param.lower())] + conn_str_no_auth = ";".join(parts) + + # Escape connection string for embedding in subprocess code + escaped_conn_str = conn_str_no_auth.replace('\\', '\\\\').replace('"', '\\"') + + # Test that properly formatted tokens don't get blocked (but will fail auth) + # Create a properly formatted UTF-16LE encoded ACCESSTOKEN structure + code = f""" +import sys +import struct +from mssql_python import connect + +conn_str = "{escaped_conn_str}" + +# Create properly formatted ACCESSTOKEN with UTF-16LE encoded data +# Use a fake JWT-like string that encodes properly +fake_jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" # Base64-like JWT header +token_data = fake_jwt.encode('utf-16-le') # Properly encode as UTF-16LE +token_struct = struct.pack(f'