Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,67 @@
from typing import Tuple, Dict, Optional, Union
from mssql_python.constants import AuthType

def validate_access_token_struct(token_struct: bytes) -> None:
"""
Validate ACCESSTOKEN structure to prevent ODBC driver crashes.
The ODBC driver crashes (segfault on macOS/Linux, access violation on Windows)
when given malformed access tokens. This function validates the structure
before passing to the driver.
ACCESSTOKEN structure: typedef struct { DWORD dataSize; BYTE data[]; } ACCESSTOKEN;
Args:
token_struct (bytes): The ACCESSTOKEN structure to validate
Raises:
ValueError: If the token structure is invalid
"""
# Check minimum size (4-byte header + data)
if len(token_struct) < 4:
raise ValueError(
f"Invalid access token: minimum 4 bytes required for ACCESSTOKEN structure, got {len(token_struct)} bytes"
)

# Extract declared size from first 4 bytes
declared_size = struct.unpack('<I', token_struct[:4])[0]

Comment on lines +35 to +36
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

struct is used but not imported in this module, which will raise a NameError when this function is called. Add import struct at the top of the file.

Copilot uses AI. Check for mistakes.

# Validate structure integrity
total_size = len(token_struct)
expected_size = declared_size + 4
if expected_size != total_size:
raise ValueError(
f"Invalid access token: size mismatch in ACCESSTOKEN structure. "
f"Header declares {declared_size} bytes, but structure has {total_size - 4} bytes of data"
)

# Validate token data is not empty/all zeros
token_data = token_struct[4:]
if not any(token_data):
raise ValueError("Invalid access token: token data is empty or all zeros")

# Validate UTF-16LE encoding (ODBC driver requirement)
# JWT tokens in UTF-16LE have null bytes interleaved with ASCII characters
if declared_size % 2 != 0:
raise ValueError(
f"Invalid access token: must be UTF-16LE encoded (got odd byte length {declared_size})"
)

# Check for UTF-16LE pattern: ASCII characters with interleaved null bytes
# Real JWTs start with "eyJ" in UTF-16LE: 65 00 79 00 4A 00
if declared_size >= 6:
has_utf16_pattern = all([
0x20 <= token_data[0] <= 0x7E and token_data[1] == 0, # First char
0x20 <= token_data[2] <= 0x7E and token_data[3] == 0, # Second char
0x20 <= token_data[4] <= 0x7E and token_data[5] == 0 # Third char
])

if not has_utf16_pattern:
raise ValueError(
"Invalid access token: must be UTF-16LE encoded JWT. "
"Expected alternating ASCII and null bytes (e.g., 'e\\x00y\\x00J\\x00' for 'eyJ')"
)

class AADAuth:
"""Handles Azure Active Directory authentication"""

Expand Down
10 changes: 10 additions & 0 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef
connection_str, **kwargs
)
self._attrs_before = attrs_before or {}

# Validate access token if provided directly via attrs_before
if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before:
from mssql_python.auth import validate_access_token_struct
token_struct = self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value]
if isinstance(token_struct, (bytes, bytearray)):
try:
validate_access_token_struct(bytes(token_struct))
except ValueError as e:
raise ValueError(f"Invalid access token in attrs_before: {e}") from e

# Initialize encoding settings with defaults for Python 3
# Python 3 only has str (which is Unicode), so we use utf-16le by default
Expand Down
3 changes: 3 additions & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ class ConstantsDDBC(Enum):
SQL_QUICK = 0
SQL_ENSURE = 1

# Connection Attributes
SQL_COPT_SS_ACCESS_TOKEN = 1256

class GetInfoConstants(Enum):
"""
These constants are used with various methods like getinfo().
Expand Down
177 changes: 176 additions & 1 deletion tests/test_008_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,179 @@ def test_error_handling():

# Test non-string input
with pytest.raises(ValueError, match="Connection string must be a string"):
process_connection_string(None)
process_connection_string(None)


def test_short_access_token_protection_blocks_short_tokens():
"""
Test protection against ODBC driver crashes with malformed access tokens.

Microsoft ODBC Driver 18 has a bug where it crashes (segfault on macOS/Linux,
access violation on Windows) when given malformed access tokens. This test
verifies that our defensive validation properly rejects invalid tokens before
they reach the ODBC driver.

The validation is implemented in Connection::setAttribute() in connection.cpp
and checks:
1. Minimum size (4 bytes for ACCESSTOKEN header)
2. Structure integrity (declared size matches actual size)
3. Non-empty data (not all zeros)

This test runs in a subprocess to isolate potential crashes.
"""
import os
import subprocess

# Get connection string and remove UID/Pwd to force token-only mode
conn_str = os.getenv("DB_CONNECTION_STRING")
if not conn_str:
pytest.skip("DB_CONNECTION_STRING environment variable not set")

# Remove authentication to force pure token mode
conn_str_no_auth = conn_str
for remove_param in ["UID=", "Pwd=", "uid=", "pwd="]:
if remove_param in conn_str_no_auth:
parts = conn_str_no_auth.split(";")
parts = [p for p in parts if not p.lower().startswith(remove_param.lower())]
conn_str_no_auth = ";".join(parts)

# Escape connection string for embedding in subprocess code
escaped_conn_str = conn_str_no_auth.replace('\\', '\\\\').replace('"', '\\"')

# Test cases for problematic tokens
test_cases = [
(b"", "empty token"),
(b"x" * 3, "too small (< 4 bytes)"),
(b"\x00\x00\x00\x00", "header only, no data"),
(b"\x10\x00\x00\x00" + b"\x00" * 16, "size mismatch (declares 16, total 20)"),
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The description claims a size mismatch, but a 16-byte declared size plus the 4-byte header correctly equals 20 bytes; this case actually fails due to all-zero data. Update the description or adjust the test data (e.g., declare 8 but provide 16) for clarity.

Suggested change
(b"\x10\x00\x00\x00" + b"\x00" * 16, "size mismatch (declares 16, total 20)"),
(b"\x08\x00\x00\x00" + b"\x00" * 16, "size mismatch (declares 8, total 20)"),

Copilot uses AI. Check for mistakes.

(b"\x10\x00\x00\x00" + b"\x00" * 12, "size mismatch (declares 16, has 12)"),
(b"\x08\x00\x00\x00" + b"\x00" * 8, "all zeros data"),
]

for token, description in test_cases:
# Convert bytes to hex string for safe embedding in subprocess code
token_hex = token.hex()

code = f"""
import sys
from mssql_python import connect

