Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,22 +1529,19 @@ def _compute_column_type(self, column):
sample_value = v

return sample_value, None, None

def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Prepare a database operation and execute it against all parameter sequences.
This version uses column-wise parameter binding and a single batched SQLExecute().
Args:
operation: SQL query or command.
seq_of_parameters: Sequence of sequences or mappings of parameters.

Raises:
Error: If the operation fails.
"""
self._check_closed()
self._reset_cursor()

# Clear any previous messages
self.messages = []

if not seq_of_parameters:
Expand All @@ -1570,6 +1567,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
param_count = len(sample_row)
param_info = ddbc_bindings.ParamInfo
parameters_type = []
any_dae = False

# Check if we have explicit input sizes set
if self._inputsizes:
Expand Down Expand Up @@ -1673,6 +1671,14 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
paraminfo.columnSize = max(max_binary_size, 1)

parameters_type.append(paraminfo)
if paraminfo.isDAE:
any_dae = True

if any_dae:
log('debug', "DAE parameters detected. Falling back to row-by-row execution with streaming.")
for row in seq_of_parameters:
self.execute(operation, row)
return

# Process parameters into column-wise format with possible type conversions
# First, convert any Decimal types as needed for NUMERIC/DECIMAL columns
Expand Down Expand Up @@ -1705,8 +1711,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
log('debug', "Executing batch query with %d parameter sets:\n%s",
len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters[:5])) # Limit to first 5 rows for large batches
)

# Execute batched statement

ret = ddbc_bindings.SQLExecuteMany(
self.hstmt,
operation,
Expand Down
69 changes: 59 additions & 10 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,7 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle,
size_t paramSetSize) {
SQLHANDLE hStmt = statementHandle->get();
SQLWCHAR* queryPtr;

#if defined(__APPLE__) || defined(__linux__)
std::vector<SQLWCHAR> queryBuffer = WStringToSQLWCHAR(query);
queryPtr = queryBuffer.data();
Expand All @@ -2008,15 +2009,63 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle,
#endif
RETCODE rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS);
if (!SQL_SUCCEEDED(rc)) return rc;
std::vector<std::shared_ptr<void>> paramBuffers;
rc = BindParameterArray(hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers);
if (!SQL_SUCCEEDED(rc)) return rc;
rc = SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_PARAMSET_SIZE, (SQLPOINTER)paramSetSize, 0);
if (!SQL_SUCCEEDED(rc)) return rc;
rc = SQLExecute_ptr(hStmt);
return rc;

bool hasDAE = false;
for (const auto& p : paramInfos) {
if (p.isDAE) {
hasDAE = true;
break;
}
}
if (!hasDAE) {
std::vector<std::shared_ptr<void>> paramBuffers;
rc = BindParameterArray(hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers);
if (!SQL_SUCCEEDED(rc)) return rc;

rc = SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_PARAMSET_SIZE, (SQLPOINTER)paramSetSize, 0);
if (!SQL_SUCCEEDED(rc)) return rc;

rc = SQLExecute_ptr(hStmt);
return rc;
} else {
size_t rowCount = columnwise_params.size();
for (size_t rowIndex = 0; rowIndex < rowCount; ++rowIndex) {
py::list rowParams = columnwise_params[rowIndex];

std::vector<std::shared_ptr<void>> paramBuffers;
rc = BindParameters(hStmt, rowParams, const_cast<std::vector<ParamInfo>&>(paramInfos), paramBuffers);
if (!SQL_SUCCEEDED(rc)) return rc;

rc = SQLExecute_ptr(hStmt);
while (rc == SQL_NEED_DATA) {
SQLPOINTER token;
rc = SQLParamData_ptr(hStmt, &token);
if (!SQL_SUCCEEDED(rc) && rc != SQL_NEED_DATA) return rc;

py::object* py_obj_ptr = reinterpret_cast<py::object*>(token);
if (!py_obj_ptr) return SQL_ERROR;

if (py::isinstance<py::str>(*py_obj_ptr)) {
std::string data = py_obj_ptr->cast<std::string>();
SQLLEN data_len = static_cast<SQLLEN>(data.size());
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), data_len);
} else if (py::isinstance<py::bytes>(*py_obj_ptr) || py::isinstance<py::bytearray>(*py_obj_ptr)) {
std::string data = py_obj_ptr->cast<std::string>();
SQLLEN data_len = static_cast<SQLLEN>(data.size());
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), data_len);
} else {
LOG("Unsupported DAE parameter type in row {}", rowIndex);
return SQL_ERROR;
}
}

if (!SQL_SUCCEEDED(rc)) return rc;
}
return SQL_SUCCESS;
}
}


// Wrap SQLNumResultCols
SQLSMALLINT SQLNumResultCols_wrap(SqlHandlePtr statementHandle) {
LOG("Get number of columns in result set");
Expand Down Expand Up @@ -2213,7 +2262,7 @@ static py::object FetchLobColumnData(SQLHSTMT hStmt,
LOG("Loop {}: Appended {} bytes", loopCount, bytesRead);
}
if (ret == SQL_SUCCESS) {
LOG("Loop {}: SQL_SUCCESS no more data", loopCount);
LOG("Loop {}: SQL_SUCCESS, no more data", loopCount);
break;
}
}
Expand Down Expand Up @@ -3270,7 +3319,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch

// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("LOB columns detected using per-row SQLGetData path");
LOG("LOB columns detected, using per-row SQLGetData path");
while (true) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA) break;
Expand Down Expand Up @@ -3392,7 +3441,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {

// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("LOB columns detected using per-row SQLGetData path");
LOG("LOB columns detected, using per-row SQLGetData path");
while (true) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA) break;
Expand Down
93 changes: 93 additions & 0 deletions tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10661,6 +10661,99 @@ def test_decimal_separator_calculations(cursor, db_connection):
# Cleanup
cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test")

def test_nvarcharmax_executemany_streaming(cursor, db_connection):
"""Streaming insert + fetch > 4k NVARCHAR(MAX) using executemany with all fetch modes."""
try:
values = ["Ω" * 4100, "漢" * 5000]
cursor.execute("CREATE TABLE #pytest_nvarcharmax (col NVARCHAR(MAX))")
db_connection.commit()

# --- executemany insert ---
cursor.executemany("INSERT INTO #pytest_nvarcharmax VALUES (?)", [(v,) for v in values])
db_connection.commit()

# --- fetchall ---
cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY LEN(col)")
rows = [r[0] for r in cursor.fetchall()]
assert rows == sorted(values, key=len)

# --- fetchone ---
cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY LEN(col)")
r1 = cursor.fetchone()[0]
r2 = cursor.fetchone()[0]
assert {r1, r2} == set(values)
assert cursor.fetchone() is None

# --- fetchmany ---
cursor.execute("SELECT col FROM #pytest_nvarcharmax ORDER BY LEN(col)")
batch = [r[0] for r in cursor.fetchmany(1)]
assert batch[0] in values
finally:
cursor.execute("DROP TABLE #pytest_nvarcharmax")
db_connection.commit()

def test_varcharmax_executemany_streaming(cursor, db_connection):
"""Streaming insert + fetch > 4k VARCHAR(MAX) using executemany with all fetch modes."""
try:
values = ["A" * 4100, "B" * 5000]
cursor.execute("CREATE TABLE #pytest_varcharmax (col VARCHAR(MAX))")
db_connection.commit()

# --- executemany insert ---
cursor.executemany("INSERT INTO #pytest_varcharmax VALUES (?)", [(v,) for v in values])
db_connection.commit()

# --- fetchall ---
cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY LEN(col)")
rows = [r[0] for r in cursor.fetchall()]
assert rows == sorted(values, key=len)

# --- fetchone ---
cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY LEN(col)")
r1 = cursor.fetchone()[0]
r2 = cursor.fetchone()[0]
assert {r1, r2} == set(values)
assert cursor.fetchone() is None

# --- fetchmany ---
cursor.execute("SELECT col FROM #pytest_varcharmax ORDER BY LEN(col)")
batch = [r[0] for r in cursor.fetchmany(1)]
assert batch[0] in values
finally:
cursor.execute("DROP TABLE #pytest_varcharmax")
db_connection.commit()

def test_varbinarymax_executemany_streaming(cursor, db_connection):
"""Streaming insert + fetch > 4k VARBINARY(MAX) using executemany with all fetch modes."""
try:
values = [b"\x01" * 4100, b"\x02" * 5000]
cursor.execute("CREATE TABLE #pytest_varbinarymax (col VARBINARY(MAX))")
db_connection.commit()

# --- executemany insert ---
cursor.executemany("INSERT INTO #pytest_varbinarymax VALUES (?)", [(v,) for v in values])
db_connection.commit()

# --- fetchall ---
cursor.execute("SELECT col FROM #pytest_varbinarymax ORDER BY DATALENGTH(col)")
rows = [r[0] for r in cursor.fetchall()]
assert rows == sorted(values, key=len)

# --- fetchone ---
cursor.execute("SELECT col FROM #pytest_varbinarymax ORDER BY DATALENGTH(col)")
r1 = cursor.fetchone()[0]
r2 = cursor.fetchone()[0]
assert {r1, r2} == set(values)
assert cursor.fetchone() is None

# --- fetchmany ---
cursor.execute("SELECT col FROM #pytest_varbinarymax ORDER BY DATALENGTH(col)")
batch = [r[0] for r in cursor.fetchmany(1)]
assert batch[0] in values
finally:
cursor.execute("DROP TABLE #pytest_varbinarymax")
db_connection.commit()

def test_date_string_parameter_binding(cursor, db_connection):
"""Verify that date-like strings are treated as strings in parameter binding"""
table_name = "#pytest_date_string"
Expand Down