diff --git a/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformation.java b/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformation.java index d67412907..90d4feed9 100644 --- a/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformation.java +++ b/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformation.java @@ -71,7 +71,7 @@ public interface CmdInformation { void addErrorStat(); - void clearErrorStat(); + void reset(); void addResultSetStat(); diff --git a/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationBatch.java b/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationBatch.java index 770c2667e..dcb890a06 100644 --- a/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationBatch.java +++ b/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationBatch.java @@ -64,8 +64,8 @@ public class CmdInformationBatch implements CmdInformation { - private final Queue insertIds; - private final Queue updateCounts; + private final Queue insertIds = new ConcurrentLinkedQueue<>(); + private final Queue updateCounts = new ConcurrentLinkedQueue<>(); private final int expectedSize; private final int autoIncrement; private int insertIdNumber = 0; @@ -84,15 +84,13 @@ public class CmdInformationBatch implements CmdInformation { */ public CmdInformationBatch(int expectedSize, int autoIncrement) { this.expectedSize = expectedSize; - this.insertIds = new ConcurrentLinkedQueue<>(); - this.updateCounts = new ConcurrentLinkedQueue<>(); this.autoIncrement = autoIncrement; } @Override public void addErrorStat() { hasException = true; - this.updateCounts.add((long) Statement.EXECUTE_FAILED); + updateCounts.add((long) Statement.EXECUTE_FAILED); } /** @@ -100,9 +98,12 @@ public void addErrorStat() { * */ @Override - public void clearErrorStat() { + public void reset() { + insertIds.clear(); + updateCounts.clear(); + insertIdNumber = 0; hasException = false; - this.updateCounts.remove((long) Statement.EXECUTE_FAILED); + rewritten = false; } public void addResultSetStat() { @@ -111,9 +112,9 @@ public void addResultSetStat() { @Override public void addSuccessStat(long updateCount, long insertId) { - this.insertIds.add(insertId); + insertIds.add(insertId); insertIdNumber += updateCount; - this.updateCounts.add(updateCount); + updateCounts.add(updateCount); } @Override diff --git a/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationMultiple.java b/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationMultiple.java index 5c3abe41a..cc6794694 100644 --- a/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationMultiple.java +++ b/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationMultiple.java @@ -80,16 +80,16 @@ public class CmdInformationMultiple implements CmdInformation { * @param autoIncrement connection auto increment value. */ public CmdInformationMultiple(int expectedSize, int autoIncrement) { + insertIds = new ArrayList<>(expectedSize); + updateCounts = new ArrayList<>(expectedSize); this.expectedSize = expectedSize; - this.insertIds = new ArrayList<>(expectedSize); - this.updateCounts = new ArrayList<>(expectedSize); this.autoIncrement = autoIncrement; } @Override public void addErrorStat() { hasException = true; - this.updateCounts.add((long) Statement.EXECUTE_FAILED); + updateCounts.add((long) Statement.EXECUTE_FAILED); } /** @@ -97,21 +97,25 @@ public void addErrorStat() { * */ @Override - public void clearErrorStat() { + public void reset() { + insertIds.clear(); + updateCounts.clear(); + insertIdNumber = 0; + moreResults = 0; hasException = false; - this.updateCounts.remove((long) Statement.EXECUTE_FAILED); + rewritten = false; } public void addResultSetStat() { - this.updateCounts.add((long) RESULT_SET_VALUE); + updateCounts.add((long) RESULT_SET_VALUE); } @Override public void addSuccessStat(long updateCount, long insertId) { - this.insertIds.add(insertId); + insertIds.add(insertId); insertIdNumber += updateCount; - this.updateCounts.add(updateCount); + updateCounts.add(updateCount); } @Override diff --git a/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationSingle.java b/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationSingle.java index b90e292bf..ff992049e 100644 --- a/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationSingle.java +++ b/src/main/java/org/mariadb/jdbc/internal/com/read/dao/CmdInformationSingle.java @@ -103,7 +103,7 @@ public void addErrorStat() { } @Override - public void clearErrorStat() { + public void reset() { //not expected } diff --git a/src/main/java/org/mariadb/jdbc/internal/protocol/AbstractQueryProtocol.java b/src/main/java/org/mariadb/jdbc/internal/protocol/AbstractQueryProtocol.java index ab462fd8f..832dafb55 100644 --- a/src/main/java/org/mariadb/jdbc/internal/protocol/AbstractQueryProtocol.java +++ b/src/main/java/org/mariadb/jdbc/internal/protocol/AbstractQueryProtocol.java @@ -430,10 +430,10 @@ private boolean executeBulkBatch(Results results, String sql, ServerPrepareResul getResult(results); } catch (SQLException sqle) { if ("HY000".equals(sqle.getSQLState()) && sqle.getErrorCode() == 1295) { - //query contain SELECT or DELETE. cannot be handle by BULK protocol + //query contain commands that cannot be handled by BULK protocol // clear error and special error code, so it won't leak anywhere // and wouldn't be misinterpreted as an additional update count - results.getCmdInformation().clearErrorStat(); + results.getCmdInformation().reset(); return false; } if (exception == null) { diff --git a/src/test/java/org/mariadb/jdbc/StatementTest.java b/src/test/java/org/mariadb/jdbc/StatementTest.java index 7cdf6974c..dd3e5e257 100644 --- a/src/test/java/org/mariadb/jdbc/StatementTest.java +++ b/src/test/java/org/mariadb/jdbc/StatementTest.java @@ -367,17 +367,34 @@ public void testFallbackBatchUpdate() throws SQLException { Assume.assumeTrue(doPrecisionTest); createTable("testFallbackBatchUpdate", "col int"); - int[] results; - int queriesInBatch = 2; + Statement statement = sharedConnection.createStatement(); + + //add 100 data + StringBuilder sb = new StringBuilder("INSERT INTO testFallbackBatchUpdate(col) VALUES (0)"); + for (int i = 1; i < 100; i++) sb.append(",(").append(i).append(")"); + statement.execute(sb.toString()); + try (PreparedStatement preparedStatement = sharedConnection.prepareStatement( - "DELETE FROM testFallbackBatchUpdate WHERE col = ? ")) { - for (int i = 0; i < queriesInBatch; i++) { - preparedStatement.setInt(1, 0); - preparedStatement.addBatch(); + "DELETE FROM testFallbackBatchUpdate WHERE col = ?")) { + preparedStatement.setInt(1, 10); + preparedStatement.addBatch(); + + preparedStatement.setInt(1, 15); + preparedStatement.addBatch(); + + int[] results = preparedStatement.executeBatch(); + assertEquals(2, results.length); + } + + //check results + try (ResultSet rs = statement.executeQuery("SELECT * FROM testFallbackBatchUpdate")) { + for (int i = 0; i < 100; i++) { + if (i == 10 || i == 15) continue; + assertTrue(rs.next()); + assertEquals(i, rs.getInt(1)); } - results = preparedStatement.executeBatch(); + assertFalse(rs.next()); } - assertEquals(results.length, queriesInBatch); } @Test @@ -385,17 +402,35 @@ public void testProperBatchUpdate() throws SQLException { Assume.assumeTrue(doPrecisionTest); createTable("testProperBatchUpdate", "col int, col2 int"); - int[] results; - int queriesInBatch = 3; + Statement statement = sharedConnection.createStatement(); + + //add 100 data + StringBuilder sb = new StringBuilder("INSERT INTO testProperBatchUpdate(col, col2) VALUES (0,0)"); + for (int i = 1; i < 100; i++) sb.append(",(").append(i).append(",0)"); + statement.execute(sb.toString()); + try (PreparedStatement preparedStatement = sharedConnection.prepareStatement( "UPDATE testProperBatchUpdate set col2 = ? WHERE col = ? ")) { - for (int i = 0; i < queriesInBatch; i++) { - preparedStatement.setInt(1, i); - preparedStatement.setInt(2, i); - preparedStatement.addBatch(); + preparedStatement.setInt(1, 10); + preparedStatement.setInt(2, 10); + preparedStatement.addBatch(); + + preparedStatement.setInt(1, 15); + preparedStatement.setInt(2, 15); + preparedStatement.addBatch(); + + int[] results = preparedStatement.executeBatch(); + assertEquals(2, results.length); + } + + //check results + try (ResultSet rs = statement.executeQuery("SELECT * FROM testProperBatchUpdate")) { + for (int i = 0; i < 100; i++) { + assertTrue(rs.next()); + assertEquals(i, rs.getInt(1)); + assertEquals((i == 10 || i == 15) ? i : 0, rs.getInt(2)); } - results = preparedStatement.executeBatch(); + assertFalse(rs.next()); } - assertEquals(results.length, queriesInBatch); } } \ No newline at end of file