From 2173cec0d92c994616a59fe9515f3e0d715d1d1f Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Mon, 6 Oct 2025 16:47:08 +0530 Subject: [PATCH 1/3] FIX: Validate access tokens to prevent crashes --- mssql_python/pybind/connection/connection.cpp | 11 +- tests/test_008_auth.py | 152 +++++++++++++++++- 2 files changed, 161 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 3311c697..c838ea87 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -180,7 +180,16 @@ SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { ptr = reinterpret_cast(static_cast(intValue)); length = SQL_IS_INTEGER; } else if (py::isinstance(value) || py::isinstance(value)) { - buffer = value.cast(); // stack buffer + buffer = value.cast(); // local string object (data is heap-allocated) + + // DEFENSIVE FIX: Protect against ODBC driver bug with short access tokens + // Microsoft ODBC Driver 18 crashes when given access tokens shorter than 32 bytes + // Real access tokens are typically 100+ bytes, so reject anything under 32 bytes + if (attribute == SQL_COPT_SS_ACCESS_TOKEN && buffer.size() < 32) { + LOG("Access token too short (< 32 bytes) - protecting against ODBC driver crash"); + ThrowStdException("Failed to set access token: Access token must be at least 32 bytes long"); + } + ptr = buffer.data(); length = static_cast(buffer.size()); } else { diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 6bf6c410..be697bff 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -219,4 +219,154 @@ def test_error_handling(): # Test non-string input with pytest.raises(ValueError, match="Connection string must be a string"): - process_connection_string(None) \ No newline at end of file + process_connection_string(None) + + +def test_short_access_token_protection_blocks_short_tokens(): + """ + Test protection against ODBC driver segfault with short access tokens. + + Microsoft ODBC Driver 18 has a bug where it crashes (segfaults) when given + access tokens shorter than 32 bytes. This test verifies that our defensive + fix properly rejects such tokens before they reach the ODBC driver. + + The fix is implemented in Connection::setAttribute() in connection.cpp. + + This test runs in a subprocess to isolate potential segfaults. + """ + import os + import subprocess + + # Get connection string and remove UID/Pwd to force token-only mode + conn_str = os.getenv("DB_CONNECTION_STRING") + if not conn_str: + pytest.skip("DB_CONNECTION_STRING environment variable not set") + + # Remove authentication to force pure token mode + conn_str_no_auth = conn_str + for remove_param in ["UID=", "Pwd=", "uid=", "pwd="]: + if remove_param in conn_str_no_auth: + parts = conn_str_no_auth.split(";") + parts = [p for p in parts if not p.lower().startswith(remove_param.lower())] + conn_str_no_auth = ";".join(parts) + + # Escape connection string for embedding in subprocess code + escaped_conn_str = conn_str_no_auth.replace('\\', '\\\\').replace('"', '\\"') + + # Test cases for problematic token lengths (0-31 bytes) + problematic_lengths = [0, 1, 4, 8, 16, 31] + + for length in problematic_lengths: + code = f""" +import sys +from mssql_python import connect + +conn_str = "{escaped_conn_str}" +fake_token = b"x" * {length} +attrs_before = {{1256: fake_token}} # SQL_COPT_SS_ACCESS_TOKEN = 1256 + +try: + connect(conn_str, attrs_before=attrs_before) + print("ERROR: Should have raised exception for length {length}") + sys.exit(1) +except Exception as e: + error_msg = str(e) + if "Access token must be at least 32 bytes" in error_msg: + print(f"PASS: Got expected protective error for length {length}") + sys.exit(0) + else: + print(f"ERROR: Got unexpected error for length {length}: {{error_msg}}") + sys.exit(1) +""" + + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True + ) + + # Should not segfault (exit code 139 on Linux, 134 on macOS, -11 on some systems) + assert result.returncode not in [134, 139, -11], \ + f"Segfault detected for token length {length}! STDERR: {result.stderr}" + + # Should exit cleanly with our protective error + assert result.returncode == 0, \ + f"Expected protective error for length {length}. Exit code: {result.returncode}\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}" + + assert "PASS" in result.stdout, \ + f"Expected PASS message for length {length}, got: {result.stdout}" + + +def test_short_access_token_protection_allows_valid_tokens(): + """ + Test that legitimate-sized access tokens (== 32 bytes) are NOT blocked by protection. + + This verifies that our defensive fix only blocks dangerously short tokens, + and allows legitimate tokens to proceed (even though they may fail authentication + if they're invalid, which is expected and proper behavior). + + Runs in separate subprocess to avoid ODBC driver state pollution from earlier tests. + """ + import os + import subprocess + + # Get connection string and remove UID/Pwd to force token-only mode + conn_str = os.getenv("DB_CONNECTION_STRING") + if not conn_str: + pytest.skip("DB_CONNECTION_STRING environment variable not set") + + # Remove authentication to force pure token mode + conn_str_no_auth = conn_str + for remove_param in ["UID=", "Pwd=", "uid=", "pwd="]: + if remove_param in conn_str_no_auth: + parts = conn_str_no_auth.split(";") + parts = [p for p in parts if not p.lower().startswith(remove_param.lower())] + conn_str_no_auth = ";".join(parts) + + # Escape connection string for embedding in subprocess code + escaped_conn_str = conn_str_no_auth.replace('\\', '\\\\').replace('"', '\\"') + + # Test that legitimate-sized tokens don't get blocked (but will fail auth) + code = f""" +import sys +from mssql_python import connect + +conn_str = "{escaped_conn_str}" +legitimate_token = b"x" * 32 # 32 bytes - exactly the minimum +attrs_before = {{1256: legitimate_token}} + +try: + connect(conn_str, attrs_before=attrs_before) + print("ERROR: Should have failed authentication") + sys.exit(1) +except Exception as e: + error_msg = str(e) + # Should NOT get our protective error + if "Access token must be at least 32 bytes" in error_msg: + print(f"ERROR: Legitimate token was incorrectly blocked: {{error_msg}}") + sys.exit(1) + # Should get an authentication/connection error instead + elif any(keyword in error_msg.lower() for keyword in ["login", "auth", "tcp", "connect"]): + print(f"PASS: Legitimate token not blocked, got expected auth error") + sys.exit(0) + else: + print(f"ERROR: Unexpected error for legitimate token: {{error_msg}}") + sys.exit(1) +""" + + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True + ) + + # Should not segfault + assert result.returncode not in [134, 139, -11], \ + f"Segfault detected for legitimate token! STDERR: {result.stderr}" + + # Should pass the test + assert result.returncode == 0, \ + f"Legitimate token test failed. Exit code: {result.returncode}\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}" + + assert "PASS" in result.stdout, \ + f"Expected PASS message for legitimate token, got: {result.stdout}" From 96e60b41a5c1bf0b54a62f99c42e4b8f3cb01584 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Mon, 6 Oct 2025 17:36:54 +0530 Subject: [PATCH 2/3] Fixed approach altogether --- mssql_python/auth.py | 68 +++++++++++++- mssql_python/connection.py | 10 ++ mssql_python/constants.py | 3 + mssql_python/pybind/connection/connection.cpp | 10 +- tests/test_008_auth.py | 93 ++++++++++++------- 5 files changed, 140 insertions(+), 44 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index c7e6683a..57ea3f99 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -9,6 +9,67 @@ from typing import Tuple, Dict, Optional, Union from mssql_python.constants import AuthType +def validate_access_token_struct(token_struct: bytes) -> None: + """ + Validate ACCESSTOKEN structure to prevent ODBC driver crashes. + + The ODBC driver crashes (segfault on macOS/Linux, access violation on Windows) + when given malformed access tokens. This function validates the structure + before passing to the driver. + + ACCESSTOKEN structure: typedef struct { DWORD dataSize; BYTE data[]; } ACCESSTOKEN; + + Args: + token_struct (bytes): The ACCESSTOKEN structure to validate + + Raises: + ValueError: If the token structure is invalid + """ + # Check minimum size (4-byte header + data) + if len(token_struct) < 4: + raise ValueError( + f"Invalid access token: minimum 4 bytes required for ACCESSTOKEN structure, got {len(token_struct)} bytes" + ) + + # Extract declared size from first 4 bytes + declared_size = struct.unpack('= 6: + has_utf16_pattern = all([ + 0x20 <= token_data[0] <= 0x7E and token_data[1] == 0, # First char + 0x20 <= token_data[2] <= 0x7E and token_data[3] == 0, # Second char + 0x20 <= token_data[4] <= 0x7E and token_data[5] == 0 # Third char + ]) + + if not has_utf16_pattern: + raise ValueError( + "Invalid access token: must be UTF-16LE encoded JWT. " + "Expected alternating ASCII and null bytes (e.g., 'e\\x00y\\x00J\\x00' for 'eyJ')" + ) + class AADAuth: """Handles Azure Active Directory authentication""" @@ -16,7 +77,12 @@ class AADAuth: def get_token_struct(token: str) -> bytes: """Convert token to SQL Server compatible format""" token_bytes = token.encode("UTF-16-LE") - return struct.pack(f" bytes: diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 832d2aac..198f0036 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -148,6 +148,16 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef connection_str, **kwargs ) self._attrs_before = attrs_before or {} + + # Validate access token if provided directly via attrs_before + if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: + from mssql_python.auth import validate_access_token_struct + token_struct = self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] + if isinstance(token_struct, (bytes, bytearray)): + try: + validate_access_token_struct(bytes(token_struct)) + except ValueError as e: + raise ValueError(f"Invalid access token in attrs_before: {e}") from e # Initialize encoding settings with defaults for Python 3 # Python 3 only has str (which is Unicode), so we use utf-16le by default diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 05df3e14..b64b0ee4 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -136,6 +136,9 @@ class ConstantsDDBC(Enum): SQL_QUICK = 0 SQL_ENSURE = 1 + # Connection Attributes + SQL_COPT_SS_ACCESS_TOKEN = 1256 + class GetInfoConstants(Enum): """ These constants are used with various methods like getinfo(). diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index c838ea87..b6322bc4 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -7,6 +7,7 @@ #include "connection.h" #include "connection_pool.h" #include +#include #include #define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token @@ -181,15 +182,6 @@ SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { length = SQL_IS_INTEGER; } else if (py::isinstance(value) || py::isinstance(value)) { buffer = value.cast(); // local string object (data is heap-allocated) - - // DEFENSIVE FIX: Protect against ODBC driver bug with short access tokens - // Microsoft ODBC Driver 18 crashes when given access tokens shorter than 32 bytes - // Real access tokens are typically 100+ bytes, so reject anything under 32 bytes - if (attribute == SQL_COPT_SS_ACCESS_TOKEN && buffer.size() < 32) { - LOG("Access token too short (< 32 bytes) - protecting against ODBC driver crash"); - ThrowStdException("Failed to set access token: Access token must be at least 32 bytes long"); - } - ptr = buffer.data(); length = static_cast(buffer.size()); } else { diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index be697bff..633906d0 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -224,15 +224,20 @@ def test_error_handling(): def test_short_access_token_protection_blocks_short_tokens(): """ - Test protection against ODBC driver segfault with short access tokens. + Test protection against ODBC driver crashes with malformed access tokens. - Microsoft ODBC Driver 18 has a bug where it crashes (segfaults) when given - access tokens shorter than 32 bytes. This test verifies that our defensive - fix properly rejects such tokens before they reach the ODBC driver. + Microsoft ODBC Driver 18 has a bug where it crashes (segfault on macOS/Linux, + access violation on Windows) when given malformed access tokens. This test + verifies that our defensive validation properly rejects invalid tokens before + they reach the ODBC driver. - The fix is implemented in Connection::setAttribute() in connection.cpp. + The validation is implemented in Connection::setAttribute() in connection.cpp + and checks: + 1. Minimum size (4 bytes for ACCESSTOKEN header) + 2. Structure integrity (declared size matches actual size) + 3. Non-empty data (not all zeros) - This test runs in a subprocess to isolate potential segfaults. + This test runs in a subprocess to isolate potential crashes. """ import os import subprocess @@ -253,29 +258,40 @@ def test_short_access_token_protection_blocks_short_tokens(): # Escape connection string for embedding in subprocess code escaped_conn_str = conn_str_no_auth.replace('\\', '\\\\').replace('"', '\\"') - # Test cases for problematic token lengths (0-31 bytes) - problematic_lengths = [0, 1, 4, 8, 16, 31] + # Test cases for problematic tokens + test_cases = [ + (b"", "empty token"), + (b"x" * 3, "too small (< 4 bytes)"), + (b"\x00\x00\x00\x00", "header only, no data"), + (b"\x10\x00\x00\x00" + b"\x00" * 16, "size mismatch (declares 16, total 20)"), + (b"\x10\x00\x00\x00" + b"\x00" * 12, "size mismatch (declares 16, has 12)"), + (b"\x08\x00\x00\x00" + b"\x00" * 8, "all zeros data"), + ] - for length in problematic_lengths: + for token, description in test_cases: + # Convert bytes to hex string for safe embedding in subprocess code + token_hex = token.hex() + code = f""" import sys from mssql_python import connect conn_str = "{escaped_conn_str}" -fake_token = b"x" * {length} +fake_token = bytes.fromhex("{token_hex}") attrs_before = {{1256: fake_token}} # SQL_COPT_SS_ACCESS_TOKEN = 1256 try: connect(conn_str, attrs_before=attrs_before) - print("ERROR: Should have raised exception for length {length}") + print("ERROR: Should have raised exception for {description}") sys.exit(1) except Exception as e: error_msg = str(e) - if "Access token must be at least 32 bytes" in error_msg: - print(f"PASS: Got expected protective error for length {length}") + # Check for our validation error messages + if "Invalid access token" in error_msg: + print(f"PASS: Got expected validation error for {description}") sys.exit(0) else: - print(f"ERROR: Got unexpected error for length {length}: {{error_msg}}") + print(f"ERROR: Got unexpected error for {description}: {{error_msg}}") sys.exit(1) """ @@ -285,30 +301,31 @@ def test_short_access_token_protection_blocks_short_tokens(): text=True ) - # Should not segfault (exit code 139 on Linux, 134 on macOS, -11 on some systems) + # Should not crash (exit code 139 on Linux, 134 on macOS, -11 on some systems) assert result.returncode not in [134, 139, -11], \ - f"Segfault detected for token length {length}! STDERR: {result.stderr}" + f"Crash detected for {description}! STDERR: {result.stderr}" - # Should exit cleanly with our protective error + # Should exit cleanly with our validation error assert result.returncode == 0, \ - f"Expected protective error for length {length}. Exit code: {result.returncode}\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}" + f"Expected validation error for {description}. Exit code: {result.returncode}\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}" assert "PASS" in result.stdout, \ - f"Expected PASS message for length {length}, got: {result.stdout}" + f"Expected PASS message for {description}, got: {result.stdout}" def test_short_access_token_protection_allows_valid_tokens(): """ - Test that legitimate-sized access tokens (== 32 bytes) are NOT blocked by protection. + Test that properly formatted access tokens are NOT blocked by validation. - This verifies that our defensive fix only blocks dangerously short tokens, - and allows legitimate tokens to proceed (even though they may fail authentication - if they're invalid, which is expected and proper behavior). + This verifies that our defensive validation only blocks malformed tokens, + and allows properly structured tokens to proceed (even though they may fail + authentication if the token is invalid, which is expected behavior). Runs in separate subprocess to avoid ODBC driver state pollution from earlier tests. """ import os import subprocess + import struct # Get connection string and remove UID/Pwd to force token-only mode conn_str = os.getenv("DB_CONNECTION_STRING") @@ -326,14 +343,22 @@ def test_short_access_token_protection_allows_valid_tokens(): # Escape connection string for embedding in subprocess code escaped_conn_str = conn_str_no_auth.replace('\\', '\\\\').replace('"', '\\"') - # Test that legitimate-sized tokens don't get blocked (but will fail auth) + # Test that properly formatted tokens don't get blocked (but will fail auth) + # Create a properly formatted UTF-16LE encoded ACCESSTOKEN structure code = f""" import sys +import struct from mssql_python import connect conn_str = "{escaped_conn_str}" -legitimate_token = b"x" * 32 # 32 bytes - exactly the minimum -attrs_before = {{1256: legitimate_token}} + +# Create properly formatted ACCESSTOKEN with UTF-16LE encoded data +# Use a fake JWT-like string that encodes properly +fake_jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" # Base64-like JWT header +token_data = fake_jwt.encode('utf-16-le') # Properly encode as UTF-16LE +token_struct = struct.pack(f' Date: Mon, 6 Oct 2025 17:40:42 +0530 Subject: [PATCH 3/3] cleanup --- mssql_python/auth.py | 7 +------ mssql_python/pybind/connection/connection.cpp | 3 +-- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 57ea3f99..b6275e44 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -77,12 +77,7 @@ class AADAuth: def get_token_struct(token: str) -> bytes: """Convert token to SQL Server compatible format""" token_bytes = token.encode("UTF-16-LE") - token_struct = struct.pack(f" bytes: diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index b6322bc4..3311c697 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -7,7 +7,6 @@ #include "connection.h" #include "connection_pool.h" #include -#include #include #define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token @@ -181,7 +180,7 @@ SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { ptr = reinterpret_cast(static_cast(intValue)); length = SQL_IS_INTEGER; } else if (py::isinstance(value) || py::isinstance(value)) { - buffer = value.cast(); // local string object (data is heap-allocated) + buffer = value.cast(); // stack buffer ptr = buffer.data(); length = static_cast(buffer.size()); } else {