Skip to content

Commit

Permalink
feat: parse query parameters in PostgreSQL query (#1732)
Browse files Browse the repository at this point in the history
* fix: PostgreSQL supports newline in quoted literals and identifiers

PostgreSQL supports newline characters in string literals and quoted
identifiers. Trying to execute a statement with a string literal or
quoted identifier that contained a newline character would cause an
'Unclosed string literal' error.

Fixes #1730

* feat: parse query parameters in PostgreSQL query

Adds a helper method to get the parameters from a PostgreSQL query. This
is needed for DESCRIBE statement messages in PGAdapter, as it must
return the data types of all query parameters in a query string. Even
though this parser is not able to determine the parameter types, it is
able to determine the number of parameters. This again makes it possible
to PGAdapter to return Oid.UNSPECIFIED for each parameter in the query
string, which is enough for most clients.
  • Loading branch information
olavloite committed Mar 8, 2022
1 parent f403d99 commit 7357ac6
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 27 deletions.
Expand Up @@ -21,6 +21,9 @@
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.common.base.Preconditions;
import java.util.HashSet;
import java.util.Set;
import javax.annotation.Nullable;

@InternalApi
public class PostgreSQLStatementParser extends AbstractStatementParser {
Expand Down Expand Up @@ -149,30 +152,76 @@ ParametersInfo convertPositionalParametersToNamedParametersInternal(char paramCh
return new ParametersInfo(paramIndex - 1, named.toString());
}

private int skip(String sql, int currentIndex, StringBuilder result) {
/**
* Note: This is an internal API and breaking changes can be made without prior notice.
*
* <p>Returns the PostgreSQL-style query parameters ($1, $2, ...) in the given SQL string. The
* SQL-string is assumed to not contain any comments. Use {@link #removeCommentsAndTrim(String)}
* to remove all comments before calling this method. Occurrences of query-parameter like strings
* inside quoted identifiers or string literals are ignored.
*
* <p>The following example will return a set containing ("$1", "$2"). <code>
* select col1, col2, "col$4"
* from some_table
* where col1=$1 and col2=$2
* and not col3=$1 and col4='$3'
* </code>
*
* @param sql the SQL-string to check for parameters. Must not contain comments.
* @return A set containing all the parameters in the SQL-string.
*/
@InternalApi
public Set<String> getQueryParameters(String sql) {
Preconditions.checkNotNull(sql);
int maxCount = countOccurrencesOf('$', sql);
Set<String> parameters = new HashSet<>(maxCount);
int currentIndex = 0;
while (currentIndex < sql.length() - 1) {
char c = sql.charAt(currentIndex);
if (c == '$' && Character.isDigit(sql.charAt(currentIndex + 1))) {
// Look ahead for the first non-digit. That is the end of the query parameter.
int endIndex = currentIndex + 2;
while (endIndex < sql.length() && Character.isDigit(sql.charAt(endIndex))) {
endIndex++;
}
parameters.add(sql.substring(currentIndex, endIndex));
currentIndex = endIndex;
} else {
currentIndex = skip(sql, currentIndex, null);
}
}
return parameters;
}

private int skip(String sql, int currentIndex, @Nullable StringBuilder result) {
char currentChar = sql.charAt(currentIndex);
if (currentChar == SINGLE_QUOTE || currentChar == DOUBLE_QUOTE) {
result.append(currentChar);
appendIfNotNull(result, currentChar);
return skipQuoted(sql, currentIndex, currentChar, result);
} else if (currentChar == DOLLAR) {
String dollarTag = parseDollarQuotedString(sql, currentIndex + 1);
if (dollarTag != null) {
result.append(currentChar).append(dollarTag).append(currentChar);
appendIfNotNull(result, currentChar, dollarTag, currentChar);
return skipQuoted(
sql, currentIndex + dollarTag.length() + 1, currentChar, dollarTag, result);
}
}

result.append(currentChar);
appendIfNotNull(result, currentChar);
return currentIndex + 1;
}

private int skipQuoted(String sql, int startIndex, char startQuote, StringBuilder result) {
private int skipQuoted(
String sql, int startIndex, char startQuote, @Nullable StringBuilder result) {
return skipQuoted(sql, startIndex, startQuote, null, result);
}

private int skipQuoted(
String sql, int startIndex, char startQuote, String dollarTag, StringBuilder result) {
String sql,
int startIndex,
char startQuote,
String dollarTag,
@Nullable StringBuilder result) {
boolean lastCharWasEscapeChar = false;
int currentIndex = startIndex + 1;
while (currentIndex < sql.length()) {
Expand All @@ -182,29 +231,41 @@ private int skipQuoted(
// Check if this is the end of the current dollar quoted string.
String tag = parseDollarQuotedString(sql, currentIndex + 1);
if (tag != null && tag.equals(dollarTag)) {
result.append(currentChar).append(tag).append(currentChar);
appendIfNotNull(result, currentChar, dollarTag, currentChar);
return currentIndex + tag.length() + 2;
}
} else if (lastCharWasEscapeChar) {
lastCharWasEscapeChar = false;
} else if (sql.length() > currentIndex + 1 && sql.charAt(currentIndex + 1) == startQuote) {
// This is an escaped quote (e.g. 'foo''bar')
result.append(currentChar).append(currentChar);
appendIfNotNull(result, currentChar);
appendIfNotNull(result, currentChar);
currentIndex += 2;
continue;
} else {
result.append(currentChar);
appendIfNotNull(result, currentChar);
return currentIndex + 1;
}
} else if (currentChar == '\\') {
lastCharWasEscapeChar = true;
} else {
lastCharWasEscapeChar = false;
lastCharWasEscapeChar = currentChar == '\\';
}
currentIndex++;
result.append(currentChar);
appendIfNotNull(result, currentChar);
}
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT, "SQL statement contains an unclosed literal: " + sql);
}

private void appendIfNotNull(@Nullable StringBuilder result, char currentChar) {
if (result != null) {
result.append(currentChar);
}
}

private void appendIfNotNull(
@Nullable StringBuilder result, char prefix, String tag, char suffix) {
if (result != null) {
result.append(prefix).append(tag).append(suffix);
}
}
}
Expand Up @@ -24,6 +24,7 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;

import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.ErrorCode;
Expand All @@ -33,6 +34,7 @@
import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.truth.Truth;
import java.io.File;
import java.io.FileNotFoundException;
Expand All @@ -42,7 +44,6 @@
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -158,7 +159,7 @@ public void testRemoveComments() {

@Test
public void testGoogleStandardSQLRemoveCommentsGsql() {
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);

assertThat(parser.removeCommentsAndTrim("/*GSQL*/")).isEqualTo("");
assertThat(parser.removeCommentsAndTrim("/*GSQL*/SELECT * FROM FOO"))
Expand All @@ -183,7 +184,7 @@ public void testGoogleStandardSQLRemoveCommentsGsql() {

@Test
public void testPostgreSQLDialectRemoveCommentsGsql() {
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
assumeTrue(dialect == Dialect.POSTGRESQL);

assertThat(parser.removeCommentsAndTrim("/*GSQL*/")).isEqualTo("/*GSQL*/");
assertThat(parser.removeCommentsAndTrim("/*GSQL*/SELECT * FROM FOO"))
Expand Down Expand Up @@ -273,7 +274,7 @@ public void testStatementWithCommentContainingSlashAndNoAsteriskOnNewLine() {

@Test
public void testPostgresSQLDialectDollarQuoted() {
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
assumeTrue(dialect == Dialect.POSTGRESQL);

assertThat(parser.removeCommentsAndTrim("$$foo$$")).isEqualTo("$$foo$$");
assertThat(parser.removeCommentsAndTrim("$$--foo$$")).isEqualTo("$$--foo$$");
Expand All @@ -296,7 +297,7 @@ public void testPostgresSQLDialectDollarQuoted() {

@Test
public void testPostgreSQLDialectSupportsEmbeddedComments() {
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
assumeTrue(dialect == Dialect.POSTGRESQL);

final String sql =
"/* This is a comment /* This is an embedded comment */ This is after the embedded comment */ SELECT 1";
Expand All @@ -305,7 +306,7 @@ public void testPostgreSQLDialectSupportsEmbeddedComments() {

@Test
public void testGoogleStandardSQLDialectDoesNotSupportEmbeddedComments() {
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);

final String sql =
"/* This is a comment /* This is an embedded comment */ This is after the embedded comment */ SELECT 1";
Expand All @@ -315,7 +316,7 @@ public void testGoogleStandardSQLDialectDoesNotSupportEmbeddedComments() {

@Test
public void testPostgreSQLDialectUnterminatedComment() {
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
assumeTrue(dialect == Dialect.POSTGRESQL);

final String sql =
"/* This is a comment /* This is still a comment */ this is unterminated SELECT 1";
Expand All @@ -334,7 +335,7 @@ public void testPostgreSQLDialectUnterminatedComment() {

@Test
public void testGoogleStandardSqlDialectDialectUnterminatedComment() {
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);

final String sql =
"/* This is a comment /* This is still a comment */ this is unterminated SELECT 1";
Expand All @@ -360,7 +361,7 @@ public void testShowStatements() {

@Test
public void testGoogleStandardSQLDialectStatementWithHashTagSingleLineComment() {
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);

// Supports # based comments
assertThat(
Expand All @@ -382,7 +383,7 @@ public void testGoogleStandardSQLDialectStatementWithHashTagSingleLineComment()

@Test
public void testPostgreSQLDialectStatementWithHashTagSingleLineComment() {
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
assumeTrue(dialect == Dialect.POSTGRESQL);

// Does not support # based comments
assertThat(
Expand Down Expand Up @@ -615,7 +616,7 @@ public void testIsQuery() {

@Test
public void testGoogleStandardSQLDialectIsQuery_QueryHints() {
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);

// Supports query hints, PostgreSQL dialect does NOT
// Valid query hints.
Expand Down Expand Up @@ -663,7 +664,7 @@ public void testGoogleStandardSQLDialectIsQuery_QueryHints() {

@Test
public void testIsUpdate_QueryHints() {
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);

// Supports query hints, PostgreSQL dialect does NOT
// Valid query hints.
Expand Down Expand Up @@ -1093,7 +1094,7 @@ public void testConvertPositionalParametersToNamedParametersWithGsqlException()

@Test
public void testGoogleStandardSQLDialectConvertPositionalParametersToNamedParameters() {
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);

assertThat(
parser.convertPositionalParametersToNamedParameters(
Expand Down Expand Up @@ -1203,7 +1204,7 @@ public void testGoogleStandardSQLDialectConvertPositionalParametersToNamedParame

@Test
public void testPostgreSQLDialectDialectConvertPositionalParametersToNamedParameters() {
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
assumeTrue(dialect == Dialect.POSTGRESQL);

assertThat(
parser.convertPositionalParametersToNamedParameters(
Expand Down Expand Up @@ -1318,6 +1319,25 @@ public void testPostgreSQLDialectDialectConvertPositionalParametersToNamedParame
+ "and col8 between $12 and $13")));
}

@Test
public void testPostgreSQLGetQueryParameters() {
assumeTrue(dialect == Dialect.POSTGRESQL);

PostgreSQLStatementParser parser = (PostgreSQLStatementParser) this.parser;
assertEquals(ImmutableSet.of(), parser.getQueryParameters("select * from foo"));
assertEquals(
ImmutableSet.of("$1"), parser.getQueryParameters("select * from foo where bar=$1"));
assertEquals(
ImmutableSet.of("$1", "$2", "$3"),
parser.getQueryParameters("select $2 from foo where bar=$1 and baz=$3"));
assertEquals(
ImmutableSet.of("$1", "$3"),
parser.getQueryParameters("select '$2' from foo where bar=$1 and baz in ($1, $3)"));
assertEquals(
ImmutableSet.of("$1"),
parser.getQueryParameters("select '$2' from foo where bar=$1 and baz=$foo"));
}

private void assertUnclosedLiteral(String sql) {
try {
parser.convertPositionalParametersToNamedParameters('?', sql);
Expand Down

0 comments on commit 7357ac6

Please sign in to comment.