Skip to content

Commit

Permalink
[CONJ-1010] improve client side prepared parameter parameter substitu…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
rusher committed Sep 7, 2022
1 parent be083fe commit 0e4dbec
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 79 deletions.
6 changes: 5 additions & 1 deletion src/main/java/org/mariadb/jdbc/codec/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ public Parameter(Codec<T> codec, T value, Long length) {
}

public void encodeText(Writer encoder, Context context) throws IOException, SQLException {
codec.encodeText(encoder, context, this.value, null, length);
if (value == null) {
encoder.writeAscii("null");
} else {
codec.encodeText(encoder, context, this.value, null, length);
}
}

public void encodeBinary(Writer encoder) throws IOException, SQLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,18 @@ public int encode(Writer encoder, Context context) throws IOException, SQLExcept
encoder.initPacket();
encoder.writeByte(0x03);
if (preSqlCmd != null) encoder.writeAscii(preSqlCmd);
if (parser.getParamCount() == 0) {
encoder.writeBytes(parser.getQueryParts().get(0));
if (parser.getParamPositions().size() == 0) {
encoder.writeBytes(parser.getQuery());
} else {
encoder.writeBytes(parser.getQueryParts().get(0));
for (int i = 0; i < parser.getParamCount(); i++) {
if (parameters.get(i).isNull()) {
encoder.writeAscii("null");
} else {
parameters.get(i).encodeText(encoder, context);
}
encoder.writeBytes(parser.getQueryParts().get(i + 1));
int pos = 0;
int paramPos;
for (int i = 0; i < parser.getParamPositions().size(); i++) {
paramPos = parser.getParamPositions().get(i);
encoder.writeBytes(parser.getQuery(), pos, paramPos - pos);
pos = paramPos + 1;
parameters.get(i).encodeText(encoder, context);
}
encoder.writeBytes(parser.getQuery(), pos, parser.getQuery().length - pos);
}
encoder.flush();
return 1;
Expand Down
84 changes: 33 additions & 51 deletions src/main/java/org/mariadb/jdbc/util/ClientParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
public final class ClientParser implements PrepareResult {

private final String sql;
private final List<byte[]> queryParts;
private final int paramCount;
private final byte[] query;
private List<Integer> paramPositions;
private int paramCount;

private ClientParser(String sql, List<byte[]> queryParts) {
private ClientParser(String sql, byte[] query, List<Integer> paramPositions) {
this.sql = sql;
this.queryParts = queryParts;
this.paramCount = queryParts.size() - 1;
this.query = query;
this.paramPositions = paramPositions;
this.paramCount = paramPositions.size();
}

/**
Expand All @@ -33,59 +35,56 @@ private ClientParser(String sql, List<byte[]> queryParts) {
*/
public static ClientParser parameterParts(String queryString, boolean noBackslashEscapes) {

List<byte[]> partList = new ArrayList<>();
List<Integer> paramPositions = new ArrayList<>();
LexState state = LexState.Normal;
char lastChar = '\0';
boolean endingSemicolon = false;
byte lastChar = 0x00;

boolean singleQuotes = false;
int lastParameterPosition = 0;

char[] query = queryString.toCharArray();
byte[] query = queryString.getBytes(StandardCharsets.UTF_8);
int queryLength = query.length;
for (int i = 0; i < queryLength; i++) {

char car = query[i];
byte car = query[i];
if (state == LexState.Escape
&& !((car == '\'' && singleQuotes) || (car == '"' && !singleQuotes))) {
state = LexState.String;
lastChar = car;
continue;
}
switch (car) {
case '*':
if (state == LexState.Normal && lastChar == '/') {
case (byte) '*':
if (state == LexState.Normal && lastChar == (byte) '/') {
state = LexState.SlashStarComment;
}
break;

case '/':
if (state == LexState.SlashStarComment && lastChar == '*') {
case (byte) '/':
if (state == LexState.SlashStarComment && lastChar == (byte) '*') {
state = LexState.Normal;
} else if (state == LexState.Normal && lastChar == '/') {
} else if (state == LexState.Normal && lastChar == (byte) '/') {
state = LexState.EOLComment;
}
break;

case '#':
case (byte) '#':
if (state == LexState.Normal) {
state = LexState.EOLComment;
}
break;

case '-':
if (state == LexState.Normal && lastChar == '-') {
case (byte) '-':
if (state == LexState.Normal && lastChar == (byte) '-') {
state = LexState.EOLComment;
}
break;

case '\n':
case (byte) '\n':
if (state == LexState.EOLComment) {
state = LexState.Normal;
}
break;

case '"':
case (byte) '"':
if (state == LexState.Normal) {
state = LexState.String;
singleQuotes = false;
Expand All @@ -96,7 +95,7 @@ public static ClientParser parameterParts(String queryString, boolean noBackslas
}
break;

case '\'':
case (byte) '\'':
if (state == LexState.Normal) {
state = LexState.String;
singleQuotes = true;
Expand All @@ -107,60 +106,43 @@ public static ClientParser parameterParts(String queryString, boolean noBackslas
}
break;

case '\\':
case (byte) '\\':
if (noBackslashEscapes) {
break;
}
if (state == LexState.String) {
state = LexState.Escape;
}
break;
case ';':
case (byte) '?':
if (state == LexState.Normal) {
endingSemicolon = true;
paramPositions.add(i);
}
break;
case '?':
if (state == LexState.Normal) {
partList.add(
queryString.substring(lastParameterPosition, i).getBytes(StandardCharsets.UTF_8));
lastParameterPosition = i + 1;
}
break;
case '`':
case (byte) '`':
if (state == LexState.Backtick) {
state = LexState.Normal;
} else if (state == LexState.Normal) {
state = LexState.Backtick;
}
break;
default:
// multiple queries
if (state == LexState.Normal && endingSemicolon && ((byte) car >= 40)) {
endingSemicolon = false;
}
break;
}
lastChar = car;
}
if (lastParameterPosition == 0) {
partList.add(queryString.getBytes(StandardCharsets.UTF_8));
} else {
partList.add(
queryString
.substring(lastParameterPosition, queryLength)
.getBytes(StandardCharsets.UTF_8));
}

return new ClientParser(queryString, partList);
return new ClientParser(queryString, query, paramPositions);
}

public String getSql() {
return sql;
}

public List<byte[]> getQueryParts() {
return queryParts;
public byte[] getQuery() {
return query;
}

public List<Integer> getParamPositions() {
return paramPositions;
}

public int getParamCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@ private void checkParsing(String sql, int paramNumber, String[] partsMulti) thro
(sharedConn.getContext().getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) > 0;
ClientParser parser = ClientParser.parameterParts(sql, noBackslashEscapes);
assertEquals(paramNumber, parser.getParamCount());

for (int i = 0; i < partsMulti.length; i++) {
assertEquals(partsMulti[i], new String(parser.getQueryParts().get(i)));
int pos = 0;
int paramPos = parser.getQuery().length;
for (int i = 0; i < parser.getParamPositions().size(); i++) {
paramPos = parser.getParamPositions().get(i);
assertEquals(partsMulti[i], new String(parser.getQuery(), pos, paramPos - pos));
pos = paramPos + 1;
}
assertEquals(
partsMulti[partsMulti.length - 1],
new String(parser.getQuery(), pos, parser.getQuery().length - pos));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,25 @@ void testObject(ResultSet rs, Class<?> objClass, Object exp, int idx) throws Exc
}
}

void testErrObject(ResultSet rs, Class<?> objClass, int index) throws SQLException {
Assertions.assertThrows(SQLException.class, () -> rs.getObject(index, objClass));
Assertions.assertThrows(
SQLException.class, () -> rs.getObject("t" + index + "alias", objClass));
}

void testErrObject(ResultSet rs, Class<?> objClass) throws SQLException {
Assertions.assertThrows(SQLException.class, () -> rs.getObject(1, objClass));
Assertions.assertThrows(SQLException.class, () -> rs.getObject("t1alias", objClass));
testErrObject(rs, objClass, 1);
assertNull(rs.getObject(4, objClass));
assertNull(rs.getObject("t4alias", objClass));
}

void testArrObject(ResultSet rs, byte[] exp, int index) throws SQLException {
assertArrayEquals(exp, (byte[]) rs.getObject(index, (Class<?>) byte[].class));
assertArrayEquals(exp, (byte[]) rs.getObject("t" + index + "alias", (Class<?>) byte[].class));
}

void testArrObject(ResultSet rs, byte[] exp) throws SQLException {
assertArrayEquals(exp, (byte[]) rs.getObject(1, (Class<?>) byte[].class));
assertArrayEquals(exp, (byte[]) rs.getObject("t1alias", (Class<?>) byte[].class));
testArrObject(rs, exp, 1);
assertNull(rs.getObject(4, (Class<?>) byte[].class));
assertNull(rs.getObject("t4alias", (Class<?>) byte[].class));
}
Expand Down
39 changes: 29 additions & 10 deletions src/test/java/org/mariadb/jdbc/unit/util/ClientParserTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,46 @@ public class ClientParserTest {

private void parse(String sql, String[] expected, String[] expectedNoBackSlash) {
ClientParser parser = ClientParser.parameterParts(sql, false);
assertEquals(expected.length, parser.getQueryParts().size(), displayErr(parser, expected));
for (int i = 0; i < parser.getQueryParts().size(); i++) {
byte[] b = parser.getQueryParts().get(i);
assertEquals(expected[i], new String(b, StandardCharsets.UTF_8));
assertEquals(expected.length, parser.getParamCount() + 1, displayErr(parser, expected));

int pos = 0;
int paramPos = parser.getQuery().length;
for (int i = 0; i < parser.getParamCount(); i++) {
paramPos = parser.getParamPositions().get(i);
assertEquals(expected[i], new String(parser.getQuery(), pos, paramPos - pos));
pos = paramPos + 1;
}
assertEquals(expected[expected.length - 1], new String(parser.getQuery(), pos, paramPos - pos));

parser = ClientParser.parameterParts(sql, true);
assertEquals(
expectedNoBackSlash.length, parser.getQueryParts().size(), displayErr(parser, expected));
for (int i = 0; i < parser.getQueryParts().size(); i++) {
byte[] b = parser.getQueryParts().get(i);
assertEquals(expectedNoBackSlash[i], new String(b, StandardCharsets.UTF_8));
expectedNoBackSlash.length, parser.getParamCount() + 1, displayErr(parser, expected));
pos = 0;
paramPos = parser.getQuery().length;
for (int i = 0; i < parser.getParamCount(); i++) {
paramPos = parser.getParamPositions().get(i);
assertEquals(expectedNoBackSlash[i], new String(parser.getQuery(), pos, paramPos - pos));
pos = paramPos + 1;
}
assertEquals(
expectedNoBackSlash[expectedNoBackSlash.length - 1],
new String(parser.getQuery(), pos, paramPos - pos));
}

private String displayErr(ClientParser parser, String[] exp) {
StringBuilder sb = new StringBuilder();
sb.append("is:\n");
for (byte[] b : parser.getQueryParts()) {
sb.append(new String(b, StandardCharsets.UTF_8)).append("\n");

int pos = 0;
int paramPos = parser.getQuery().length;
for (int i = 0; i < parser.getParamCount(); i++) {
paramPos = parser.getParamPositions().get(i);
sb.append(new String(parser.getQuery(), pos, paramPos - pos, StandardCharsets.UTF_8))
.append("\n");
pos = paramPos + 1;
}
sb.append(new String(parser.getQuery(), pos, paramPos - pos));

sb.append("but was:\n");
for (String s : exp) {
sb.append(s).append("\n");
Expand Down

0 comments on commit 0e4dbec

Please sign in to comment.