Skip to content

Commit

Permalink
SQL: add support for passing query parameters in REST API calls (#51029)
Browse files Browse the repository at this point in the history
* REST PreparedStatement-like query parameters are now supported in the form of an array of non-object, non-array values where ES SQL parser will try to infer the data type of the value being passed as parameter.
  • Loading branch information
astefan committed Jan 20, 2020
1 parent 38eb485 commit 45b8bf6
Show file tree
Hide file tree
Showing 8 changed files with 427 additions and 51 deletions.
37 changes: 36 additions & 1 deletion docs/reference/sql/endpoints/rest.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* <<sql-pagination>>
* <<sql-rest-filtering>>
* <<sql-rest-columnar>>
* <<sql-rest-params>>
* <<sql-rest-fields>>

[[sql-rest-overview]]
Expand Down Expand Up @@ -337,7 +338,7 @@ Which will like return the
[[sql-rest-filtering]]
=== Filtering using {es} query DSL

You can filter the results that SQL will run on using a standard
One can filter the results that SQL will run on using a standard
{es} query DSL by specifying the query in the filter
parameter.

Expand Down Expand Up @@ -442,6 +443,36 @@ Which looks like:
--------------------------------------------------
// TESTRESPONSE[s/46ToAwFzQERYRjFaWEo1UVc1a1JtVjBZMmdCQUFBQUFBQUFBQUVXWjBaNlFXbzNOV0pVY21Wa1NUZDJhV2t3V2xwblp3PT3\/\/\/\/\/DwQBZgZhdXRob3IBBHRleHQAAAFmBG5hbWUBBHRleHQAAAFmCnBhZ2VfY291bnQBBGxvbmcBAAFmDHJlbGVhc2VfZGF0ZQEIZGF0ZXRpbWUBAAEP/$body.cursor/]

[[sql-rest-params]]
=== Passing parameters to a query

Using values in a query condition, for example, or in a `HAVING` statement can be done "inline",
by integrating the value in the query string itself:

[source,console]
--------------------------------------------------
POST /_sql?format=txt
{
"query": "SELECT YEAR(release_date) AS year FROM library WHERE page_count > 300 AND author = 'Frank Herbert' GROUP BY year HAVING COUNT(*) > 0"
}
--------------------------------------------------
// TEST[setup:library]

or it can be done by extracting the values in a separate list of parameters and using question mark placeholders (`?`) in the query string:

[source,console]
--------------------------------------------------
POST /_sql?format=txt
{
"query": "SELECT YEAR(release_date) AS year FROM library WHERE page_count > ? AND author = ? GROUP BY year HAVING COUNT(*) > ?",
"params": [300, "Frank Herbert", 0]
}
--------------------------------------------------
// TEST[setup:library]

[IMPORTANT]
The recommended way of passing values to a query is with question mark placeholders, to avoid any attempts of hacking or SQL injection.

[[sql-rest-fields]]
=== Supported REST parameters

Expand Down Expand Up @@ -495,6 +526,10 @@ More information available https://docs.oracle.com/javase/8/docs/api/java/time/Z
|false
|Whether to include <<frozen-indices, frozen-indices>> in the query execution or not (default).

|params
|none
|Optional list of parameters to replace question mark (`?`) placeholders inside the query.

|===

Do note that most parameters (outside the timeout and `columnar` ones) make sense only during the initial query - any follow-up pagination request only requires the `cursor` parameter as explained in the <<sql-pagination, pagination>> chapter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,11 @@ public void testBasicQueryWithParameters() throws IOException {
} else {
expected.put("rows", Arrays.asList(Arrays.asList("foo", 10)));
}

String params = mode.equals("jdbc") ? "{\"type\": \"integer\", \"value\": 10}, {\"type\": \"keyword\", \"value\": \"foo\"}" :
"10, \"foo\"";
assertResponse(expected, runSql(new StringEntity("{\"query\":\"SELECT test, ? param FROM test WHERE test = ?\", " +
"\"params\":[{\"type\": \"integer\", \"value\": 10}, {\"type\": \"keyword\", \"value\": \"foo\"}]"
"\"params\":[" + params + "]"
+ mode(mode) + columnarParameter(columnar) + "}", ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.XContentLocation;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.sql.proto.Mode;
Expand All @@ -22,6 +27,7 @@

import java.io.IOException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -75,11 +81,11 @@ protected static <R extends AbstractSqlQueryRequest> ObjectParser<R, Void> objec
parser.declareString(AbstractSqlQueryRequest::query, QUERY);
parser.declareString((request, mode) -> request.mode(Mode.fromString(mode)), MODE);
parser.declareString((request, clientId) -> request.clientId(clientId), CLIENT_ID);
parser.declareObjectArray(AbstractSqlQueryRequest::params, (p, c) -> SqlTypedParamValue.fromXContent(p), PARAMS);
parser.declareField(AbstractSqlQueryRequest::params, AbstractSqlQueryRequest::parseParams, PARAMS, ValueType.VALUE_ARRAY);
parser.declareString((request, zoneId) -> request.zoneId(ZoneId.of(zoneId)), TIME_ZONE);
parser.declareInt(AbstractSqlQueryRequest::fetchSize, FETCH_SIZE);
parser.declareString((request, timeout) -> request.requestTimeout(TimeValue.parseTimeValue(timeout, Protocol.REQUEST_TIMEOUT,
"request_timeout")), REQUEST_TIMEOUT);
"request_timeout")), REQUEST_TIMEOUT);
parser.declareString(
(request, timeout) -> request.pageTimeout(TimeValue.parseTimeValue(timeout, Protocol.PAGE_TIMEOUT, "page_timeout")),
PAGE_TIMEOUT);
Expand Down Expand Up @@ -117,6 +123,87 @@ public AbstractSqlQueryRequest params(List<SqlTypedParamValue> params) {
this.params = params;
return this;
}

private static List<SqlTypedParamValue> parseParams(XContentParser p) throws IOException {
List<SqlTypedParamValue> result = new ArrayList<>();
Token token = p.currentToken();

if (token == Token.START_ARRAY) {
Object value = null;
String type = null;
SqlTypedParamValue previousParam = null;
SqlTypedParamValue currentParam = null;

while ((token = p.nextToken()) != Token.END_ARRAY) {
XContentLocation loc = p.getTokenLocation();

if (token == Token.START_OBJECT) {
// we are at the start of a value/type pair... hopefully
currentParam = SqlTypedParamValue.fromXContent(p);
/*
* Always set the xcontentlocation for the first param just in case the first one happens to not meet the parsing rules
* that are checked later in validateParams method.
* Also, set the xcontentlocation of the param that is different from the previous param in list when it comes to
* its type being explicitly set or inferred.
*/
if ((previousParam != null && previousParam.hasExplicitType() == false) || result.isEmpty()) {
currentParam.tokenLocation(loc);
}
} else {
if (token == Token.VALUE_STRING) {
value = p.text();
type = "keyword";
} else if (token == Token.VALUE_NUMBER) {
XContentParser.NumberType numberType = p.numberType();
if (numberType == XContentParser.NumberType.INT) {
value = p.intValue();
type = "integer";
} else if (numberType == XContentParser.NumberType.LONG) {
value = p.longValue();
type = "long";
} else if (numberType == XContentParser.NumberType.FLOAT) {
value = p.floatValue();
type = "float";
} else if (numberType == XContentParser.NumberType.DOUBLE) {
value = p.doubleValue();
type = "double";
}
} else if (token == Token.VALUE_BOOLEAN) {
value = p.booleanValue();
type = "boolean";
} else if (token == Token.VALUE_NULL) {
value = null;
type = "null";
} else {
throw new XContentParseException(loc, "Failed to parse object: unexpected token [" + token + "] found");
}

currentParam = new SqlTypedParamValue(type, value, false);
if ((previousParam != null && previousParam.hasExplicitType() == true) || result.isEmpty()) {
currentParam.tokenLocation(loc);
}
}

result.add(currentParam);
previousParam = currentParam;
}
}

return result;
}

protected static void validateParams(List<SqlTypedParamValue> params, Mode mode) {
for(SqlTypedParamValue param : params) {
if (Mode.isDriver(mode) && param.hasExplicitType() == false) {
throw new XContentParseException(param.tokenLocation(), "[params] must be an array where each entry is an object with a "
+ "value/type pair");
}
if (Mode.isDriver(mode) == false && param.hasExplicitType() == true) {
throw new XContentParseException(param.tokenLocation(), "[params] must be an array where each entry is a single field (no "
+ "objects supported)");
}
}
}

/**
* The client's time zone
Expand Down Expand Up @@ -204,10 +291,11 @@ public AbstractSqlQueryRequest(StreamInput in) throws IOException {
public static void writeSqlTypedParamValue(StreamOutput out, SqlTypedParamValue value) throws IOException {
out.writeString(value.type);
out.writeGenericValue(value.value);
out.writeBoolean(value.hasExplicitType());
}

public static SqlTypedParamValue readSqlTypedParamValue(StreamInput in) throws IOException {
return new SqlTypedParamValue(in.readString(), in.readGenericValue());
return new SqlTypedParamValue(in.readString(), in.readGenericValue(), in.readBoolean());
}

@Override
Expand Down Expand Up @@ -248,6 +336,6 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), query, zoneId, fetchSize, requestTimeout, pageTimeout, filter);
return Objects.hash(super.hashCode(), query, params, zoneId, fetchSize, requestTimeout, pageTimeout, filter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

public static SqlQueryRequest fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
SqlQueryRequest request = PARSER.apply(parser, null);
validateParams(request.params(), request.mode());
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public String getDescription() {

public static SqlTranslateRequest fromXContent(XContentParser parser) {
SqlTranslateRequest request = PARSER.apply(parser, null);
validateParams(request.params(), request.mode());
return request;
}

Expand Down
Loading

0 comments on commit 45b8bf6

Please sign in to comment.