conn_str = "{escaped_conn_str}"
fake_token = bytes.fromhex("{token_hex}")
attrs_before = {{1256: fake_token}} # SQL_COPT_SS_ACCESS_TOKEN = 1256

try:
connect(conn_str, attrs_before=attrs_before)
print("ERROR: Should have raised exception for {description}")
sys.exit(1)
except Exception as e:
error_msg = str(e)
# Check for our validation error messages
if "Invalid access token" in error_msg:
print(f"PASS: Got expected validation error for {description}")
sys.exit(0)
else:
print(f"ERROR: Got unexpected error for {description}: {{error_msg}}")
sys.exit(1)
"""

result = subprocess.run(
[sys.executable, "-c", code],
capture_output=True,
text=True
)
Comment on lines +298 to +302
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys is used but never imported in this test function, which will raise a NameError before the subprocess runs. Add import sys near the other imports at the start of the function.

Copilot uses AI. Check for mistakes.


# Should not crash (exit code 139 on Linux, 134 on macOS, -11 on some systems)
assert result.returncode not in [134, 139, -11], \
f"Crash detected for {description}! STDERR: {result.stderr}"

# Should exit cleanly with our validation error
assert result.returncode == 0, \
f"Expected validation error for {description}. Exit code: {result.returncode}\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}"

assert "PASS" in result.stdout, \
f"Expected PASS message for {description}, got: {result.stdout}"


def test_short_access_token_protection_allows_valid_tokens():
"""
Test that properly formatted access tokens are NOT blocked by validation.

This verifies that our defensive validation only blocks malformed tokens,
and allows properly structured tokens to proceed (even though they may fail
authentication if the token is invalid, which is expected behavior).

Runs in separate subprocess to avoid ODBC driver state pollution from earlier tests.
"""
import os
import subprocess
import struct

# Get connection string and remove UID/Pwd to force token-only mode
conn_str = os.getenv("DB_CONNECTION_STRING")
if not conn_str:
pytest.skip("DB_CONNECTION_STRING environment variable not set")

# Remove authentication to force pure token mode
conn_str_no_auth = conn_str
for remove_param in ["UID=", "Pwd=", "uid=", "pwd="]:
if remove_param in conn_str_no_auth:
parts = conn_str_no_auth.split(";")
parts = [p for p in parts if not p.lower().startswith(remove_param.lower())]
conn_str_no_auth = ";".join(parts)

# Escape connection string for embedding in subprocess code
escaped_conn_str = conn_str_no_auth.replace('\\', '\\\\').replace('"', '\\"')

# Test that properly formatted tokens don't get blocked (but will fail auth)
# Create a properly formatted UTF-16LE encoded ACCESSTOKEN structure
code = f"""
import sys
import struct
from mssql_python import connect

conn_str = "{escaped_conn_str}"

# Create properly formatted ACCESSTOKEN with UTF-16LE encoded data
# Use a fake JWT-like string that encodes properly
fake_jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" # Base64-like JWT header
token_data = fake_jwt.encode('utf-16-le') # Properly encode as UTF-16LE
token_struct = struct.pack(f'<I{{len(token_data)}}s', len(token_data), token_data)

attrs_before = {{1256: token_struct}}

try:
connect(conn_str, attrs_before=attrs_before)
print("ERROR: Should have failed authentication")
sys.exit(1)
except Exception as e:
error_msg = str(e)
# Should NOT get our validation errors
if "Invalid access token" in error_msg:
print(f"ERROR: Valid token structure was incorrectly blocked: {{error_msg}}")
sys.exit(1)
# Should get an authentication/connection error instead
elif any(keyword in error_msg.lower() for keyword in ["login", "auth", "tcp", "connect", "token"]):
print(f"PASS: Valid token structure not blocked, got expected connection/auth error")
sys.exit(0)
else:
print(f"WARN: Got unexpected error (but structure passed validation): {{error_msg}}")
sys.exit(0) # Still pass - structure validation worked
"""

result = subprocess.run(
[sys.executable, "-c", code],
capture_output=True,
text=True
)
Comment on lines +382 to +386
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys is referenced here but not imported in this function, causing a NameError. Add import sys alongside the other local imports.

Copilot uses AI. Check for mistakes.


# Should not crash
assert result.returncode not in [134, 139, -11], \
f"Segfault detected for legitimate token! STDERR: {result.stderr}"

# Should pass the test
assert result.returncode == 0, \
f"Legitimate token test failed. Exit code: {result.returncode}\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}"

assert "PASS" in result.stdout, \
f"Expected PASS message for legitimate token, got: {result.stdout}"