Skip to content

Commit

Permalink
[CONJ-1135] bulk insert when useBulkStmtsForInserts only used for I…
Browse files Browse the repository at this point in the history
…NSERT NOT using "ON DUPLICATE KEY UPDATE"
  • Loading branch information
rusher committed Dec 13, 2023
1 parent f8b835c commit 6deda54
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 13 deletions.
15 changes: 10 additions & 5 deletions src/main/java/org/mariadb/jdbc/BasePreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@
import java.sql.ParameterMetaData;
import java.util.*;
import java.util.concurrent.locks.ReentrantLock;
import java.util.regex.Pattern;
import org.mariadb.jdbc.client.ColumnDecoder;
import org.mariadb.jdbc.client.util.Parameters;
import org.mariadb.jdbc.codec.*;
import org.mariadb.jdbc.export.ExceptionFactory;
import org.mariadb.jdbc.export.Prepare;
import org.mariadb.jdbc.plugin.Codec;
import org.mariadb.jdbc.plugin.codec.*;
import org.mariadb.jdbc.util.ClientParser;
import org.mariadb.jdbc.util.ParameterList;
import org.mariadb.jdbc.util.constants.ServerStatus;

/** Common methods for prepare statement, for client and server prepare statement. */
public abstract class BasePreparedStatement extends Statement implements PreparedStatement {
private static final Pattern INSERT_STATEMENT_PATTERN =
Pattern.compile("^(\\s*/\\*([^*]|\\*[^/])*\\*/)*\\s*(INSERT)", Pattern.CASE_INSENSITIVE);

/** prepare statement sql command */
protected final String sql;
Expand Down Expand Up @@ -96,8 +95,14 @@ public String toString() {
}

protected void checkIfInsertCommand() {
if (isCommandInsert == null)
isCommandInsert = sql != null && INSERT_STATEMENT_PATTERN.matcher(sql).find();
if (isCommandInsert == null) {
if (sql == null) {
isCommandInsert = false;
} else {
ClientParser parser = ClientParser.parameterParts(sql, (con.getContext().getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) > 0);
isCommandInsert = parser.isInsert() && !parser.isInsertDuplicate();
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public ClientPreparedStatement(
boolean noBackslashEscapes =
(con.getContext().getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) > 0;
parser = ClientParser.parameterParts(sql, noBackslashEscapes);
isCommandInsert = parser.isInsert() && !parser.isInsertDuplicate();
parameters = new ParameterList(parser.getParamCount());
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/mariadb/jdbc/Statement.java
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ public void closeOnCompletion() throws SQLException {
}
} else {
if (currResult != null && currResult instanceof ResultSet) {
Result res = (Result)currResult;
Result res = (Result) currResult;
if (res.streaming() || res.loaded()) {
res.closeOnCompletion();
}
Expand Down
69 changes: 66 additions & 3 deletions src/main/java/org/mariadb/jdbc/util/ClientParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@ public final class ClientParser implements PrepareResult {
private final byte[] query;
private final List<Integer> paramPositions;
private final int paramCount;

private ClientParser(String sql, byte[] query, List<Integer> paramPositions) {
private final boolean isInsert;
private final boolean isInsertDuplicate;

private ClientParser(
String sql,
byte[] query,
List<Integer> paramPositions,
boolean isInsert,
boolean isInsertDuplicate) {
this.sql = sql;
this.query = query;
this.paramPositions = paramPositions;
this.paramCount = paramPositions.size();
this.isInsert = isInsert;
this.isInsertDuplicate = isInsertDuplicate;
}

/**
Expand All @@ -39,6 +48,8 @@ public static ClientParser parameterParts(String queryString, boolean noBackslas
byte lastChar = 0x00;

boolean singleQuotes = false;
boolean isInsert = false;
boolean isInsertDupplicate = false;
byte[] query = queryString.getBytes(StandardCharsets.UTF_8);
int queryLength = query.length;
for (int i = 0; i < queryLength; i++) {
Expand Down Expand Up @@ -105,6 +116,50 @@ public static ClientParser parameterParts(String queryString, boolean noBackslas
}
break;

case (byte) 'I':
case (byte) 'i':
if (state == LexState.Normal && !isInsert) {
if (i + 6 < queryLength
&& (query[i + 1] == (byte) 'n' || query[i + 1] == (byte) 'N')
&& (query[i + 2] == (byte) 's' || query[i + 2] == (byte) 'S')
&& (query[i + 3] == (byte) 'e' || query[i + 3] == (byte) 'E')
&& (query[i + 4] == (byte) 'r' || query[i + 4] == (byte) 'R')
&& (query[i + 5] == (byte) 't' || query[i + 5] == (byte) 'T')) {
if (i > 0 && (query[i - 1] > ' ' && "();><=-+,".indexOf(query[i - 1]) == -1)) {
break;
}
if (query[i + 6] > ' ' && "();><=-+,".indexOf(query[i + 6]) == -1) {
break;
}
i += 5;
isInsert = true;
}
}
break;
case (byte) 'D':
case (byte) 'd':
if (isInsert && state == LexState.Normal) {
if (i + 9 < queryLength
&& (query[i + 1] == (byte) 'u' || query[i + 1] == (byte) 'U')
&& (query[i + 2] == (byte) 'p' || query[i + 2] == (byte) 'P')
&& (query[i + 3] == (byte) 'l' || query[i + 3] == (byte) 'L')
&& (query[i + 4] == (byte) 'i' || query[i + 4] == (byte) 'I')
&& (query[i + 5] == (byte) 'c' || query[i + 5] == (byte) 'C')
&& (query[i + 6] == (byte) 'a' || query[i + 6] == (byte) 'A')
&& (query[i + 7] == (byte) 't' || query[i + 7] == (byte) 'T')
&& (query[i + 8] == (byte) 'e' || query[i + 8] == (byte) 'E')) {
if (i > 0 && (query[i - 1] > ' ' && "();><=-+,".indexOf(query[i - 1]) == -1)) {
break;
}
if (query[i + 9] > ' ' && "();><=-+,".indexOf(query[i + 9]) == -1) {
break;
}
i += 9;
isInsertDupplicate = true;
}
}
break;

case (byte) '\\':
if (noBackslashEscapes) {
break;
Expand All @@ -129,7 +184,7 @@ public static ClientParser parameterParts(String queryString, boolean noBackslas
lastChar = car;
}

return new ClientParser(queryString, query, paramPositions);
return new ClientParser(queryString, query, paramPositions, isInsert, isInsertDupplicate);
}

public String getSql() {
Expand All @@ -148,6 +203,14 @@ public int getParamCount() {
return paramCount;
}

public boolean isInsert() {
return isInsert;
}

public boolean isInsertDuplicate() {
return isInsertDuplicate;
}

enum LexState {
Normal, /* inside query */
String, /* inside string */
Expand Down
36 changes: 35 additions & 1 deletion src/test/java/org/mariadb/jdbc/integration/BatchTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ public void differentParameterType() throws SQLException {
try (Connection con = createCon("&useServerPrepStmts=false&useBulkStmtsForInserts=false")) {
differentParameterType(con, false);
}
try (Connection con = createCon("&useServerPrepStmts=false&useBulkStmtsForInserts")) {
differentParameterType(con, false);
}
try (Connection con =
createCon("&useServerPrepStmts=false&useBulkStmts&useBulkStmtsForInserts")) {
differentParameterType(con, isMariaDBServer() && !isXpand());
Expand All @@ -263,6 +266,9 @@ public void differentParameterType() throws SQLException {
try (Connection con = createCon("&useServerPrepStmts&useBulkStmtsForInserts=false")) {
differentParameterType(con, false);
}
try (Connection con = createCon("&useServerPrepStmts&useBulkStmtsForInserts")) {
differentParameterType(con, false);
}
try (Connection con =
createCon("&useServerPrepStmts&useBulkStmtsForInserts&allowLocalInfile=false")) {
differentParameterType(con, false);
Expand Down Expand Up @@ -336,8 +342,8 @@ public void differentParameterType(Connection con, boolean expectSuccessUnknown)
assertEquals(3, rs.getInt(1));
assertNull(rs.getString(2));
assertFalse(rs.next());
stmt.execute("TRUNCATE BatchTest");

stmt.execute("TRUNCATE BatchTest");
try (PreparedStatement prep =
con.prepareStatement("INSERT INTO BatchTest(t1, t2) VALUES (?,?)")) {
prep.setInt(1, 1);
Expand Down Expand Up @@ -436,6 +442,34 @@ public void differentParameterType(Connection con, boolean expectSuccessUnknown)
assertEquals(1, res[1]);
}
}

stmt.execute("TRUNCATE BatchTest");
try (PreparedStatement prep =
con.prepareStatement(
"INSERT INTO BatchTest(t1, t2) VALUES (?,?) ON DUPLICATE KEY UPDATE t2='changed'")) {
prep.setInt(1, 5);
prep.setInt(2, 5);
prep.addBatch();

prep.setInt(1, 5);
prep.setInt(2, 6);
prep.addBatch();
int[] res = prep.executeBatch();
assertEquals(2, res.length);
if (expectSuccessUnknown) {
assertEquals(Statement.SUCCESS_NO_INFO, res[0]);
assertEquals(Statement.SUCCESS_NO_INFO, res[1]);
} else {
assertEquals(1, res[0]);
assertEquals(2, res[1]);
}
}
rs = stmt.executeQuery("SELECT * FROM BatchTest");
assertTrue(rs.next());
assertEquals(5, rs.getInt(1));
assertEquals("changed", rs.getString(2));
assertFalse(rs.next());

con.rollback();
}

Expand Down
31 changes: 28 additions & 3 deletions src/test/java/org/mariadb/jdbc/unit/util/ClientParserTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
@SuppressWarnings("ConstantConditions")
public class ClientParserTest {

private void parse(String sql, String[] expected, String[] expectedNoBackSlash) {
private void parse(
String sql, String[] expected, String[] expectedNoBackSlash, boolean isInsertDuplicate) {
ClientParser parser = ClientParser.parameterParts(sql, false);
assertEquals(expected.length, parser.getParamCount() + 1, displayErr(parser, expected));

Expand All @@ -38,6 +39,8 @@ private void parse(String sql, String[] expected, String[] expectedNoBackSlash)
assertEquals(
expectedNoBackSlash[expectedNoBackSlash.length - 1],
new String(parser.getQuery(), pos, paramPos - pos));

assertEquals(isInsertDuplicate, parser.isInsertDuplicate());
}

private String displayErr(ClientParser parser, String[] exp) {
Expand Down Expand Up @@ -66,10 +69,32 @@ public void ClientParser() {
parse(
"SELECT '\\\\test' /*test* #/ ;`*/",
new String[] {"SELECT '\\\\test' /*test* #/ ;`*/"},
new String[] {"SELECT '\\\\test' /*test* #/ ;`*/"});
new String[] {"SELECT '\\\\test' /*test* #/ ;`*/"},
false);
parse(
"DO '\\\"', \"\\'\"",
new String[] {"DO '\\\"', \"\\'\""},
new String[] {"DO '\\\"', \"\\'\""});
new String[] {"DO '\\\"', \"\\'\""},
false);
parse(
"INSERT INTO TABLE(id,val) VALUES (1,2)",
new String[] {"INSERT INTO TABLE(id,val) VALUES (1,2)"},
new String[] {"INSERT INTO TABLE(id,val) VALUES (1,2)"},
false);
parse(
"INSERT INTO TABLE(id,val) VALUES (1,2) ON DUPLICATE KEY UPDATE",
new String[] {"INSERT INTO TABLE(id,val) VALUES (1,2) ON DUPLICATE KEY UPDATE"},
new String[] {"INSERT INTO TABLE(id,val) VALUES (1,2) ON DUPLICATE KEY UPDATE"},
true);
parse(
"INSERT INTO TABLE(id,val) VALUES (1,2) ON DUPLICATE",
new String[] {"INSERT INTO TABLE(id,val) VALUES (1,2) ON DUPLICATE"},
new String[] {"INSERT INTO TABLE(id,val) VALUES (1,2) ON DUPLICATE"},
false);
parse(
"INSERT INTO TABLE(id,val) VALUES (1,2) ONDUPLICATE",
new String[] {"INSERT INTO TABLE(id,val) VALUES (1,2) ONDUPLICATE"},
new String[] {"INSERT INTO TABLE(id,val) VALUES (1,2) ONDUPLICATE"},
false);
}
}

0 comments on commit 6deda54

Please sign in to comment.