Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL: add support for passing query parameters in REST API calls #51029

Merged
merged 8 commits into from
Jan 20, 2020
37 changes: 36 additions & 1 deletion docs/reference/sql/endpoints/rest.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* <<sql-pagination>>
* <<sql-rest-filtering>>
* <<sql-rest-columnar>>
* <<sql-rest-params>>
* <<sql-rest-fields>>

[[sql-rest-overview]]
Expand Down Expand Up @@ -337,7 +338,7 @@ Which will like return the
[[sql-rest-filtering]]
=== Filtering using {es} query DSL

You can filter the results that SQL will run on using a standard
One can filter the results that SQL will run on using a standard
{es} query DSL by specifying the query in the filter
parameter.

Expand Down Expand Up @@ -442,6 +443,36 @@ Which looks like:
--------------------------------------------------
// TESTRESPONSE[s/46ToAwFzQERYRjFaWEo1UVc1a1JtVjBZMmdCQUFBQUFBQUFBQUVXWjBaNlFXbzNOV0pVY21Wa1NUZDJhV2t3V2xwblp3PT3\/\/\/\/\/DwQBZgZhdXRob3IBBHRleHQAAAFmBG5hbWUBBHRleHQAAAFmCnBhZ2VfY291bnQBBGxvbmcBAAFmDHJlbGVhc2VfZGF0ZQEIZGF0ZXRpbWUBAAEP/$body.cursor/]

[[sql-rest-params]]
=== Passing parameters to a query

Using values in a query condition, for example, or in a `HAVING` statement can be done "inline",
by integrating the value in the query string itself:

[source,console]
--------------------------------------------------
POST /_sql?format=txt
{
"query": "SELECT YEAR(release_date) AS year FROM library WHERE page_count > 300 AND author = 'Frank Herbert' GROUP BY year HAVING COUNT(*) > 0"
}
--------------------------------------------------
// TEST[setup:library]

or it can be done by extracting the values in a separate list of parameters and using question mark placeholders (`?`) in the query string:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a statement to mention that this is the recommended way of passing parameters to avoid hacking or SQL injection (which does not apply to us but nevertheless).


[source,console]
--------------------------------------------------
POST /_sql?format=txt
{
"query": "SELECT YEAR(release_date) AS year FROM library WHERE page_count > ? AND author = ? GROUP BY year HAVING COUNT(*) > ?",
"params": [300, "Frank Herbert", 0]
}
--------------------------------------------------
// TEST[setup:library]

[IMPORTANT]
The recommended way of passing values to a query is with question mark placeholders, to avoid any attempts of hacking or SQL injection.

[[sql-rest-fields]]
=== Supported REST parameters

Expand Down Expand Up @@ -495,6 +526,10 @@ More information available https://docs.oracle.com/javase/8/docs/api/java/time/Z
|false
|Whether to include <<frozen-indices, frozen-indices>> in the query execution or not (default).

|params
|none
|Optional list of parameters to replace question mark (`?`) placeholders inside the query.

|===

Do note that most parameters (outside the timeout and `columnar` ones) make sense only during the initial query - any follow-up pagination request only requires the `cursor` parameter as explained in the <<sql-pagination, pagination>> chapter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,11 @@ public void testBasicQueryWithParameters() throws IOException {
} else {
expected.put("rows", Arrays.asList(Arrays.asList("foo", 10)));
}

String params = mode.equals("jdbc") ? "{\"type\": \"integer\", \"value\": 10}, {\"type\": \"keyword\", \"value\": \"foo\"}" :
"10, \"foo\"";
assertResponse(expected, runSql(new StringEntity("{\"query\":\"SELECT test, ? param FROM test WHERE test = ?\", " +
"\"params\":[{\"type\": \"integer\", \"value\": 10}, {\"type\": \"keyword\", \"value\": \"foo\"}]"
"\"params\":[" + params + "]"
+ mode(mode) + columnarParameter(columnar) + "}", ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.XContentLocation;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.sql.proto.Mode;
Expand All @@ -22,6 +27,7 @@

import java.io.IOException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -75,11 +81,11 @@ protected static <R extends AbstractSqlQueryRequest> ObjectParser<R, Void> objec
parser.declareString(AbstractSqlQueryRequest::query, QUERY);
parser.declareString((request, mode) -> request.mode(Mode.fromString(mode)), MODE);
parser.declareString((request, clientId) -> request.clientId(clientId), CLIENT_ID);
parser.declareObjectArray(AbstractSqlQueryRequest::params, (p, c) -> SqlTypedParamValue.fromXContent(p), PARAMS);
parser.declareField(AbstractSqlQueryRequest::params, AbstractSqlQueryRequest::parseParams, PARAMS, ValueType.VALUE_ARRAY);
parser.declareString((request, zoneId) -> request.zoneId(ZoneId.of(zoneId)), TIME_ZONE);
parser.declareInt(AbstractSqlQueryRequest::fetchSize, FETCH_SIZE);
parser.declareString((request, timeout) -> request.requestTimeout(TimeValue.parseTimeValue(timeout, Protocol.REQUEST_TIMEOUT,
"request_timeout")), REQUEST_TIMEOUT);
"request_timeout")), REQUEST_TIMEOUT);
Copy link
Contributor

