Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[misc] Batch result correction
  • Loading branch information
diego Dupin committed Oct 27, 2021
1 parent fa7635e commit a2f2523
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Expand Up @@ -7,7 +7,7 @@
<artifactId>mariadb-java-client</artifactId>
<packaging>jar</packaging>
<name>mariadb-java-client</name>
<version>3.0.2-rc</version>
<version>3.0.3-SNAPSHOT</version>
<description>JDBC driver for MariaDB and MySQL</description>
<url>https://mariadb.com/kb/en/mariadb/about-mariadb-connector-j/</url>

Expand Down
Expand Up @@ -416,7 +416,7 @@ public int[] executeBatch() throws SQLException {
updates[i] = Statement.SUCCESS_NO_INFO;
}
} else {
for (int i = 0; i < Math.min(results.size(), batchParameters.size()); i++) {
for (int i = 0; i < updates.length; i++) {
if (results.get(i) instanceof OkPacket) {
updates[i] = (int) ((OkPacket) results.get(i)).getAffectedRows();
} else {
Expand Down
18 changes: 11 additions & 7 deletions src/main/java/org/mariadb/jdbc/ServerPreparedStatement.java
Expand Up @@ -550,12 +550,12 @@ public int[] executeBatch() throws SQLException {
executeInternalPreparedBatch();

int[] updates = new int[batchParameters.size()];
if (results.size() != batchParameters.size()) {
for (int i = 0; i < batchParameters.size(); i++) {
if (results.size() != updates.length) {
for (int i = 0; i < updates.length; i++) {
updates[i] = Statement.SUCCESS_NO_INFO;
}
} else {
for (int i = 0; i < Math.min(results.size(), batchParameters.size()); i++) {
for (int i = 0; i < updates.length; i++) {
if (results.get(i) instanceof OkPacket) {
updates[i] = (int) ((OkPacket) results.get(i)).getAffectedRows();
} else {
Expand All @@ -581,13 +581,17 @@ public long[] executeLargeBatch() throws SQLException {
executeInternalPreparedBatch();

long[] updates = new long[batchParameters.size()];
if (results.size() != batchParameters.size()) {
for (int i = 0; i < batchParameters.size(); i++) {
if (results.size() != updates.length) {
for (int i = 0; i < updates.length; i++) {
updates[i] = Statement.SUCCESS_NO_INFO;
}
} else {
for (int i = 0; i < results.size(); i++) {
updates[i] = ((OkPacket) results.get(i)).getAffectedRows();
for (int i = 0; i < updates.length; i++) {
if (results.get(i) instanceof OkPacket) {
updates[i] = ((OkPacket) results.get(i)).getAffectedRows();
} else {
updates[i] = org.mariadb.jdbc.Statement.SUCCESS_NO_INFO;
}
}
}

Expand Down
40 changes: 33 additions & 7 deletions src/main/java/org/mariadb/jdbc/client/impl/StandardClient.java
Expand Up @@ -481,25 +481,25 @@ public List<Completion> executePipeline(
responseMsg[i] = sendQuery(messages[i]);
}
while (readCounter < messages.length) {
readCounter++;
for (int j = 0; j < responseMsg[readCounter - 1]; j++) {
for (int j = 0; j < responseMsg[readCounter]; j++) {
results.addAll(
readResponse(
stmt,
messages[readCounter - 1],
messages[readCounter],
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion));
}
readCounter++;
}
}
return results;
} catch (SQLException sqlException) {

// read remaining results
for (int i = readCounter; i < messages.length; i++) {
for (int i = readCounter + 1; i < messages.length; i++) {
for (int j = 0; j < responseMsg[i]; j++) {
try {
results.addAll(
Expand Down Expand Up @@ -546,9 +546,35 @@ public List<Completion> execute(
int resultSetType,
boolean closeOnCompletion)
throws SQLException {
sendQuery(message);
return readResponse(
stmt, message, fetchSize, maxRows, resultSetConcurrency, resultSetType, closeOnCompletion);
int nbResp = sendQuery(message);
if (nbResp == 1) {
return readResponse(
stmt,
message,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion);
} else {
if (streamStmt != null) {
streamStmt.fetchRemaining();
streamStmt = null;
}
List<Completion> completions = new ArrayList<>();
while (nbResp-- > 0) {
readResults(
stmt,
message,
completions,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion);
}
return completions;
}
}

public List<Completion> readResponse(
Expand Down
Expand Up @@ -128,7 +128,7 @@ public boolean binaryProtocol() {
}

public String description() {
return command;
return "EXECUTE " + command;
}

public void setPrepareResult(PrepareResultPacket prepareResult) {
Expand Down
Expand Up @@ -18,6 +18,6 @@ public int encode(Writer writer, Context context) throws IOException {
writer.initPacket();
writer.writeByte(0x0e);
writer.flush();
return 0;
return 1;
}
}
Expand Up @@ -95,6 +95,6 @@ public Completion readPacket(

@Override
public String description() {
return sql;
return "PREPARE " + sql;
}
}
79 changes: 70 additions & 9 deletions src/test/java/org/mariadb/jdbc/integration/BatchTest.java
Expand Up @@ -77,31 +77,39 @@ public void wrongParameter(Connection con) throws SQLException {

@Test
public void differentParameterType() throws SQLException {
try (Connection con = createCon("&useServerPrepStmts=false")) {
differentParameterType(con);
try (Connection con = createCon("&useServerPrepStmts=false&useBulkStmts=false")) {
differentParameterType(con, false);
}
try (Connection con = createCon("&useServerPrepStmts=false&useBulkStmts=true")) {
differentParameterType(con, isMariaDBServer());
}
try (Connection con =
createCon("&useServerPrepStmts=false&useBulkStmts=true&disablePipeline")) {
differentParameterType(con, isMariaDBServer());
}
try (Connection con = createCon("&useServerPrepStmts&useBulkStmts=false")) {
differentParameterType(con);
differentParameterType(con, false);
}
try (Connection con = createCon("&useServerPrepStmts&useBulkStmts")) {
differentParameterType(con);
differentParameterType(con, isMariaDBServer());
}
try (Connection con = createCon("&useServerPrepStmts=false&allowLocalInfile")) {
differentParameterType(con);
differentParameterType(con, isMariaDBServer());
}
try (Connection con = createCon("&useServerPrepStmts&useBulkStmts=false&allowLocalInfile")) {
differentParameterType(con);
differentParameterType(con, false);
}
try (Connection con = createCon("&useServerPrepStmts&useBulkStmts&allowLocalInfile")) {
differentParameterType(con);
differentParameterType(con, false);
}
try (Connection con =
createCon("&useServerPrepStmts&useBulkStmts=false&disablePipeline=true")) {
differentParameterType(con);
differentParameterType(con, false);
}
}

public void differentParameterType(Connection con) throws SQLException {
public void differentParameterType(Connection con, boolean expectSuccessUnknown)
throws SQLException {
Statement stmt = con.createStatement();
stmt.execute("TRUNCATE BatchTest");
try (PreparedStatement prep =
Expand All @@ -126,6 +134,59 @@ public void differentParameterType(Connection con) throws SQLException {
assertEquals(2, rs.getInt(1));
assertEquals("2", rs.getString(2));
assertFalse(rs.next());

stmt.execute("TRUNCATE BatchTest");
try (PreparedStatement prep =
con.prepareStatement("INSERT INTO BatchTest(t1, t2) VALUES (?,?)")) {
prep.setInt(1, 1);
prep.setInt(2, 1);
prep.addBatch();

prep.setInt(1, 2);
prep.setInt(2, 2);
prep.addBatch();
int[] res = prep.executeBatch();
assertEquals(2, res.length);
if (expectSuccessUnknown) {
assertEquals(Statement.SUCCESS_NO_INFO, res[0]);
assertEquals(Statement.SUCCESS_NO_INFO, res[1]);
} else {
assertEquals(1, res[0]);
assertEquals(1, res[1]);
}
}
rs = stmt.executeQuery("SELECT * FROM BatchTest");
assertTrue(rs.next());
assertEquals(1, rs.getInt(1));
assertEquals("1", rs.getString(2));
assertTrue(rs.next());
assertEquals(2, rs.getInt(1));
assertEquals("2", rs.getString(2));
assertFalse(rs.next());

stmt.execute("TRUNCATE BatchTest");
try (PreparedStatement prep =
con.prepareStatement("INSERT INTO BatchTest(t1, t2) VALUES (?,?)")) {
prep.setInt(1, 1);
prep.setString(2, "1");
prep.addBatch();

prep.setInt(1, 2);
prep.setInt(2, 2);
prep.addBatch();
int[] res = prep.executeBatch();
assertEquals(2, res.length);
assertEquals(1, res[0]);
assertEquals(1, res[1]);
}
rs = stmt.executeQuery("SELECT * FROM BatchTest");
assertTrue(rs.next());
assertEquals(1, rs.getInt(1));
assertEquals("1", rs.getString(2));
assertTrue(rs.next());
assertEquals(2, rs.getInt(1));
assertEquals("2", rs.getString(2));
assertFalse(rs.next());
}

@Test
Expand Down

0 comments on commit a2f2523

Please sign in to comment.