diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 9a828011..12b806cf 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3870,12 +3870,13 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch } } + SQLULEN numRowsFetched = 0; // If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap if (!lobColumns.empty()) { LOG("FetchMany_wrap: LOB columns detected (%zu columns), using per-row " "SQLGetData path", lobColumns.size()); - while (true) { + while (numRowsFetched < (SQLULEN)fetchSize) { ret = SQLFetch_ptr(hStmt); if (ret == SQL_NO_DATA) break; @@ -3883,9 +3884,9 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; py::list row; - SQLGetData_wrap(StatementHandle, numCols, - row); // <-- streams LOBs correctly + SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly rows.append(row); + numRowsFetched++; } return SQL_SUCCESS; } @@ -3899,8 +3900,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch LOG("FetchMany_wrap: Error when binding columns - SQLRETURN=%d", ret); return ret; } - - SQLULEN numRowsFetched; + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index cfc4ccf4..d209b2dc 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -30,6 +30,7 @@ integer_column INTEGER, float_column FLOAT, wvarchar_column NVARCHAR(255), + lob_wvarchar_column NVARCHAR(MAX), time_column TIME, datetime_column DATETIME, date_column DATE, @@ -47,6 +48,7 @@ 2147483647, 1.23456789, "nvarchar data", + "nvarchar data", time(12, 34, 56), datetime(2024, 5, 20, 12, 34, 56, 123000), date(2024, 5, 20), @@ -65,6 +67,7 @@ 0, 0.0, "test1", + "nvarchar data", time(0, 0, 0), datetime(2024, 1, 1, 0, 0, 0), date(2024, 1, 1), @@ -79,6 +82,7 @@ 1, 1.1, "test2", + "test2", time(1, 1, 1), datetime(2024, 2, 2, 1, 1, 1), date(2024, 2, 2), @@ -93,6 +97,7 @@ 2147483647, 1.23456789, "test3", + "test3", time(12, 34, 56), datetime(2024, 5, 20, 12, 34, 56, 123000), date(2024, 5, 20), @@ -821,7 +826,7 @@ def test_insert_args(cursor, db_connection): cursor.execute( """ INSERT INTO #pytest_all_data_types VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) """, TEST_DATA[0], @@ -836,6 +841,7 @@ def test_insert_args(cursor, db_connection): TEST_DATA[9], TEST_DATA[10], TEST_DATA[11], + TEST_DATA[12], ) db_connection.commit() cursor.execute("SELECT * FROM #pytest_all_data_types WHERE id = 1") @@ -855,7 +861,7 @@ def test_parametrized_insert(cursor, db_connection, data): cursor.execute( """ INSERT INTO #pytest_all_data_types VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) """, [None if v is None else v for v in data], @@ -930,14 +936,34 @@ def test_rowcount_executemany(cursor, db_connection): def test_fetchone(cursor): """Test fetching a single row""" - cursor.execute("SELECT * FROM #pytest_all_data_types WHERE id = 1") + cursor.execute( + "SELECT id, bit_column, tinyint_column, smallint_column, bigint_column, integer_column, float_column, wvarchar_column, time_column, datetime_column, date_column, real_column FROM #pytest_all_data_types" + ) row = cursor.fetchone() assert row is not None, "No row returned" assert len(row) == 12, "Incorrect number of columns" +def test_fetchone_lob(cursor): + """Test fetching a single row with LOB columns""" + cursor.execute("SELECT * FROM #pytest_all_data_types") + row = cursor.fetchone() + assert row is not None, "No row returned" + assert len(row) == 13, "Incorrect number of columns" + + def test_fetchmany(cursor): """Test fetching multiple rows""" + cursor.execute( + "SELECT id, bit_column, tinyint_column, smallint_column, bigint_column, integer_column, float_column, wvarchar_column, time_column, datetime_column, date_column, real_column FROM #pytest_all_data_types" + ) + rows = cursor.fetchmany(2) + assert isinstance(rows, list), "fetchmany should return a list" + assert len(rows) == 2, "Incorrect number of rows returned" + + +def test_fetchmany_lob(cursor): + """Test fetching multiple rows with LOB columns""" cursor.execute("SELECT * FROM #pytest_all_data_types") rows = cursor.fetchmany(2) assert isinstance(rows, list), "fetchmany should return a list" @@ -947,12 +973,32 @@ def test_fetchmany(cursor): def test_fetchmany_with_arraysize(cursor, db_connection): """Test fetchmany with arraysize""" cursor.arraysize = 3 - cursor.execute("SELECT * FROM #pytest_all_data_types") + cursor.execute( + "SELECT id, bit_column, tinyint_column, smallint_column, bigint_column, integer_column, float_column, wvarchar_column, time_column, datetime_column, date_column, real_column FROM #pytest_all_data_types" + ) rows = cursor.fetchmany() assert len(rows) == 3, "fetchmany with arraysize returned incorrect number of rows" +def test_fetchmany_lob_with_arraysize(cursor, db_connection): + """Test fetchmany with arraysize with LOB columns""" + cursor.arraysize = 3 + cursor.execute("SELECT * FROM #pytest_all_data_types") + rows = cursor.fetchmany() + assert len(rows) == 3, "fetchmany_lob with arraysize returned incorrect number of rows" + + def test_fetchall(cursor): + """Test fetching all rows""" + cursor.execute( + "SELECT id, bit_column, tinyint_column, smallint_column, bigint_column, integer_column, float_column, wvarchar_column, time_column, datetime_column, date_column, real_column FROM #pytest_all_data_types" + ) + rows = cursor.fetchall() + assert isinstance(rows, list), "fetchall should return a list" + assert len(rows) == len(PARAM_TEST_DATA), "Incorrect number of rows returned" + + +def test_fetchall_lob(cursor): """Test fetching all rows""" cursor.execute("SELECT * FROM #pytest_all_data_types") rows = cursor.fetchall() @@ -980,10 +1026,11 @@ def test_execute_invalid_query(cursor): # assert row[5] == TEST_DATA[5], "Integer mismatch" # assert round(row[6], 5) == round(TEST_DATA[6], 5), "Float mismatch" # assert row[7] == TEST_DATA[7], "Nvarchar mismatch" -# assert row[8] == TEST_DATA[8], "Time mismatch" -# assert row[9] == TEST_DATA[9], "Datetime mismatch" -# assert row[10] == TEST_DATA[10], "Date mismatch" -# assert round(row[11], 5) == round(TEST_DATA[11], 5), "Real mismatch" +# assert row[8] == TEST_DATA[8], "Nvarchar max mismatch" +# assert row[9] == TEST_DATA[9], "Time mismatch" +# assert row[10] == TEST_DATA[10], "Datetime mismatch" +# assert row[11] == TEST_DATA[11], "Date mismatch" +# assert round(row[12], 5) == round(TEST_DATA[12], 5), "Real mismatch" def test_arraysize(cursor): @@ -998,7 +1045,7 @@ def test_description(cursor): """Test description""" cursor.execute("SELECT * FROM #pytest_all_data_types WHERE id = 1") desc = cursor.description - assert len(desc) == 12, "Description length mismatch" + assert len(desc) == 13, "Description length mismatch" assert desc[0][0] == "id", "Description column name mismatch"