@bpintea bpintea Jan 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was the extra space intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, to have a uniform formatting for the field parsers' declaration. See here the initial formatting.

parser.declareString(
(request, timeout) -> request.pageTimeout(TimeValue.parseTimeValue(timeout, Protocol.PAGE_TIMEOUT, "page_timeout")),
PAGE_TIMEOUT);
Expand Down Expand Up @@ -117,6 +123,87 @@ public AbstractSqlQueryRequest params(List<SqlTypedParamValue> params) {
this.params = params;
return this;
}

private static List<SqlTypedParamValue> parseParams(XContentParser p) throws IOException {
List<SqlTypedParamValue> result = new ArrayList<>();
Token token = p.currentToken();

if (token == Token.START_ARRAY) {
Object value = null;
String type = null;
SqlTypedParamValue previousParam = null;
SqlTypedParamValue currentParam = null;

while ((token = p.nextToken()) != Token.END_ARRAY) {
XContentLocation loc = p.getTokenLocation();

if (token == Token.START_OBJECT) {
// we are at the start of a value/type pair... hopefully
currentParam = SqlTypedParamValue.fromXContent(p);
/*
* Always set the xcontentlocation for the first param just in case the first one happens to not meet the parsing rules
* that are checked later in validateParams method.
* Also, set the xcontentlocation of the param that is different from the previous param in list when it comes to
* its type being explicitly set or inferred.
*/
if ((previousParam != null && previousParam.hasExplicitType() == false) || result.isEmpty()) {
currentParam.tokenLocation(loc);
}
} else {
if (token == Token.VALUE_STRING) {
value = p.text();
type = "keyword";
} else if (token == Token.VALUE_NUMBER) {
XContentParser.NumberType numberType = p.numberType();
if (numberType == XContentParser.NumberType.INT) {
value = p.intValue();
type = "integer";
} else if (numberType == XContentParser.NumberType.LONG) {
value = p.longValue();
type = "long";
} else if (numberType == XContentParser.NumberType.FLOAT) {
value = p.floatValue();
type = "float";
} else if (numberType == XContentParser.NumberType.DOUBLE) {
value = p.doubleValue();
type = "double";
}
} else if (token == Token.VALUE_BOOLEAN) {
value = p.booleanValue();
type = "boolean";
} else if (token == Token.VALUE_NULL) {
value = null;
type = "null";
} else {
throw new XContentParseException(loc, "Failed to parse object: unexpected token [" + token + "] found");
}

currentParam = new SqlTypedParamValue(type, value, false);
if ((previousParam != null && previousParam.hasExplicitType() == true) || result.isEmpty()) {
currentParam.tokenLocation(loc);
}
}

result.add(currentParam);
previousParam = currentParam;
}
}

return result;
}

protected static void validateParams(List<SqlTypedParamValue> params, Mode mode) {
for(SqlTypedParamValue param : params) {
if (Mode.isDriver(mode) && param.hasExplicitType() == false) {
throw new XContentParseException(param.tokenLocation(), "[params] must be an array where each entry is an object with a "
+ "value/type pair");
}
if (Mode.isDriver(mode) == false && param.hasExplicitType() == true) {
throw new XContentParseException(param.tokenLocation(), "[params] must be an array where each entry is a single field (no "
+ "objects supported)");
}
}
}

/**
* The client's time zone
Expand Down Expand Up @@ -204,10 +291,11 @@ public AbstractSqlQueryRequest(StreamInput in) throws IOException {
public static void writeSqlTypedParamValue(StreamOutput out, SqlTypedParamValue value) throws IOException {
out.writeString(value.type);
out.writeGenericValue(value.value);
out.writeBoolean(value.hasExplicitType());
}

public static SqlTypedParamValue readSqlTypedParamValue(StreamInput in) throws IOException {
return new SqlTypedParamValue(in.readString(), in.readGenericValue());
return new SqlTypedParamValue(in.readString(), in.readGenericValue(), in.readBoolean());
}

@Override
Expand Down Expand Up @@ -248,6 +336,6 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), query, zoneId, fetchSize, requestTimeout, pageTimeout, filter);
return Objects.hash(super.hashCode(), query, params, zoneId, fetchSize, requestTimeout, pageTimeout, filter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

public static SqlQueryRequest fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
SqlQueryRequest request = PARSER.apply(parser, null);
validateParams(request.params(), request.mode());
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public String getDescription() {

public static SqlTranslateRequest fromXContent(XContentParser parser) {
SqlTranslateRequest request = PARSER.apply(parser, null);
validateParams(request.params(), request.mode());
return request;
}

Expand Down
Loading