diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 464eca4b..84f5e70e 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -2155,45 +2155,40 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - // TODO: revisit - HandleZeroColumnSizeAtFetch(columnSize); - std::unique_ptr dataBuffer(new SQLCHAR[columnSize]); - SQLLEN dataLen; - ret = SQLGetData_ptr(hStmt, i, SQL_C_BINARY, dataBuffer.get(), columnSize, &dataLen); + // Use streaming for large VARBINARY (columnSize unknown or > 8000) + if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > 8000) { + LOG("Streaming LOB for column {} (VARBINARY)", i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true)); + } else { + // Small VARBINARY, fetch directly + std::vector dataBuffer(columnSize); + SQLLEN dataLen; + ret = SQLGetData_ptr(hStmt, i, SQL_C_BINARY, dataBuffer.data(), columnSize, &dataLen); - if (SQL_SUCCEEDED(ret)) { - // TODO: Refactor these if's across other switches to avoid code duplication - if (dataLen > 0) { - if (static_cast(dataLen) <= columnSize) { - row.append(py::bytes(reinterpret_cast( - dataBuffer.get()), dataLen)); - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit + if (SQL_SUCCEEDED(ret)) { + if (dataLen > 0) { + if (static_cast(dataLen) <= columnSize) { + row.append(py::bytes(reinterpret_cast(dataBuffer.data()), dataLen)); + } else { + LOG("VARBINARY column {} data truncated, using streaming LOB", i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true)); + } + } else if (dataLen == SQL_NULL_DATA) { + row.append(py::none()); + } else if (dataLen == 0) { + row.append(py::bytes("")); + } else { std::ostringstream oss; - oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data " - << "to be retrieved is longer (" << dataLen << "). ColumnID - " - << i << ", datatype - " << dataType; + oss << "Unexpected negative length (" << dataLen << ") returned by SQLGetData. ColumnID=" + << i << ", dataType=" << dataType << ", bufferSize=" << columnSize; + LOG("Error: {}", oss.str()); ThrowStdException(oss.str()); } - } else if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); - } else if (dataLen == 0) { - // Empty bytes - row.append(py::bytes("")); - } else if (dataLen < 0) { - // This is unexpected - LOG("SQLGetData returned an unexpected negative data length. " - "Raising exception. Column ID - {}, Data Type - {}, Data Length - {}", - i, dataType, dataLen); - ThrowStdException("SQLGetData returned an unexpected negative data length"); + } else { + LOG("Error retrieving VARBINARY data for column {}. SQLGetData rc = {}", i, ret); + row.append(py::none()); } - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } + } break; } case SQL_TINYINT: { diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 74150910..473807e9 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -6113,34 +6113,71 @@ def test_binary_data_over_8000_bytes(cursor, db_connection): drop_table_if_exists(cursor, "#pytest_small_binary") db_connection.commit() -def test_binary_data_large(cursor, db_connection): - """Test insertion of binary data larger than 8000 bytes with streaming support.""" +def test_varbinarymax_insert_fetch(cursor, db_connection): + """Test for VARBINARY(MAX) insert and fetch (streaming support) using execute per row""" try: - drop_table_if_exists(cursor, "#pytest_large_binary") + # Create test table + drop_table_if_exists(cursor, "#pytest_varbinarymax") cursor.execute(""" - CREATE TABLE #pytest_large_binary ( - id INT PRIMARY KEY, - large_binary VARBINARY(MAX) + CREATE TABLE #pytest_varbinarymax ( + id INT, + binary_data VARBINARY(MAX) ) """) - - # Large binary data > 8000 bytes - large_data = b'A' * 10000 # 10 KB - cursor.execute("INSERT INTO #pytest_large_binary (id, large_binary) VALUES (?, ?)", (1, large_data)) + + # Prepare test data + test_data = [ + (2, b''), # Empty bytes + (3, b'1234567890'), # Small binary + (4, b'A' * 9000), # Large binary > 8000 (streaming) + (5, b'B' * 20000), # Large binary > 8000 (streaming) + (6, b'C' * 8000), # Edge case: exactly 8000 bytes + (7, b'D' * 8001), # Edge case: just over 8000 bytes + ] + + # Insert each row using execute + for row_id, binary in test_data: + cursor.execute("INSERT INTO #pytest_varbinarymax VALUES (?, ?)", (row_id, binary)) db_connection.commit() - print("Inserted large binary data (>8000 bytes) successfully.") - - # commented out for now - # cursor.execute("SELECT large_binary FROM #pytest_large_binary WHERE id=1") - # result = cursor.fetchone() - # assert result[0] == large_data, f"Large binary data mismatch, got {len(result[0])} bytes" - - # print("Large binary data (>8000 bytes) inserted and verified successfully.") - + + # ---------- FETCHONE TEST (multi-column) ---------- + cursor.execute("SELECT id, binary_data FROM #pytest_varbinarymax ORDER BY id") + rows = [] + while True: + row = cursor.fetchone() + if row is None: + break + rows.append(row) + + assert len(rows) == len(test_data), f"Expected {len(test_data)} rows, got {len(rows)}" + + # Validate each row + for i, (expected_id, expected_data) in enumerate(test_data): + fetched_id, fetched_data = rows[i] + assert fetched_id == expected_id, f"Row {i+1} ID mismatch: expected {expected_id}, got {fetched_id}" + assert isinstance(fetched_data, bytes), f"Row {i+1} expected bytes, got {type(fetched_data)}" + assert fetched_data == expected_data, f"Row {i+1} data mismatch" + + # ---------- FETCHALL TEST ---------- + cursor.execute("SELECT id, binary_data FROM #pytest_varbinarymax ORDER BY id") + all_rows = cursor.fetchall() + assert len(all_rows) == len(test_data) + + # ---------- FETCHMANY TEST ---------- + cursor.execute("SELECT id, binary_data FROM #pytest_varbinarymax ORDER BY id") + batch_size = 2 + batches = [] + while True: + batch = cursor.fetchmany(batch_size) + if not batch: + break + batches.extend(batch) + assert len(batches) == len(test_data) + except Exception as e: - pytest.fail(f"Large binary data insertion test failed: {e}") + pytest.fail(f"VARBINARY(MAX) insert/fetch test failed: {e}") finally: - drop_table_if_exists(cursor, "#pytest_large_binary") + drop_table_if_exists(cursor, "#pytest_varbinarymax") db_connection.commit() @@ -6303,6 +6340,40 @@ def test_binary_mostly_small_one_large(cursor, db_connection): drop_table_if_exists(cursor, "#pytest_mixed_size_binary") db_connection.commit() +def test_varbinarymax_insert_fetch_null(cursor, db_connection): + """Test insertion and retrieval of NULL value in VARBINARY(MAX) column.""" + try: + drop_table_if_exists(cursor, "#pytest_varbinarymax_null") + cursor.execute(""" + CREATE TABLE #pytest_varbinarymax_null ( + id INT, + binary_data VARBINARY(MAX) + ) + """) + + # Insert a row with NULL for binary_data + cursor.execute( + "INSERT INTO #pytest_varbinarymax_null VALUES (?, CAST(NULL AS VARBINARY(MAX)))", + (1,) + ) + db_connection.commit() + + # Fetch the row + cursor.execute("SELECT id, binary_data FROM #pytest_varbinarymax_null") + row = cursor.fetchone() + + assert row is not None, "No row fetched" + fetched_id, fetched_data = row + assert fetched_id == 1, "ID mismatch" + assert fetched_data is None, "Expected NULL for binary_data" + + except Exception as e: + pytest.fail(f"VARBINARY(MAX) NULL insert/fetch test failed: {e}") + + finally: + drop_table_if_exists(cursor, "#pytest_varbinarymax_null") + db_connection.commit() + def test_only_null_and_empty_binary(cursor, db_connection): """Test table with only NULL and empty binary values to ensure fallback doesn't produce size=0""" try: