Skip to content

Commit

Permalink
feat: add ProtoBuf as a content type for pull queries over /query-str…
Browse files Browse the repository at this point in the history
…eam endpoint (#9103)

* save changes

* save changes rqtt functional

* save changes

* Delete cp.json

* fix: mediaType test

* add back logic to rqtt

* more unit tests

* pass all tests

* spacing nits

* nit test

* remove headerProtobuf from StreamedRow?

* change stremedrow matchers

* address comments

* Update RestQueryTranslationTest.java

* refactor rqtt

* unit test schemas

* unit test

* more unit tests

* address comments

* make schema non-optional

* nit

* address comments on KsqlTargetUtil

* refactor before fixing rqtt

* fix rqtt

* nit immutable bytes

* nit run api test on all http versions

* nit tests
  • Loading branch information
cprasad1 committed May 18, 2022
1 parent 9cb13c7 commit e64e284
Show file tree
Hide file tree
Showing 20 changed files with 1,359 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class RestTestCase implements Test {
private final Optional<Matcher<RestResponse<?>>> expectedError;
private final Optional<InputConditions> inputConditions;
private final Optional<OutputConditions> outputConditions;
private final boolean testPullWithProtoFormat;

RestTestCase(
final TestLocation location,
Expand All @@ -57,7 +58,8 @@ class RestTestCase implements Test {
final Collection<Response> responses,
final Optional<Matcher<RestResponse<?>>> expectedError,
final Optional<InputConditions> inputConditions,
final Optional<OutputConditions> outputConditions
final Optional<OutputConditions> outputConditions,
final boolean testPullWithProtoFormat
) {
this.name = requireNonNull(name, "name");
this.location = requireNonNull(location, "testPath");
Expand All @@ -70,6 +72,7 @@ class RestTestCase implements Test {
this.expectedError = requireNonNull(expectedError, "expectedError");
this.inputConditions = requireNonNull(inputConditions, "inputConditions");
this.outputConditions = requireNonNull(outputConditions, "outputConditions");
this.testPullWithProtoFormat = testPullWithProtoFormat;
}

@Override
Expand Down Expand Up @@ -127,4 +130,8 @@ public Optional<InputConditions> getInputConditions() {
public Optional<OutputConditions> getOutputConditions() {
return outputConditions;
}

public boolean isTestPullWithProtoFormat() {
return testPullWithProtoFormat;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ private static RestTestCase createTest(
test.getResponses(),
ee,
test.getInputConditions(),
test.getOutputConditions()
test.getOutputConditions(),
test.isTestPullWithProtoFormat()
);
} catch (final Exception e) {
throw new AssertionError(testName + ": Invalid test. " + e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.confluent.ksql.test.model.RecordNode;
import io.confluent.ksql.test.model.TopicNode;
Expand Down Expand Up @@ -50,6 +48,7 @@ public class RestTestCaseNode {
private final Optional<InputConditions> inputConditions;
private final Optional<OutputConditions> outputConditions;
private final boolean enabled;
private final boolean testPullWithProtoFormat;

public RestTestCaseNode(
@JsonProperty("name") final String name,
Expand All @@ -63,7 +62,8 @@ public RestTestCaseNode(
@JsonProperty("responses") final List<Response> responses,
@JsonProperty("inputConditions") final InputConditions inputConditions,
@JsonProperty("outputConditions") final OutputConditions outputConditions,
@JsonProperty("enabled") final Boolean enabled
@JsonProperty("enabled") final Boolean enabled,
@JsonProperty("testPullWithProtoFormat") final Boolean testPullWithProtoFormat
) {
this.name = name == null ? "" : name;
this.formats = immutableCopyOf(formats);
Expand All @@ -77,10 +77,15 @@ public RestTestCaseNode(
this.inputConditions = Optional.ofNullable(inputConditions);
this.outputConditions = Optional.ofNullable(outputConditions);
this.enabled = !Boolean.FALSE.equals(enabled);
this.testPullWithProtoFormat = Boolean.TRUE.equals(testPullWithProtoFormat);

validate();
}

public boolean isTestPullWithProtoFormat() {
return testPullWithProtoFormat;
}

public boolean isEnabled() {
return enabled;
}
Expand Down Expand Up @@ -149,5 +154,22 @@ private void validate() {
throw new InvalidFieldException("inputs and expectedError",
"can not both be set");
}

if (isTestPullWithProtoFormat()) {
final int numQueryResponses = (int) getResponses()
.stream()
.filter(response -> response.getContent().containsKey("query"))
.count();
final int numQueryProtoResponses = (int) getResponses()
.stream()
.filter(response -> response.getContent().containsKey("queryProto"))
.count();

if (numQueryResponses != numQueryProtoResponses) {
throw new InvalidFieldException("responses",
"Number of query responses must be equal to number of queryProto responses " +
"when `testPullWithProtoFormat` flag is set to `True`");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.kafka.schemaregistry.protobuf.ProtobufSchema;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.function.TestFunctionRegistry;
import io.confluent.ksql.rest.client.KsqlRestClient;
Expand All @@ -40,6 +42,7 @@
import io.confluent.ksql.rest.entity.KsqlStatementErrorMessage;
import io.confluent.ksql.rest.entity.StreamedRow;
import io.confluent.ksql.rest.integration.QueryStreamSubscriber;
import io.confluent.ksql.serde.protobuf.ProtobufNoSRConverter;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.test.rest.model.Response;
import io.confluent.ksql.test.tools.ExpectedRecordComparator;
Expand All @@ -59,6 +62,7 @@
import io.confluent.ksql.util.RetryUtil;
import io.confluent.ksql.util.TransientQueryMetadata;
import java.io.Closeable;
import java.io.IOException;
import java.math.BigDecimal;
import java.net.URL;
import java.time.Duration;
Expand Down Expand Up @@ -90,6 +94,7 @@
import org.apache.kafka.streams.TopologyDescription.Subtopology;
import org.hamcrest.Matcher;
import org.hamcrest.StringDescription;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -135,28 +140,29 @@ public class RestTestExecutor implements Closeable {
void buildAndExecuteQuery(final RestTestCase testCase) {
topicInfoCache.clear();

if (testCase.getStatements().size() < testCase.getExpectedResponses().size()) {
final StatementSplit statements = splitStatements(testCase);
final int expectedResponseSize = (int) testCase.getExpectedResponses()
.stream()
.filter(resp -> !resp.getContent().containsKey("queryProto"))
.count();

if (testCase.getStatements().size() < expectedResponseSize) {
throw new AssertionError("Invalid test case: more expected responses than statements. "
+ System.lineSeparator()
+ "statementCount: " + testCase.getStatements().size()
+ System.lineSeparator()
+ "responsesCount: " + testCase.getExpectedResponses().size());
+ "responsesCount: " + expectedResponseSize);
}

initializeTopics(testCase);

final StatementSplit statements = splitStatements(testCase);

testCase.getProperties().forEach(restClient::setProperty);

try {
final Optional<List<RqttResponse>> adminResults =
sendAdminStatements(testCase, statements.admin);

if (!adminResults.isPresent()) {
return;
}

final boolean waitForActivePushQueryToProduceInput = testCase.getInputConditions().isPresent()
&& testCase.getInputConditions().get().getWaitForActivePushQuery();
final Optional<InputConditionsParameters> postInputConditionRunnable;
Expand All @@ -175,13 +181,26 @@ void buildAndExecuteQuery(final RestTestCase testCase) {

final List<RqttResponse> queryResults = sendQueryStatements(testCase, statements.queries,
postInputConditionRunnable);

if (!queryResults.isEmpty()) {
failIfExpectingError(testCase);
}

List<RqttResponse> protoResponses = ImmutableList.of();

if (testCase.isTestPullWithProtoFormat()) {
protoResponses = statements.queries.stream()
.map(this::sendQueryStreamProtoStatement)
.filter(Optional::isPresent)
.map(Optional::get)
.map(RqttResponse::queryProto)
.collect(Collectors.toList());
}

final List<RqttResponse> responses = ImmutableList.<RqttResponse>builder()
.addAll(adminResults.get())
.addAll(queryResults)
.addAll(protoResponses)
.build();

verifyOutput(testCase);
Expand All @@ -192,7 +211,6 @@ void buildAndExecuteQuery(final RestTestCase testCase) {
// Give a few seconds for the transient queries to complete, otherwise, we'll go into teardown
// and leave the queries stuck.
waitForTransientQueriesToComplete();

} finally {
testCase.getProperties().keySet().forEach(restClient::unsetProperty);
}
Expand Down Expand Up @@ -374,6 +392,18 @@ private Optional<List<StreamedRow>> sendQueryStatement(
return Optional.of(resp.getResponse());
}

private Optional<List<StreamedRow>> sendQueryStreamProtoStatement(
final String sql
) {
final RestResponse<List<StreamedRow>> resp = restClient.makeQueryStreamRequestProto(sql, ImmutableMap.of());

if (resp.isErroneous()) {
return Optional.empty();
}

return Optional.of(resp.getResponse());
}

private Optional<List<StreamedRow>> sendQueryStatement(
final RestTestCase testCase,
final String sql,
Expand Down Expand Up @@ -881,6 +911,9 @@ static RqttResponse query(final List<StreamedRow> rows) {
return new RqttQueryResponse(rows);
}

static RqttResponse queryProto(final List<StreamedRow> rows) {
return new RqttQueryProtoResponse(rows);
}
void verify(
String expectedType,
Object expectedPayload,
Expand Down Expand Up @@ -993,6 +1026,100 @@ public void verify(
}
}

@VisibleForTesting
static class RqttQueryProtoResponse implements RqttResponse {

private static final TypeReference<Map<String, Object>> PAYLOAD_TYPE =
new TypeReference<Map<String, Object>>() {
};

private static final String INDENT = System.lineSeparator() + "\t";
private static final String HEADER_PROTOBUF = "header";
private static final String SCHEMA = "protoSchema";

private final List<StreamedRow> rows;

RqttQueryProtoResponse(final List<StreamedRow> rows) {
this.rows = requireNonNull(rows, "rows");
}

@SuppressWarnings("unchecked")
@Override
public void verify(
final String expectedType,
final Object expectedPayload,
final List<String> statements,
final int idx,
final boolean verifyOrder
) {
assertThat("Expected query response", expectedType, is("queryProto"));
assertThat("Query response should be an array", expectedPayload, is(instanceOf(List.class)));

final List<?> expectedRows = (List<?>) expectedPayload;

assertThat(
"row count mismatch."
+ System.lineSeparator()
+ "Expected: "
+ expectedRows.stream()
.map(Object::toString)
.collect(Collectors.joining(INDENT, INDENT, ""))
+ System.lineSeparator()
+ "Got: "
+ rows.stream()
.map(Object::toString)
.collect(Collectors.joining(INDENT, INDENT, ""))
+ System.lineSeparator(),
rows,
hasSize(expectedRows.size())
);

ProtobufSchema schema = null;
ProtobufNoSRConverter.Deserializer deserializer = new ProtobufNoSRConverter.Deserializer();
final JsonMapper mapper = new JsonMapper();
for (int i = 0; i != rows.size(); ++i) {
assertThat(
"Each row should be JSON object",
expectedRows.get(i),
is(instanceOf(Map.class))
);

final Map<String, Object> actual = asJson(rows.get(i), PAYLOAD_TYPE);
final Map<String, Object> expected = (Map<String, Object>) expectedRows.get(i);

if (actual.containsKey(HEADER_PROTOBUF)
&& ((HashMap<String, Object>) actual.get(HEADER_PROTOBUF)).containsKey(SCHEMA)) {

assertThat(i, is(0));

schema = new ProtobufSchema((String) ((Map<?, ?>)actual.get(HEADER_PROTOBUF)).get(SCHEMA));
final String actualSchema = (String) ((Map<?, ?>)actual.get(HEADER_PROTOBUF)).get(SCHEMA);
final String expectedSchema = (String) ((Map<?, ?>)expected.get(HEADER_PROTOBUF)).get(SCHEMA);

assertThat(actualSchema, is(expectedSchema));
} else if (actual.containsKey("finalMessage")) {
assertThat(actual, is(expected));
} else {
JSONObject row = new JSONObject(actual);

byte[] bytes;
try {
bytes = mapper.readTree(row.toString()).get("row").get("protobufBytes").binaryValue();
} catch (IOException e) {
throw new RuntimeException("Failed to deserialize the ProtoBuf bytes from the " +
"RQTT JSON response row: " + row);
}

final Object message = deserializer.deserialize(bytes, schema);
final String actualMessage = message.toString();
final String expectedMessage = expected.get("row").toString();

assertThat(actualMessage, is(expectedMessage));
}
}
}
}

private static final class StatementSplit {

final List<String> admin;
Expand Down
Loading

0 comments on commit e64e284

Please sign in to comment.