diff --git a/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCase.java b/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCase.java index 87f795b890c..39094c7a763 100644 --- a/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCase.java +++ b/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCase.java @@ -45,6 +45,7 @@ class RestTestCase implements Test { private final Optional>> expectedError; private final Optional inputConditions; private final Optional outputConditions; + private final boolean testPullWithProtoFormat; RestTestCase( final TestLocation location, @@ -57,7 +58,8 @@ class RestTestCase implements Test { final Collection responses, final Optional>> expectedError, final Optional inputConditions, - final Optional outputConditions + final Optional outputConditions, + final boolean testPullWithProtoFormat ) { this.name = requireNonNull(name, "name"); this.location = requireNonNull(location, "testPath"); @@ -70,6 +72,7 @@ class RestTestCase implements Test { this.expectedError = requireNonNull(expectedError, "expectedError"); this.inputConditions = requireNonNull(inputConditions, "inputConditions"); this.outputConditions = requireNonNull(outputConditions, "outputConditions"); + this.testPullWithProtoFormat = testPullWithProtoFormat; } @Override @@ -127,4 +130,8 @@ public Optional getInputConditions() { public Optional getOutputConditions() { return outputConditions; } + + public boolean isTestPullWithProtoFormat() { + return testPullWithProtoFormat; + } } diff --git a/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCaseBuilder.java b/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCaseBuilder.java index 2a222771e40..c2f16d8e663 100644 --- a/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCaseBuilder.java +++ b/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCaseBuilder.java @@ -105,7 +105,8 @@ private static RestTestCase createTest( test.getResponses(), ee, test.getInputConditions(), - test.getOutputConditions() + test.getOutputConditions(), + test.isTestPullWithProtoFormat() ); } catch (final Exception e) { throw new AssertionError(testName + ": Invalid test. " + e.getMessage(), e); diff --git a/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCaseNode.java b/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCaseNode.java index 0c1a5ab7b9c..5488a7df43a 100644 --- a/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCaseNode.java +++ b/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestCaseNode.java @@ -19,8 +19,6 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import io.confluent.ksql.test.model.RecordNode; import io.confluent.ksql.test.model.TopicNode; @@ -50,6 +48,7 @@ public class RestTestCaseNode { private final Optional inputConditions; private final Optional outputConditions; private final boolean enabled; + private final boolean testPullWithProtoFormat; public RestTestCaseNode( @JsonProperty("name") final String name, @@ -63,7 +62,8 @@ public RestTestCaseNode( @JsonProperty("responses") final List responses, @JsonProperty("inputConditions") final InputConditions inputConditions, @JsonProperty("outputConditions") final OutputConditions outputConditions, - @JsonProperty("enabled") final Boolean enabled + @JsonProperty("enabled") final Boolean enabled, + @JsonProperty("testPullWithProtoFormat") final Boolean testPullWithProtoFormat ) { this.name = name == null ? "" : name; this.formats = immutableCopyOf(formats); @@ -77,10 +77,15 @@ public RestTestCaseNode( this.inputConditions = Optional.ofNullable(inputConditions); this.outputConditions = Optional.ofNullable(outputConditions); this.enabled = !Boolean.FALSE.equals(enabled); + this.testPullWithProtoFormat = Boolean.TRUE.equals(testPullWithProtoFormat); validate(); } + public boolean isTestPullWithProtoFormat() { + return testPullWithProtoFormat; + } + public boolean isEnabled() { return enabled; } @@ -149,5 +154,22 @@ private void validate() { throw new InvalidFieldException("inputs and expectedError", "can not both be set"); } + + if (isTestPullWithProtoFormat()) { + final int numQueryResponses = (int) getResponses() + .stream() + .filter(response -> response.getContent().containsKey("query")) + .count(); + final int numQueryProtoResponses = (int) getResponses() + .stream() + .filter(response -> response.getContent().containsKey("queryProto")) + .count(); + + if (numQueryResponses != numQueryProtoResponses) { + throw new InvalidFieldException("responses", + "Number of query responses must be equal to number of queryProto responses " + + "when `testPullWithProtoFormat` flag is set to `True`"); + } + } } } diff --git a/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestExecutor.java b/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestExecutor.java index a3fe62ff991..430ba11e4a9 100644 --- a/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestExecutor.java +++ b/ksqldb-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestExecutor.java @@ -26,10 +26,12 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.json.JsonMapper; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.confluent.kafka.schemaregistry.protobuf.ProtobufSchema; import io.confluent.ksql.KsqlExecutionContext; import io.confluent.ksql.function.TestFunctionRegistry; import io.confluent.ksql.rest.client.KsqlRestClient; @@ -40,6 +42,7 @@ import io.confluent.ksql.rest.entity.KsqlStatementErrorMessage; import io.confluent.ksql.rest.entity.StreamedRow; import io.confluent.ksql.rest.integration.QueryStreamSubscriber; +import io.confluent.ksql.serde.protobuf.ProtobufNoSRConverter; import io.confluent.ksql.services.ServiceContext; import io.confluent.ksql.test.rest.model.Response; import io.confluent.ksql.test.tools.ExpectedRecordComparator; @@ -59,6 +62,7 @@ import io.confluent.ksql.util.RetryUtil; import io.confluent.ksql.util.TransientQueryMetadata; import java.io.Closeable; +import java.io.IOException; import java.math.BigDecimal; import java.net.URL; import java.time.Duration; @@ -90,6 +94,7 @@ import org.apache.kafka.streams.TopologyDescription.Subtopology; import org.hamcrest.Matcher; import org.hamcrest.StringDescription; +import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -135,28 +140,29 @@ public class RestTestExecutor implements Closeable { void buildAndExecuteQuery(final RestTestCase testCase) { topicInfoCache.clear(); - if (testCase.getStatements().size() < testCase.getExpectedResponses().size()) { + final StatementSplit statements = splitStatements(testCase); + final int expectedResponseSize = (int) testCase.getExpectedResponses() + .stream() + .filter(resp -> !resp.getContent().containsKey("queryProto")) + .count(); + + if (testCase.getStatements().size() < expectedResponseSize) { throw new AssertionError("Invalid test case: more expected responses than statements. " + System.lineSeparator() + "statementCount: " + testCase.getStatements().size() + System.lineSeparator() - + "responsesCount: " + testCase.getExpectedResponses().size()); + + "responsesCount: " + expectedResponseSize); } initializeTopics(testCase); - - final StatementSplit statements = splitStatements(testCase); - testCase.getProperties().forEach(restClient::setProperty); try { final Optional> adminResults = sendAdminStatements(testCase, statements.admin); - if (!adminResults.isPresent()) { return; } - final boolean waitForActivePushQueryToProduceInput = testCase.getInputConditions().isPresent() && testCase.getInputConditions().get().getWaitForActivePushQuery(); final Optional postInputConditionRunnable; @@ -175,13 +181,26 @@ void buildAndExecuteQuery(final RestTestCase testCase) { final List queryResults = sendQueryStatements(testCase, statements.queries, postInputConditionRunnable); + if (!queryResults.isEmpty()) { failIfExpectingError(testCase); } + List protoResponses = ImmutableList.of(); + + if (testCase.isTestPullWithProtoFormat()) { + protoResponses = statements.queries.stream() + .map(this::sendQueryStreamProtoStatement) + .filter(Optional::isPresent) + .map(Optional::get) + .map(RqttResponse::queryProto) + .collect(Collectors.toList()); + } + final List responses = ImmutableList.builder() .addAll(adminResults.get()) .addAll(queryResults) + .addAll(protoResponses) .build(); verifyOutput(testCase); @@ -192,7 +211,6 @@ void buildAndExecuteQuery(final RestTestCase testCase) { // Give a few seconds for the transient queries to complete, otherwise, we'll go into teardown // and leave the queries stuck. waitForTransientQueriesToComplete(); - } finally { testCase.getProperties().keySet().forEach(restClient::unsetProperty); } @@ -374,6 +392,18 @@ private Optional> sendQueryStatement( return Optional.of(resp.getResponse()); } + private Optional> sendQueryStreamProtoStatement( + final String sql + ) { + final RestResponse> resp = restClient.makeQueryStreamRequestProto(sql, ImmutableMap.of()); + + if (resp.isErroneous()) { + return Optional.empty(); + } + + return Optional.of(resp.getResponse()); + } + private Optional> sendQueryStatement( final RestTestCase testCase, final String sql, @@ -881,6 +911,9 @@ static RqttResponse query(final List rows) { return new RqttQueryResponse(rows); } + static RqttResponse queryProto(final List rows) { + return new RqttQueryProtoResponse(rows); + } void verify( String expectedType, Object expectedPayload, @@ -993,6 +1026,100 @@ public void verify( } } + @VisibleForTesting + static class RqttQueryProtoResponse implements RqttResponse { + + private static final TypeReference> PAYLOAD_TYPE = + new TypeReference>() { + }; + + private static final String INDENT = System.lineSeparator() + "\t"; + private static final String HEADER_PROTOBUF = "header"; + private static final String SCHEMA = "protoSchema"; + + private final List rows; + + RqttQueryProtoResponse(final List rows) { + this.rows = requireNonNull(rows, "rows"); + } + + @SuppressWarnings("unchecked") + @Override + public void verify( + final String expectedType, + final Object expectedPayload, + final List statements, + final int idx, + final boolean verifyOrder + ) { + assertThat("Expected query response", expectedType, is("queryProto")); + assertThat("Query response should be an array", expectedPayload, is(instanceOf(List.class))); + + final List expectedRows = (List) expectedPayload; + + assertThat( + "row count mismatch." + + System.lineSeparator() + + "Expected: " + + expectedRows.stream() + .map(Object::toString) + .collect(Collectors.joining(INDENT, INDENT, "")) + + System.lineSeparator() + + "Got: " + + rows.stream() + .map(Object::toString) + .collect(Collectors.joining(INDENT, INDENT, "")) + + System.lineSeparator(), + rows, + hasSize(expectedRows.size()) + ); + + ProtobufSchema schema = null; + ProtobufNoSRConverter.Deserializer deserializer = new ProtobufNoSRConverter.Deserializer(); + final JsonMapper mapper = new JsonMapper(); + for (int i = 0; i != rows.size(); ++i) { + assertThat( + "Each row should be JSON object", + expectedRows.get(i), + is(instanceOf(Map.class)) + ); + + final Map actual = asJson(rows.get(i), PAYLOAD_TYPE); + final Map expected = (Map) expectedRows.get(i); + + if (actual.containsKey(HEADER_PROTOBUF) + && ((HashMap) actual.get(HEADER_PROTOBUF)).containsKey(SCHEMA)) { + + assertThat(i, is(0)); + + schema = new ProtobufSchema((String) ((Map)actual.get(HEADER_PROTOBUF)).get(SCHEMA)); + final String actualSchema = (String) ((Map)actual.get(HEADER_PROTOBUF)).get(SCHEMA); + final String expectedSchema = (String) ((Map)expected.get(HEADER_PROTOBUF)).get(SCHEMA); + + assertThat(actualSchema, is(expectedSchema)); + } else if (actual.containsKey("finalMessage")) { + assertThat(actual, is(expected)); + } else { + JSONObject row = new JSONObject(actual); + + byte[] bytes; + try { + bytes = mapper.readTree(row.toString()).get("row").get("protobufBytes").binaryValue(); + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize the ProtoBuf bytes from the " + + "RQTT JSON response row: " + row); + } + + final Object message = deserializer.deserialize(bytes, schema); + final String actualMessage = message.toString(); + final String expectedMessage = expected.get("row").toString(); + + assertThat(actualMessage, is(expectedMessage)); + } + } + } + } + private static final class StatementSplit { final List admin; diff --git a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-protobuf.json b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-protobuf.json new file mode 100644 index 00000000000..d5300f6e037 --- /dev/null +++ b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-protobuf.json @@ -0,0 +1,293 @@ +{ + "comments": [ + "Tests covering Pull queries of materialized using CST tables" + ], + "tests": [ + { + "name": "transform a map with array values", + "testPullWithProtoFormat": true, + "statements": [ + "CREATE TABLE TEST (ID BIGINT PRIMARY KEY, VALUE MAP>) WITH (kafka_topic='test_topic', value_format='AVRO');", + "CREATE TABLE MAT_TABLE AS SELECT ID, VALUE FROM TEST;", + "SELECT ID, TRANSFORM(TRANSFORM(VALUE, (x,y) => x, (x,y) => FIlTER(y, z => z < 5)), (x,y) => UCASE(x) , (k,v) => ARRAY_MAX(v)) as FILTERED_TRANSFORMED from MAT_TABLE;" + ], + "inputs": [ + {"topic": "test_topic", "key": 0, "value": {"value": {"a": [2,null,5,4], "b": [-1,-2]}}}, + {"topic": "test_topic", "key": 1, "value": {"value": {"c": [null,null,-1], "t": [3, 1]}}}, + {"topic": "test_topic", "key": 2, "value": {"value": {"d": [4], "q": [0, 0]}}} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`ID` BIGINT KEY, `FILTERED_TRANSFORMED` MAP"}}, + {"row":{"columns":[1, {"C": -1, "T": 3}]}}, + {"row":{"columns":[2, {"D": 4, "Q": 0}]}}, + {"row":{"columns":[0, {"A": 4, "B": -1}]}} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int64 ID = 1;\n repeated ConnectDefault2Entry FILTERED_TRANSFORMED = 2;\n\n message ConnectDefault2Entry {\n string key = 1;\n int32 value = 2;\n }\n}\n"}}, + {"row": "ID: 1\nFILTERED_TRANSFORMED {\n key: \"C\"\n value: -1\n}\nFILTERED_TRANSFORMED {\n key: \"T\"\n value: 3\n}\n"}, + {"row": "ID: 2\nFILTERED_TRANSFORMED {\n key: \"Q\"\n}\nFILTERED_TRANSFORMED {\n key: \"D\"\n value: 4\n}\n"}, + {"row": "FILTERED_TRANSFORMED {\n key: \"A\"\n value: 4\n}\nFILTERED_TRANSFORMED {\n key: \"B\"\n value: -1\n}\n"} + ]} + ] + }, + { + "name": "windowed - select star and ROWTIME", + "testPullWithProtoFormat": true, + "statements": [ + "CREATE STREAM INPUT (ID STRING KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ID;", + "SELECT ID, WINDOWSTART, WINDOWEND, COUNT, ROWTIME FROM AGGREGATE WHERE ID='10';", + "SELECT ID, WINDOWSTART, WINDOWEND, COUNT, ROWTIME FROM AGGREGATE WHERE ID='missing';" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12346, "key": "11", "value": {"val": 1}}, + {"topic": "test_topic", "timestamp": 12345, "key": "10", "value": {"val": 2}} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`ID` STRING KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT, `ROWTIME` BIGINT"}}, + {"row":{"columns":["10", 12000, 13000, 1, 12345]}} + ]}, + {"query": [ + {"header":{"schema":"`ID` STRING KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT, `ROWTIME` BIGINT"}} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n string ID = 1;\n int64 WINDOWSTART = 2;\n int64 WINDOWEND = 3;\n int64 COUNT = 4;\n int64 ROWTIME = 5;\n}\n"}}, + {"row": "ID: \"10\"\nWINDOWSTART: 12000\nWINDOWEND: 13000\nCOUNT: 1\nROWTIME: 12345\n"} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n string ID = 1;\n int64 WINDOWSTART = 2;\n int64 WINDOWEND = 3;\n int64 COUNT = 4;\n int64 ROWTIME = 5;\n}\n"}} + ]} + ] + }, + { + "name": "pull query on stream with headers", + "testPullWithProtoFormat": true, + "properties": { + "ksql.query.pull.stream.enabled": true + }, + "statements": [ + "CREATE STREAM S1 (MYKEY INT KEY, MYVALUE INT, MYHEADERS ARRAY> HEADERS) WITH (kafka_topic='test_topic', value_format='JSON');", + "SELECT * FROM S1;" + ], + "topics": [ + {"name": "test_topic", "partitions": 1} + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12365, "key": 10, "value": {"myvalue": 1}, "headers": []}, + {"topic": "test_topic", "timestamp": 12366, "key": 11, "value": {"myvalue": 2}, "headers": [{"KEY": "abc", "VALUE": "IQ=="}]} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`MYKEY` INTEGER, `MYVALUE` INTEGER, `MYHEADERS` ARRAY>"}}, + {"row":{"columns":[10, 1, []]}}, + {"row":{"columns":[11, 2, [{"KEY": "abc", "VALUE": "IQ=="}]]}}, + {"finalMessage":"Query Completed"} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int32 MYKEY = 1;\n int32 MYVALUE = 2;\n repeated ConnectDefault2 MYHEADERS = 3;\n\n message ConnectDefault2 {\n string KEY = 1;\n bytes VALUE = 2;\n }\n}\n"}}, + {"row": "MYKEY: 10\nMYVALUE: 1\n"}, + {"row": "MYKEY: 11\nMYVALUE: 2\nMYHEADERS {\n KEY: \"abc\"\n VALUE: \"!\"\n}\n"}, + {"finalMessage":"Query Completed"} + ]} + ] + }, + { + "name": "select * against materialized table with headers", + "testPullWithProtoFormat": true, + "statements": [ + "CREATE TABLE INPUT (ID STRING PRIMARY KEY, GRADE STRING, RANK INT, HEAD BYTES HEADER('abc')) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE INPUT_QUERYABLE AS SELECT * FROM INPUT;", + "SELECT * FROM INPUT_QUERYABLE;" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12346, "key": "11", "value": {"GRADE": "A", "RANK": 1}, "headers": [{"KEY": "abc", "VALUE": "IQ=="}]}, + {"topic": "test_topic", "timestamp": 12345, "key": "10", "value": {"GRADE": "B", "RANK": 2}, "headers": [{"KEY": "abc", "VALUE": "IQ=="}]} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`ID` STRING KEY, `GRADE` STRING, `RANK` INTEGER, `HEAD` BYTES"}}, + {"row":{"columns":["11", "A", 1, "IQ=="]}}, + {"row":{"columns":["10", "B", 2, "IQ=="]}} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n string ID = 1;\n string GRADE = 2;\n int32 RANK = 3;\n bytes HEAD = 4;\n}\n"}}, + {"row": "ID: \"11\"\nGRADE: \"A\"\nRANK: 1\nHEAD: \"!\"\n"}, + {"row": "ID: \"10\"\nGRADE: \"B\"\nRANK: 2\nHEAD: \"!\"\n"} + ]} + ] + }, + { + "name": "on stream", + "testPullWithProtoFormat": true, + "format": ["JSON"], + "statements": [ + "CREATE STREAM riderLocations (profileId VARCHAR, latitude DOUBLE, longitude DOUBLE) WITH (kafka_topic='test_topic', value_format='{FORMAT}');", + "SELECT * FROM riderLocations LIMIT 2;", + "SELECT * FROM riderLocations LIMIT 0;", + "SELECT * FROM riderLocations LIMIT 5;" + ], + "topics": [ + {"name": "test_topic", "partitions": 1} // to get a stable ordering + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 1000001, "value": {"profileId": "which", "latitude": 37.7877, "longitude": -122.4205}}, + {"topic": "test_topic", "timestamp": 1000002, "value": {"profileId": "there", "latitude": 37.3903, "longitude": -122.0643}}, + {"topic": "test_topic", "timestamp": 1000003, "value": {"profileId": "their", "latitude": 37.3952, "longitude": -122.0813}} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`PROFILEID` STRING, `LATITUDE` DOUBLE, `LONGITUDE` DOUBLE"}}, + {"row": {"columns": ["which", 37.7877, -122.4205]}}, + {"row": {"columns": ["there", 37.3903, -122.0643]}}, + {"finalMessage":"Limit Reached"} + ]}, + {"query": [ + {"header":{"schema":"`PROFILEID` STRING, `LATITUDE` DOUBLE, `LONGITUDE` DOUBLE"}}, + {"finalMessage":"Limit Reached"} + ]}, + {"query": [ + {"header":{"schema":"`PROFILEID` STRING, `LATITUDE` DOUBLE, `LONGITUDE` DOUBLE"}}, + {"row":{"columns":["which",37.7877,-122.4205]}}, + {"row":{"columns":["there",37.3903,-122.0643]}}, + {"row":{"columns":["their",37.3952,-122.0813]}}, + {"finalMessage":"Query Completed"} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n string PROFILEID = 1;\n double LATITUDE = 2;\n double LONGITUDE = 3;\n}\n"}}, + {"row": "PROFILEID: \"which\"\nLATITUDE: 37.7877\nLONGITUDE: -122.4205\n"}, + {"row": "PROFILEID: \"there\"\nLATITUDE: 37.3903\nLONGITUDE: -122.0643\n"}, + {"finalMessage":"Limit Reached"} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n string PROFILEID = 1;\n double LATITUDE = 2;\n double LONGITUDE = 3;\n}\n"}}, + {"finalMessage":"Limit Reached"} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n string PROFILEID = 1;\n double LATITUDE = 2;\n double LONGITUDE = 3;\n}\n"}}, + {"row": "PROFILEID: \"which\"\nLATITUDE: 37.7877\nLONGITUDE: -122.4205\n"}, + {"row": "PROFILEID: \"there\"\nLATITUDE: 37.3903\nLONGITUDE: -122.0643\n"}, + {"row": "PROFILEID: \"their\"\nLATITUDE: 37.3952\nLONGITUDE: -122.0813\n"}, + {"finalMessage":"Query Completed"} + ]} + ] + }, + { + "name": "empty response on empty stream", + "testPullWithProtoFormat": true, + "statements": [ + "CREATE STREAM S1 (MYKEY INT KEY, MYVALUE INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "SELECT * FROM S1;" + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`MYKEY` INTEGER, `MYVALUE` INTEGER"}}, + {"finalMessage":"Query Completed"} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int32 MYKEY = 1;\n int32 MYVALUE = 2;\n}\n"}}, + {"finalMessage":"Query Completed"} + ]} + ] + }, + { + "name": "select * against CST table", + "testPullWithProtoFormat": true, + "statements": [ + "CREATE SOURCE TABLE INPUT (K INT PRIMARY KEY, text STRING) WITH (kafka_topic='test_topic', value_format='DELIMITED');", + "SELECT * FROM INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12345, "key": 1, "value": "a1"}, + {"topic": "test_topic", "timestamp": 12345, "key": 2, "value": "a2"}, + {"topic": "test_topic", "timestamp": 12345, "key": 3, "value": "a3"} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`K` INTEGER KEY, `TEXT` STRING"}}, + {"row":{"columns":[1,"a1"]}}, + {"row":{"columns":[2,"a2"]}}, + {"row":{"columns":[3,"a3"]}} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int32 K = 1;\n string TEXT = 2;\n}\n"}}, + {"row": "K: 1\nTEXT: \"a1\"\n"}, + {"row": "K: 2\nTEXT: \"a2\"\n"}, + {"row": "K: 3\nTEXT: \"a3\"\n"} + ]} + ] + }, + { + "name": "select * against CST table and filter by key", + "testPullWithProtoFormat": true, + "statements": [ + "CREATE SOURCE TABLE INPUT (K INT PRIMARY KEY, text STRING) WITH (kafka_topic='test_topic', value_format='DELIMITED');", + "SELECT * FROM INPUT WHERE K=2;" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12345, "key": 1, "value": "a1"}, + {"topic": "test_topic", "timestamp": 12345, "key": 2, "value": "a2"}, + {"topic": "test_topic", "timestamp": 12345, "key": 3, "value": "a3"} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`K` INTEGER KEY, `TEXT` STRING"}}, + {"row":{"columns":[2,"a2"]}} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int32 K = 1;\n string TEXT = 2;\n}\n"}}, + {"row": "K: 2\nTEXT: \"a2\"\n"} + ]} + ] + }, + { + "name": "select with projection table scan and key lookup", + "testPullWithProtoFormat": true, + "statements": [ + "CREATE SOURCE TABLE INPUT (K INT PRIMARY KEY, text STRING) WITH (kafka_topic='test_topic', value_format='DELIMITED');", + "SELECT K, TEXT FROM INPUT;", + "SELECT K, TEXT FROM INPUT WHERE K=2;" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12345, "key": 1, "value": "a1"}, + {"topic": "test_topic", "timestamp": 12345, "key": 2, "value": "a2"}, + {"topic": "test_topic", "timestamp": 12345, "key": 3, "value": "a3"} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`K` INTEGER KEY, `TEXT` STRING"}}, + {"row":{"columns":[1,"a1"]}}, + {"row":{"columns":[2,"a2"]}}, + {"row":{"columns":[3,"a3"]}} + ]}, + {"query": [ + {"header":{"schema":"`K` INTEGER KEY, `TEXT` STRING"}}, + {"row":{"columns":[2,"a2"]}} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int32 K = 1;\n string TEXT = 2;\n}\n"}}, + {"row": "K: 1\nTEXT: \"a1\"\n"}, + {"row": "K: 2\nTEXT: \"a2\"\n"}, + {"row": "K: 3\nTEXT: \"a3\"\n"} + ]}, + {"queryProto": [ + {"header":{"protoSchema":"syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int32 K = 1;\n string TEXT = 2;\n}\n"}}, + {"row": "K: 2\nTEXT: \"a2\"\n"} + ]} + ] + } + ] +} \ No newline at end of file diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/JsonStreamedRowResponseWriter.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/JsonStreamedRowResponseWriter.java index bca4bc6c643..8010b2cb166 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/JsonStreamedRowResponseWriter.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/JsonStreamedRowResponseWriter.java @@ -17,8 +17,13 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import io.confluent.connect.protobuf.ProtobufData; +import io.confluent.connect.protobuf.ProtobufDataConfig; +import io.confluent.kafka.schemaregistry.protobuf.ProtobufSchema; import io.confluent.ksql.GenericRow; import io.confluent.ksql.api.spi.QueryPublisher; import io.confluent.ksql.query.QueryId; @@ -30,6 +35,10 @@ import io.confluent.ksql.rest.entity.QueryResponseMetadata; import io.confluent.ksql.rest.entity.StreamedRow; import io.confluent.ksql.rest.server.resources.streaming.TombstoneFactory; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.serde.connect.ConnectSchemas; +import io.confluent.ksql.serde.connect.KsqlConnectSerializer; +import io.confluent.ksql.serde.protobuf.ProtobufNoSRSerdeFactory; import io.confluent.ksql.util.KeyValue; import io.confluent.ksql.util.KeyValueMetadata; import io.confluent.ksql.util.KsqlHostInfo; @@ -40,6 +49,9 @@ import java.time.Clock; import java.util.List; import java.util.Optional; +import org.apache.kafka.connect.data.ConnectSchema; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Struct; public class JsonStreamedRowResponseWriter implements QueryStreamResponseWriter { @@ -56,6 +68,7 @@ public class JsonStreamedRowResponseWriter implements QueryStreamResponseWriter // output a final message. private final boolean bufferOutput; private final Context context; + private final RowFormat rowFormat; private final WriterState writerState; private StreamedRow lastRow; private long timerId = -1; @@ -69,7 +82,8 @@ public JsonStreamedRowResponseWriter( final Optional limitMessage, final Clock clock, final boolean bufferOutput, - final Context context + final Context context, + final RowFormat rowFormat ) { this.response = response; this.tombstoneFactory = queryPublisher.getResultType().map( @@ -83,12 +97,12 @@ public JsonStreamedRowResponseWriter( Preconditions.checkState(bufferOutput || limitMessage.isPresent() || completionMessage.isPresent(), "If buffering isn't used, a limit/completion message must be set"); + this.rowFormat = rowFormat; } @Override public QueryStreamResponseWriter writeMetadata(final QueryResponseMetadata metaData) { - final StreamedRow streamedRow - = StreamedRow.header(new QueryId(metaData.queryId), metaData.schema); + final StreamedRow streamedRow = rowFormat.metadataRow(metaData); final Buffer buff = Buffer.buffer().appendByte((byte) '['); if (bufferOutput) { writeBuffer(buff, true); @@ -112,14 +126,8 @@ public QueryStreamResponseWriter writeRow( Preconditions.checkState(tombstoneFactory.isPresent(), "Should only have null values for query types that support them"); streamedRow = StreamedRow.tombstone(tombstoneFactory.get().createRow(keyValue)); - } else if (keyValueMetadata.getRowMetadata().isPresent() - && keyValueMetadata.getRowMetadata().get().getSourceNode().isPresent()) { - streamedRow = StreamedRow.pullRow(keyValue.value(), - toKsqlHostInfoEntity(keyValueMetadata.getRowMetadata().get().getSourceNode())); } else { - // Technically, this codepath is for both push and pull, but where there's no additional - // metadata, as there sometimes is with a pull query. - streamedRow = StreamedRow.pushRow(keyValue.value()); + streamedRow = rowFormat.dataRow(keyValueMetadata); } maybeCacheRowAndWriteLast(streamedRow); return this; @@ -178,6 +186,72 @@ public void end() { response.end(); } + public enum RowFormat { + PROTOBUF { + private transient ConnectSchema connectSchema; + private transient KsqlConnectSerializer serializer; + @Override + public StreamedRow metadataRow(final QueryResponseMetadata metaData) { + final LogicalSchema schema = metaData.schema; + final String queryId = metaData.queryId; + + connectSchema = ConnectSchemas.columnsToConnectSchema(schema.columns()); + serializer = new ProtobufNoSRSerdeFactory(ImmutableMap.of()) + .createSerializer(connectSchema, Struct.class, false); + return StreamedRow.headerProtobuf( + new QueryId(queryId), schema, logicalToProtoSchema(schema)); + } + + @Override + public StreamedRow dataRow(final KeyValueMetadata, GenericRow> keyValueMetadata) { + final KeyValue, GenericRow> keyValue = keyValueMetadata.getKeyValue(); + final Struct ksqlRecord = new Struct(connectSchema); + int i = 0; + for (Field field : connectSchema.fields()) { + ksqlRecord.put( + field, + keyValue.value().get(i)); + i += 1; + } + final byte[] protoMessage = serializer.serialize("", ksqlRecord); + return StreamedRow.pullRowProtobuf(protoMessage); + } + }, + JSON { + @Override + public StreamedRow metadataRow(final QueryResponseMetadata metaData) { + return StreamedRow.header(new QueryId(metaData.queryId), metaData.schema); + } + + @Override + public StreamedRow dataRow(final KeyValueMetadata, GenericRow> keyValueMetadata) { + final KeyValue, GenericRow> keyValue = keyValueMetadata.getKeyValue(); + if (keyValueMetadata.getRowMetadata().isPresent() + && keyValueMetadata.getRowMetadata().get().getSourceNode().isPresent()) { + return StreamedRow.pullRow(keyValue.value(), + toKsqlHostInfoEntity(keyValueMetadata.getRowMetadata().get().getSourceNode())); + } else { + // Technically, this codepath is for both push and pull, but where there's no additional + // metadata, as there sometimes is with a pull query. + return StreamedRow.pushRow(keyValue.value()); + } + } + }; + + public abstract StreamedRow metadataRow(QueryResponseMetadata metaData); + + public abstract StreamedRow dataRow(KeyValueMetadata, GenericRow> keyValueMetadata); + } + + @VisibleForTesting + static String logicalToProtoSchema(final LogicalSchema schema) { + final ConnectSchema connectSchema = ConnectSchemas.columnsToConnectSchema(schema.columns()); + + final ProtobufSchema protobufSchema = new ProtobufData( + new ProtobufDataConfig(ImmutableMap.of())).fromConnectSchema(connectSchema); + return protobufSchema.canonicalString(); + } + // This does the writing of the rows and possibly caches the current row, writing the last cached // value. private void maybeCacheRowAndWriteLast(final StreamedRow streamedRow) { diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java index 487a871cd2c..77708ef96a6 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java @@ -22,6 +22,7 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import io.confluent.ksql.api.auth.DefaultApiSecurityContext; +import io.confluent.ksql.api.server.JsonStreamedRowResponseWriter.RowFormat; import io.confluent.ksql.api.spi.Endpoints; import io.confluent.ksql.api.spi.QueryPublisher; import io.confluent.ksql.rest.entity.KsqlMediaType; @@ -124,11 +125,29 @@ private QueryStreamResponseWriter getQueryStreamResponseWriter( || (contentType == null && !queryCompatibilityMode)) { // Default return new DelimitedQueryStreamResponseWriter(routingContext.response()); + } else if (KsqlMediaType.KSQL_V1_PROTOBUF.mediaType().equals(contentType)) { + return new JsonStreamedRowResponseWriter( + routingContext.response(), + queryPublisher, + completionMessage, + limitMessage, + Clock.systemUTC(), + bufferOutput, + context, + RowFormat.PROTOBUF + ); } else if (KsqlMediaType.KSQL_V1_JSON.mediaType().equals(contentType) || ((contentType == null || JSON_CONTENT_TYPE.equals(contentType) && queryCompatibilityMode))) { - return new JsonStreamedRowResponseWriter(routingContext.response(), queryPublisher, - completionMessage, limitMessage, Clock.systemUTC(), bufferOutput, context); + return new JsonStreamedRowResponseWriter( + routingContext.response(), + queryPublisher, + completionMessage, + limitMessage, + Clock.systemUTC(), + bufferOutput, + context, + RowFormat.JSON); } else { return new JsonQueryStreamResponseWriter(routingContext.response()); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ServerVerticle.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ServerVerticle.java index eea9c4c18f4..7cfae6f67f6 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ServerVerticle.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ServerVerticle.java @@ -142,6 +142,7 @@ private Router setupRouter() { .produces(DELIMITED_CONTENT_TYPE) .produces(JSON_CONTENT_TYPE) .produces(KsqlMediaType.KSQL_V1_JSON.mediaType()) + .produces(KsqlMediaType.KSQL_V1_PROTOBUF.mediaType()) .handler(BodyHandler.create(false)) .handler(new QueryStreamHandler(endpoints, connectionQueryManager, context, server, false)); router.route(HttpMethod.POST, "/inserts-stream") diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/server/JsonStreamedRowResponseWriterTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/server/JsonStreamedRowResponseWriterTest.java index b6510ddc048..2f9abae9ff1 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/server/JsonStreamedRowResponseWriterTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/server/JsonStreamedRowResponseWriterTest.java @@ -18,6 +18,7 @@ import static io.confluent.ksql.api.server.JsonStreamedRowResponseWriter.MAX_FLUSH_MS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.never; @@ -82,7 +83,8 @@ public class JsonStreamedRowResponseWriterTest { private AtomicLong timeMs = new AtomicLong(TIME_NOW_MS); private Runnable simulatedVertxTimerCallback; - private JsonStreamedRowResponseWriter writer; + private JsonStreamedRowResponseWriter jsonWriter; + private JsonStreamedRowResponseWriter protoWriter; private StringBuilder stringBuilder = new StringBuilder(); public JsonStreamedRowResponseWriterTest() { @@ -104,14 +106,20 @@ public void setUp() { when(clock.millis()).thenAnswer(a -> timeMs.get()); - writer = new JsonStreamedRowResponseWriter(response, publisher, Optional.empty(), - Optional.empty(), clock, true, context); + jsonWriter = new JsonStreamedRowResponseWriter(response, publisher, Optional.empty(), + Optional.empty(), clock, true, context, JsonStreamedRowResponseWriter.RowFormat.JSON); + + protoWriter = new JsonStreamedRowResponseWriter(response, publisher, Optional.empty(), + Optional.empty(), clock, true, context, JsonStreamedRowResponseWriter.RowFormat.PROTOBUF); } private void setupWithMessages(String completionMessage, String limitMessage, boolean buffering) { // No buffering for these responses - writer = new JsonStreamedRowResponseWriter(response, publisher, Optional.of(completionMessage), - Optional.of(limitMessage), clock, buffering, context); + jsonWriter = new JsonStreamedRowResponseWriter(response, publisher, Optional.of(completionMessage), + Optional.of(limitMessage), clock, buffering, context, JsonStreamedRowResponseWriter.RowFormat.JSON); + + protoWriter = new JsonStreamedRowResponseWriter(response, publisher, Optional.of(completionMessage), + Optional.of(limitMessage), clock, buffering, context, JsonStreamedRowResponseWriter.RowFormat.PROTOBUF); } private void expectTimer() { @@ -131,8 +139,8 @@ private void expectTimer() { @Test public void shouldSucceedWithBuffering_noRows() { // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); - writer.writeCompletionMessage().end(); + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeCompletionMessage().end(); // Then: assertThat(stringBuilder.toString(), @@ -146,10 +154,10 @@ public void shouldSucceedWithBuffering_noRows() { @Test public void shouldSucceedWithBuffering_oneRow() { // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); - writer.writeCompletionMessage().end(); + jsonWriter.writeCompletionMessage().end(); // Then: assertThat(stringBuilder.toString(), @@ -168,14 +176,14 @@ public void shouldSucceedWithBuffering_twoRows_timeout() { expectTimer(); // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); simulatedVertxTimerCallback.run(); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); simulatedVertxTimerCallback.run(); - writer.writeCompletionMessage().end(); + jsonWriter.writeCompletionMessage().end(); // Then: assertThat(stringBuilder.toString(), @@ -195,12 +203,12 @@ public void shouldSucceedWithBuffering_twoRows_noTimeout() { expectTimer(); // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); - writer.writeCompletionMessage().end(); + jsonWriter.writeCompletionMessage().end(); // Then: assertThat(stringBuilder.toString(), @@ -220,14 +228,14 @@ public void shouldSucceedWithBuffering_largeBuffer() { expectTimer(); // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); for (int i = 0; i < 4000; i++) { - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(i, 100.0d + i, ImmutableList.of("hello" + i))))); } - writer.writeCompletionMessage().end(); + jsonWriter.writeCompletionMessage().end(); // Then: assertThat(stringBuilder.toString().split("\n").length, is(4001)); @@ -243,12 +251,12 @@ public void shouldSucceedWithCompletionMessage_noBuffering() { setupWithMessages("complete!", "limit hit!", false); // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); - writer.writeCompletionMessage().end(); + jsonWriter.writeCompletionMessage().end(); // Then: assertThat(stringBuilder.toString(), @@ -270,12 +278,12 @@ public void shouldSucceedWithCompletionMessage_buffering() { expectTimer(); // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); - writer.writeCompletionMessage().end(); + jsonWriter.writeCompletionMessage().end(); // Then: assertThat(stringBuilder.toString(), @@ -296,13 +304,13 @@ public void shouldSucceedWithLimitMessage_noBuffering() { setupWithMessages("complete!", "limit hit!", false); // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); - writer.writeLimitMessage().end(); + jsonWriter.writeLimitMessage().end(); // Then: assertThat(stringBuilder.toString(), @@ -324,12 +332,12 @@ public void shouldSucceedWithLimitMessage_buffering() { expectTimer(); // When: - writer.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); - writer.writeRow(new KeyValueMetadata<>( + jsonWriter.writeRow(new KeyValueMetadata<>( KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); - writer.writeLimitMessage().end(); + jsonWriter.writeLimitMessage().end(); // Then: assertThat(stringBuilder.toString(), @@ -343,4 +351,448 @@ public void shouldSucceedWithLimitMessage_buffering() { verify(vertx, times(1)).setTimer(anyLong(), any()); verify(vertx).cancelTimer(anyLong()); } -} + + @Test + public void shouldSucceedWithBuffering_noRows_proto() { + // When: + protoWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + protoWriter.writeCompletionMessage().end(); + + // Then: + assertThat(stringBuilder.toString(), + is("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}}]")); + verify(response, times(1)).write((String) any()); + verify(response, never()).write((Buffer) any()); + verify(vertx, never()).setTimer(anyLong(), any()); + } + + @Test + public void shouldSucceedWithBuffering_oneRow_proto() { + // When: + protoWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); + protoWriter.writeCompletionMessage().end(); + + // Then: + assertThat(stringBuilder.toString(), + is("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}}]")); + verify(response, times(1)).write((String) any()); + verify(response, never()).write((Buffer) any()); + verify(vertx, times(1)).setTimer(anyLong(), any()); + verify(vertx, times(1)).cancelTimer(anyLong()); + } + + @Test + public void shouldSucceedWithBuffering_twoRows_timeout_proto() { + // Given: + expectTimer(); + + // When: + protoWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); + simulatedVertxTimerCallback.run(); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); + simulatedVertxTimerCallback.run(); + protoWriter.writeCompletionMessage().end(); + + // Then: + assertThat(stringBuilder.toString(), + is("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CMgDEQAAAAAAqIhAGgNieWU=\"}}]")); + verify(response, times(3)).write((String) any()); + verify(response, never()).write((Buffer) any()); + verify(vertx, times(2)).setTimer(anyLong(), any()); + verify(vertx, never()).cancelTimer(anyLong()); + } + + @Test + public void shouldSucceedWithBuffering_twoRows_noTimeout_proto() { + // Given: + expectTimer(); + + // When: + protoWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); + protoWriter.writeCompletionMessage().end(); + + // Then: + assertThat(stringBuilder.toString(), + is("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CMgDEQAAAAAAqIhAGgNieWU=\"}}]")); + verify(response, times(1)).write((String) any()); + verify(response, never()).write((Buffer) any()); + verify(vertx, times(1)).setTimer(anyLong(), any()); + verify(vertx).cancelTimer(anyLong()); + } + + @Test + public void shouldSucceedWithCompletionMessage_noBuffering_proto() { + // Given: + setupWithMessages("complete!", "limit hit!", false); + + // When: + protoWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); + protoWriter.writeCompletionMessage().end(); + + // Then: + assertThat(stringBuilder.toString(), + is("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CMgDEQAAAAAAqIhAGgNieWU=\"}},\n" + + "{\"finalMessage\":\"complete!\"}]")); + verify(response, times(5)).write((Buffer) any()); + verify(response, never()).write((String) any()); + verify(vertx, never()).setTimer(anyLong(), any()); + verify(vertx, never()).cancelTimer(anyLong()); + } + + @Test + public void shouldSucceedWithCompletionMessage_buffering_proto() { + // Given: + setupWithMessages("complete!", "limit hit!", true); + expectTimer(); + + // When: + protoWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); + protoWriter.writeCompletionMessage().end(); + + // Then: + assertThat(stringBuilder.toString(), + is("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CMgDEQAAAAAAqIhAGgNieWU=\"}},\n" + + "{\"finalMessage\":\"complete!\"}]")); + + verify(response, times(1)).write((String) any()); + verify(response, never()).write((Buffer) any()); + verify(vertx, times(1)).setTimer(anyLong(), any()); + verify(vertx).cancelTimer(anyLong()); + } + + @Test + public void shouldSucceedWithLimitMessage_noBuffering_proto() { + // Given: + setupWithMessages("complete!", "limit hit!", false); + + // When: + protoWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); + + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); + protoWriter.writeLimitMessage().end(); + + // Then: + assertThat(stringBuilder.toString(), + is("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CMgDEQAAAAAAqIhAGgNieWU=\"}},\n" + + "{\"finalMessage\":\"limit hit!\"}]")); + verify(response, times(5)).write((Buffer) any()); + verify(response, never()).write((String) any()); + verify(vertx, never()).setTimer(anyLong(), any()); + verify(vertx, never()).cancelTimer(anyLong()); + } + + @Test + public void shouldSucceedWithLimitMessage_buffering_proto() { + // Given: + setupWithMessages("complete!", "limit hit!", true); + expectTimer(); + + // When: + protoWriter.writeMetadata(new QueryResponseMetadata(QUERY_ID, COL_NAMES, COL_TYPES, SCHEMA)); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(123, 234.0d, ImmutableList.of("hello"))))); + protoWriter.writeRow(new KeyValueMetadata<>( + KeyValue.keyValue(null, GenericRow.genericRow(456, 789.0d, ImmutableList.of("bye"))))); + protoWriter.writeLimitMessage().end(); + + // Then: + assertThat(stringBuilder.toString(), + is("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CMgDEQAAAAAAqIhAGgNieWU=\"}},\n" + + "{\"finalMessage\":\"limit hit!\"}]")); + verify(response, times(1)).write((String) any()); + verify(response, never()).write((Buffer) any()); + verify(vertx, times(1)).setTimer(anyLong(), any()); + verify(vertx).cancelTimer(anyLong()); + } + @Test + public void shouldConvertLogicalSchemaToProtobufSchema() { + // Given: + final String expectedProtoSchemaString = "syntax = \"proto3\";\n" + + "\n" + + "message ConnectDefault1 {\n" + + " int32 A = 1;\n" + + " double B = 2;\n" + + " repeated string C = 3;\n" + + "}\n"; + + // When: + final String protoSchema = JsonStreamedRowResponseWriter.logicalToProtoSchema(SCHEMA); + + // Then: + assertThat(protoSchema, is(expectedProtoSchemaString)); + } + + @Test + public void shouldConvertComplexLogicalSchemaToProtobufSchema() { + // Given: + final LogicalSchema schema = LogicalSchema.builder() + .keyColumn(ColumnName.of("K"), SqlTypes.struct() + .field("F1", SqlTypes.array(SqlTypes.STRING)) + .build()) + .valueColumn(ColumnName.of("STR"), SqlTypes.STRING) + .valueColumn(ColumnName.of("LONG"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("DEC"), SqlTypes.decimal(4, 2)) + .valueColumn(ColumnName.of("BYTES_"), SqlTypes.BYTES) + .valueColumn(ColumnName.of("ARRAY"), SqlTypes.array(SqlTypes.STRING)) + .valueColumn(ColumnName.of("MAP"), SqlTypes.map(SqlTypes.STRING, SqlTypes.STRING)) + .valueColumn(ColumnName.of("STRUCT"), SqlTypes.struct().field("F1", SqlTypes.INTEGER).build()) + .valueColumn(ColumnName.of("COMPLEX"), SqlTypes.struct() + .field("DECIMAL", SqlTypes.decimal(2, 1)) + .field("STRUCT", SqlTypes.struct() + .field("F1", SqlTypes.STRING) + .field("F2", SqlTypes.INTEGER) + .build()) + .field("ARRAY_STRUCT", SqlTypes.array(SqlTypes.struct().field("F1", SqlTypes.STRING).build())) + .field("ARRAY_MAP", SqlTypes.array(SqlTypes.map(SqlTypes.STRING, SqlTypes.INTEGER))) + .field("MAP_ARRAY", SqlTypes.map(SqlTypes.STRING, SqlTypes.array(SqlTypes.STRING))) + .field("MAP_MAP", SqlTypes.map(SqlTypes.STRING, + SqlTypes.map(SqlTypes.STRING, SqlTypes.INTEGER) + )) + .field("MAP_STRUCT", SqlTypes.map(SqlTypes.STRING, + SqlTypes.struct().field("F1", SqlTypes.STRING).build() + )) + .build() + ) + .valueColumn(ColumnName.of("TIMESTAMP"), SqlTypes.TIMESTAMP) + .valueColumn(ColumnName.of("DATE"), SqlTypes.DATE) + .valueColumn(ColumnName.of("TIME"), SqlTypes.TIME) + .headerColumn(ColumnName.of("HEAD"), Optional.of("h0")) + .build(); + + final String expectedProtoSchemaString = "syntax = \"proto3\";\n" + + "\n" + + "import \"confluent/type/decimal.proto\";\n" + + "import \"google/protobuf/timestamp.proto\";\n" + + "import \"google/type/date.proto\";\n" + + "import \"google/type/timeofday.proto\";\n" + + "\n" + + "message ConnectDefault1 {\n" + + " ConnectDefault2 K = 1;\n" + + " string STR = 2;\n" + + " int64 LONG = 3;\n" + + " confluent.type.Decimal DEC = 4 [(confluent.field_meta) = {\n" + + " params: [\n" + + " {\n" + + " value: \"4\",\n" + + " key: \"precision\"\n" + + " },\n" + + " {\n" + + " value: \"2\",\n" + + " key: \"scale\"\n" + + " }\n" + + " ]\n" + + " }];\n" + + " bytes BYTES_ = 5;\n" + + " repeated string ARRAY = 6;\n" + + " repeated ConnectDefault3Entry MAP = 7;\n" + + " ConnectDefault4 STRUCT = 8;\n" + + " ConnectDefault5 COMPLEX = 9;\n" + + " google.protobuf.Timestamp TIMESTAMP = 10;\n" + + " google.type.Date DATE = 11;\n" + + " google.type.TimeOfDay TIME = 12;\n" + + " bytes HEAD = 13;\n" + + "\n" + + " message ConnectDefault2 {\n" + + " repeated string F1 = 1;\n" + + " }\n" + + " message ConnectDefault3Entry {\n" + + " string key = 1;\n" + + " string value = 2;\n" + + " }\n" + + " message ConnectDefault4 {\n" + + " int32 F1 = 1;\n" + + " }\n" + + " message ConnectDefault5 {\n" + + " confluent.type.Decimal DECIMAL = 1 [(confluent.field_meta) = {\n" + + " params: [\n" + + " {\n" + + " value: \"2\",\n" + + " key: \"precision\"\n" + + " },\n" + + " {\n" + + " value: \"1\",\n" + + " key: \"scale\"\n" + + " }\n" + + " ]\n" + + " }];\n" + + " ConnectDefault6 STRUCT = 2;\n" + + " repeated ConnectDefault7 ARRAY_STRUCT = 3;\n" + + " repeated ConnectDefault8Entry ARRAY_MAP = 4;\n" + + " repeated ConnectDefault9Entry MAP_ARRAY = 5;\n" + + " repeated ConnectDefault10Entry MAP_MAP = 6;\n" + + " repeated ConnectDefault12Entry MAP_STRUCT = 7;\n" + + " \n" + + " message ConnectDefault6 {\n" + + " string F1 = 1;\n" + + " int32 F2 = 2;\n" + + " }\n" + + " message ConnectDefault7 {\n" + + " string F1 = 1;\n" + + " }\n" + + " message ConnectDefault8Entry {\n" + + " string key = 1;\n" + + " int32 value = 2;\n" + + " }\n" + + " message ConnectDefault9Entry {\n" + + " string key = 1;\n" + + " repeated string value = 2;\n" + + " }\n" + + " message ConnectDefault10Entry {\n" + + " string key = 1;\n" + + " repeated ConnectDefault11Entry value = 2;\n" + + " \n" + + " message ConnectDefault11Entry {\n" + + " string key = 1;\n" + + " int32 value = 2;\n" + + " }\n" + + " }\n" + + " message ConnectDefault12Entry {\n" + + " string key = 1;\n" + + " ConnectDefault13 value = 2;\n" + + " \n" + + " message ConnectDefault13 {\n" + + " string F1 = 1;\n" + + " }\n" + + " }\n" + + " }\n" + + "}\n"; + + // When: + final String protoSchema = JsonStreamedRowResponseWriter.logicalToProtoSchema(schema); + + // Then: + assertThat(protoSchema, is(expectedProtoSchemaString)); + } + + @Test + public void shouldFailNestedArraysConvertLogicalSchemaToProtobufSchema() { + // Given: + final LogicalSchema schema = LogicalSchema.builder() + .valueColumn(ColumnName.of("COMPLEX"), SqlTypes.struct() + .field("ARRAY_ARRAY", SqlTypes.array(SqlTypes.array(SqlTypes.STRING))) + .build() + ) + .build(); + + // When: + Exception exception = assertThrows(IllegalArgumentException.class, () -> { + JsonStreamedRowResponseWriter.logicalToProtoSchema(schema); + }); + + String expectedMessage = "Array cannot be nested"; + String actualMessage = exception.getMessage(); + + // Then: + assertThat(actualMessage.contains(expectedMessage), is(true)); + } +} \ No newline at end of file diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/server/QueryStreamHandlerTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/server/QueryStreamHandlerTest.java index d810c917c61..2d86f6fb837 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/server/QueryStreamHandlerTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/server/QueryStreamHandlerTest.java @@ -33,7 +33,6 @@ import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KeyValueMetadata; -import io.confluent.ksql.util.PushQueryMetadata; import io.vertx.core.Context; import io.vertx.core.Handler; import io.vertx.core.MultiMap; diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java index 81c4997c34c..570716e429d 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java @@ -894,7 +894,7 @@ public void shouldExecutePullQueryOverRest() { assertThat(messages.get(1), is("{\"row\":{\"columns\":[1,\"USER_1\"]}}]")); } - @Test + @Test public void shouldExecutePullQueryOverHttp2QueryStream() { QueryStreamArgs queryStreamArgs = new QueryStreamArgs( "SELECT COUNT, USERID from " + AGG_TABLE + " WHERE USERID='" + AN_AGG_KEY + "';", @@ -916,6 +916,66 @@ public void shouldExecutePullQueryOverHttp2QueryStream() { assertThat(queryResponse[0].rows.get(0).getList(), is(ImmutableList.of(1, "USER_1"))); } + @Test + public void shouldExecutePullQueryOverQueryStreamProto() { + QueryStreamArgs queryStreamArgs = new QueryStreamArgs( + "SELECT COUNT, USERID from " + AGG_TABLE + ";", + Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap()); + + final String expectedResponse + = "[{\"header\":{\"queryId\":\"XYZ\"," + + "\"schema\":\"`COUNT` BIGINT, `USERID` STRING KEY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int64 COUNT = 1;\\n" + + " string USERID = 2;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CAESBlVTRVJfMA==\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CAESBlVTRVJfMQ==\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CAISBlVTRVJfMg==\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CAISBlVTRVJfMw==\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CAESBlVTRVJfNA==\"}}]"; + + final HttpResponse[] resp = new HttpResponse[1]; + Arrays.stream(HttpVersion.values()).forEach(httpVersion -> { + assertThatEventually(() -> { + try { + resp[0] = RestIntegrationTestUtil.rawRestRequest(REST_APP, + httpVersion, POST, + "/query-stream", queryStreamArgs, KsqlMediaType.KSQL_V1_PROTOBUF.mediaType(), + Optional.empty(), Optional.empty()); + int respSize = parseRawRestQueryResponse(resp[0].body().toString()).size(); + return respSize; + } catch (Throwable t) { + return Integer.MAX_VALUE; + } + }, is(6)); + + assertThat( + resp[0].bodyAsString().replaceFirst("queryId\":\"[^\"]*\"", "queryId\":\"XYZ\""), + equalTo(expectedResponse)); + }); + } + + @Test + public void shouldNotExecutePullQueryOverHttp2QueryProto() { + final QueryStreamArgs queryStreamArgs = new QueryStreamArgs( + "SELECT COUNT, USERID from " + AGG_TABLE + ";", + Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap()); + + final HttpResponse bufferHttpResponse = RestIntegrationTestUtil.rawRestRequest(REST_APP, + HTTP_2, POST, + "/query", queryStreamArgs, KsqlMediaType.KSQL_V1_PROTOBUF.mediaType(), + Optional.empty(), Optional.empty()); + + // Then: + assertThat(bufferHttpResponse.statusCode(), is(406)); + assertThat(bufferHttpResponse.statusMessage(), is("Not Acceptable")); + } + @Test public void shouldExecutePullQuery_allTypes() { ImmutableList formats = ImmutableList.of( diff --git a/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlRestClient.java b/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlRestClient.java index 8e3e760d397..efca8a36f4b 100644 --- a/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlRestClient.java +++ b/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlRestClient.java @@ -308,6 +308,14 @@ public RestResponse> makeQueryRequest( ksql, requestPropertiesToSend, Optional.ofNullable(commandSeqNum)); } + public RestResponse> makeQueryStreamRequestProto( + final String ksql, + final Map requestProperties + ) { + final KsqlTarget target = target(); + return target.postQueryStreamRequestProto(ksql, requestProperties); + } + public RestResponse> makePrintTopicRequest( final String ksql, final Long commandSeqNum diff --git a/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTarget.java b/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTarget.java index 0c7336cbff0..62a08d4f9a6 100644 --- a/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTarget.java +++ b/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTarget.java @@ -29,6 +29,7 @@ import io.confluent.ksql.rest.entity.HeartbeatResponse; import io.confluent.ksql.rest.entity.KsqlEntityList; import io.confluent.ksql.rest.entity.KsqlHostInfoEntity; +import io.confluent.ksql.rest.entity.KsqlMediaType; import io.confluent.ksql.rest.entity.KsqlRequest; import io.confluent.ksql.rest.entity.LagReportingMessage; import io.confluent.ksql.rest.entity.LagReportingResponse; @@ -221,6 +222,19 @@ public RestResponse> postQueryRequest( KsqlTarget::toRows); } + public RestResponse> postQueryStreamRequestProto( + final String ksql, + final Map requestProperties + ) { + final QueryStreamArgs queryStreamArgs = new QueryStreamArgs(ksql, localProperties.toMap(), + Collections.emptyMap(), requestProperties); + return executeRequestSync(HttpMethod.POST, + QUERY_STREAM_PATH, + queryStreamArgs, + KsqlTarget::toRowsFromProto, + Optional.of(KsqlMediaType.KSQL_V1_PROTOBUF.mediaType())); + } + public RestResponse> postQueryRequestStreamed( final String sql, final Map requestProperties, @@ -261,7 +275,11 @@ private KsqlRequest createKsqlRequest( } private RestResponse get(final String path, final Class type) { - return executeRequestSync(HttpMethod.GET, path, null, r -> deserialize(r.getBody(), type)); + return executeRequestSync(HttpMethod.GET, + path, + null, + r -> deserialize(r.getBody(), type), + Optional.empty()); } private RestResponse post( @@ -269,7 +287,7 @@ private RestResponse post( final Object jsonEntity, final Function mapper ) { - return executeRequestSync(HttpMethod.POST, path, jsonEntity, mapper); + return executeRequestSync(HttpMethod.POST, path, jsonEntity, mapper, Optional.empty()); } private RestResponse post( @@ -300,9 +318,10 @@ private RestResponse executeRequestSync( final HttpMethod httpMethod, final String path, final Object requestBody, - final Function mapper + final Function mapper, + final Optional mediaType ) { - return executeSync(httpMethod, path, Optional.empty(), requestBody, mapper, (resp, vcf) -> { + return executeSync(httpMethod, path, mediaType, requestBody, mapper, (resp, vcf) -> { resp.bodyHandler(buff -> vcf.complete(new ResponseWithBody(resp, buff))); }); } @@ -476,4 +495,8 @@ private CompletableFuture execute( private static List toRows(final ResponseWithBody resp) { return KsqlTargetUtil.toRows(resp.getBody()); } + + private static List toRowsFromProto(final ResponseWithBody resp) { + return KsqlTargetUtil.toRows(resp.getBody()); + } } diff --git a/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTargetUtil.java b/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTargetUtil.java index e71937dd0e5..7afdd9e5591 100644 --- a/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTargetUtil.java +++ b/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTargetUtil.java @@ -18,6 +18,7 @@ import static io.confluent.ksql.rest.client.KsqlClientUtil.deserialize; import static io.confluent.ksql.util.BytesUtils.toJsonMsg; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Strings; import com.google.common.collect.Streams; import io.confluent.ksql.GenericRow; @@ -40,6 +41,7 @@ import java.util.stream.Collectors; public final class KsqlTargetUtil { + private static final ObjectMapper MAPPER = new ObjectMapper(); private KsqlTargetUtil() { diff --git a/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/client/KsqlTargetUtilTest.java b/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/client/KsqlTargetUtilTest.java index d28948adefe..62954af27e3 100644 --- a/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/client/KsqlTargetUtilTest.java +++ b/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/client/KsqlTargetUtilTest.java @@ -138,4 +138,95 @@ public void toRows_errorParsingNotAtEnd() { // Then: assertThat(e.getMessage(), is(("Failed to deserialise object"))); } + + @Test + public void shouldParseHeaderProto() { + // When: + final List rows = KsqlTargetUtil.toRows(Buffer.buffer("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}}]")); + + StreamedRow row = rows.get(0); + + // Then: + assertThat(row.getHeader().isPresent(), is(true)); + assertThat(row.getHeader().get().getQueryId().toString(), is("queryId")); + + assertThat(row.getHeader().get().getSchema().toString(), is("`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY")); + assertThat(row.getHeader().get().getProtoSchema().get(), is("syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int32 A = 1;\n double B = 2;\n repeated string C = 3;\n}\n")); + } + + @Test + public void toRowsProto() { + // When: + final List rows = KsqlTargetUtil.toRows(Buffer.buffer("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CMgDEQAAAAAAqIhAGgNieWU=\"}},\n" + + "{\"finalMessage\":\"limit hit!\"}]")); + + // Then: + assertThat(rows.size(), is(4)); + final StreamedRow row = rows.get(0); + assertThat(row.getHeader().isPresent(), is(true)); + assertThat(row.getHeader().get().getQueryId().toString(), is("queryId")); + assertThat(row.getHeader().get().getSchema().toString(), is("`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY")); + assertThat(row.getHeader().get().getProtoSchema().get(), is("syntax = \"proto3\";\n\nmessage ConnectDefault1 {\n int32 A = 1;\n double B = 2;\n repeated string C = 3;\n}\n")); + + final StreamedRow row2 = rows.get(1); + assertThat(row2.getRow().isPresent(), is(true)); + assertThat(row2.getRow().get().getProtobufBytes().isPresent(), is(true)); + assertThat(row2.getRow().get().toString(), is("{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}")); + + final StreamedRow row3 = rows.get(2); + assertThat(row3.getRow().isPresent(), is(true)); + assertThat(row3.getRow().get().getProtobufBytes().isPresent(), is(true)); + assertThat(row3.getRow().get().toString(), is("{\"protobufBytes\":\"CMgDEQAAAAAAqIhAGgNieWU=\"}")); + + final StreamedRow row4 = rows.get(3); + assertThat(row4.getRow().isPresent(), is(false)); + assertThat(row4.getFinalMessage().isPresent(), is(true)); + assertThat(row4.getFinalMessage().get(), is("limit hit!")); + } + + @Test + public void toRows_errorParsingNotAtEndProto() { + // When: + final Exception e = assertThrows( + KsqlRestClientException.class, + () -> KsqlTargetUtil.toRows(Buffer.buffer("[{\"header\":{\"queryId\":\"queryId\"," + + "\"schema\":\"`A` INTEGER KEY, `B` DOUBLE, `C` ARRAY\"," + + "\"protoSchema\":" + + "\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int32 A = 1;\\n" + + " double B = 2;\\n" + + " repeated string C = 3;\\n" + + "}\\n" + + "\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CHsRAAAAAABAbUAaBWhlbGxv\"}},\n" + + "{\"row\":{\"protobufBytes\":\"CMgDEQAA")) + ); + + // Then: + assertThat(e.getMessage(), is(("Failed to deserialise object"))); + } } diff --git a/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/entity/StreamedRowTest.java b/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/entity/StreamedRowTest.java index 74d27b2e9e1..0659553e745 100644 --- a/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/entity/StreamedRowTest.java +++ b/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/entity/StreamedRowTest.java @@ -62,6 +62,31 @@ public void shouldRoundTripPullHeader() throws Exception { testRoundTrip(row, expectedJson); } + @Test + public void shouldRoundTripPullProtoHeader() throws Exception { + final String protoSchema = + "syntax = \"proto3\";\n" + + "\n" + + "message ConnectDefault1 {\n" + + " int64 ID = 1;\n" + + " string VAL = 2;\n" + + "}\n"; + final StreamedRow row = StreamedRow.headerProtobuf(QUERY_ID, PULL_SCHEMA, protoSchema); + + final String expectedJson = "{\"header\":" + + "{\"queryId\":\"theQueryId\"," + + "\"schema\":\"`ID` BIGINT KEY, `VAL` STRING\"," + + "\"protoSchema\":\"syntax = \\\"proto3\\\";\\n" + + "\\n" + + "message ConnectDefault1 {\\n" + + " int64 ID = 1;\\n" + + " string VAL = 2;\\n" + + "}\\n" + + "\"}}"; + + testRoundTrip(row, expectedJson); + } + @Test public void shouldRoundTripPushHeader() throws Exception { final StreamedRow row = StreamedRow.header( diff --git a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/KsqlMediaType.java b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/KsqlMediaType.java index 8754d75eca9..0e9dd50a62c 100644 --- a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/KsqlMediaType.java +++ b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/KsqlMediaType.java @@ -26,9 +26,11 @@ */ public enum KsqlMediaType { - KSQL_V1_JSON("application/vnd.ksql.v1+json"); + KSQL_V1_JSON("application/vnd.ksql.v1+json"), + KSQL_V1_PROTOBUF("application/vnd.ksql.v1+protobuf"); public static final KsqlMediaType LATEST_FORMAT = KSQL_V1_JSON; + public static final KsqlMediaType LATEST_FORMAT_PROTOBUF = KSQL_V1_PROTOBUF; private final int version; private final String mediaType; diff --git a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/StreamedRow.java b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/StreamedRow.java index 8cd4bd20f87..4c04d2659ef 100644 --- a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/StreamedRow.java +++ b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/StreamedRow.java @@ -72,7 +72,7 @@ public static StreamedRow header( final LogicalSchema schema ) { return new StreamedRow( - Optional.of(Header.of(queryId, schema)), + Optional.of(Header.of(queryId, schema, Optional.empty())), Optional.empty(), Optional.empty(), Optional.empty(), @@ -82,6 +82,22 @@ public static StreamedRow header( ); } + public static StreamedRow headerProtobuf( + final QueryId queryId, + final LogicalSchema columnsSchema, + final String protoSchema + ) { + return new StreamedRow( + Optional.of(Header.of(queryId, columnsSchema, Optional.of(protoSchema))), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + } + /** * Row returned from a push query. */ @@ -130,6 +146,20 @@ public static StreamedRow pullRow( ); } + public static StreamedRow pullRowProtobuf( + final byte[] rowBytes + ) { + return new StreamedRow( + Optional.empty(), + Optional.of(DataRow.rowProtobuf(rowBytes)), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + } + public static StreamedRow tombstone(final GenericRow columns) { return new StreamedRow( Optional.empty(), @@ -212,7 +242,8 @@ private StreamedRow( this.continuationToken = requireNonNull(continuationToken, "continuationToken"); this.consistencyToken = requireNonNull( consistencyToken, "consistencyToken"); - checkUnion(header, row, errorMessage, finalMessage, continuationToken, consistencyToken); + checkUnion(header, row, errorMessage, finalMessage, + continuationToken, consistencyToken); } public Optional
getHeader() { @@ -248,6 +279,7 @@ public boolean isTerminal() { return finalMessage.isPresent() || errorMessage.isPresent(); } + @SuppressWarnings("checkstyle:CyclomaticComplexity") @Override public boolean equals(final Object o) { if (this == o) { @@ -268,8 +300,8 @@ public boolean equals(final Object o) { @Override public int hashCode() { - return Objects.hash(header, row, errorMessage, finalMessage, sourceHost, continuationToken, - consistencyToken); + return Objects.hash(header, row, errorMessage, finalMessage, + sourceHost, continuationToken, consistencyToken); } @Override @@ -310,12 +342,14 @@ public static final class Header extends BaseRow { private final QueryId queryId; private final LogicalSchema columnsSchema; + private final Optional protoSchema; public static Header of( final QueryId queryId, - final LogicalSchema columnsSchema + final LogicalSchema columnsSchema, + final Optional protoSchema ) { - return new Header(queryId, columnsSchema); + return new Header(queryId, columnsSchema, protoSchema); } public QueryId getQueryId() { @@ -329,21 +363,28 @@ public LogicalSchema getSchema() { return columnsSchema; } + public Optional getProtoSchema() { + return protoSchema; + } + @JsonCreator @SuppressWarnings("unused") // Invoked by reflection by Jackson. private static Header jsonCreator( @JsonProperty(value = "queryId", required = true) final QueryId queryId, - @JsonProperty(value = "schema", required = true) final LogicalSchema columnsSchema + @JsonProperty(value = "schema") final LogicalSchema columnsSchema, + @JsonProperty(value = "protobufSchema") final Optional protoSchema ) { - return new Header(queryId, columnsSchema); + return new Header(queryId, columnsSchema, protoSchema); } private Header( final QueryId queryId, - final LogicalSchema columnsSchema + final LogicalSchema columnsSchema, + final Optional protoSchema ) { this.queryId = requireNonNull(queryId, "queryId"); this.columnsSchema = requireNonNull(columnsSchema, "columnsSchema"); + this.protoSchema = protoSchema; } @Override @@ -356,12 +397,13 @@ public boolean equals(final Object o) { } final Header header = (Header) o; return Objects.equals(queryId, header.queryId) - && Objects.equals(columnsSchema, header.columnsSchema); + && Objects.equals(columnsSchema, header.columnsSchema) + && Objects.equals(protoSchema, header.protoSchema); } @Override public int hashCode() { - return Objects.hash(queryId, columnsSchema); + return Objects.hash(queryId, columnsSchema, protoSchema); } } @@ -370,24 +412,36 @@ public int hashCode() { public static final class DataRow extends BaseRow { @EffectivelyImmutable - private final List columns; + private final Optional> columns; + @EffectivelyImmutable + private final Optional protobufBytes; private final boolean tombstone; public static DataRow row( final List columns ) { - return new DataRow(columns, Optional.empty()); + return new DataRow(Optional.of(columns), Optional.empty(), Optional.empty()); + } + + public static DataRow rowProtobuf( + final byte[] bytes + ) { + return new DataRow(Optional.empty(), Optional.of(bytes), Optional.empty()); } public static DataRow tombstone( final List columns ) { - return new DataRow(columns, Optional.of(true)); + return new DataRow(Optional.of(columns), Optional.empty(), Optional.of(true)); } @SuppressFBWarnings(value = "EI_EXPOSE_REP", justification = "columns is unmodifiableList()") public List getColumns() { - return columns; + return columns.orElse(Collections.emptyList()); + } + + public Optional getProtobufBytes() { + return protobufBytes; } public Optional getTombstone() { @@ -396,14 +450,15 @@ public Optional getTombstone() { @JsonCreator private DataRow( - @JsonProperty(value = "columns") final List columns, + @JsonProperty(value = "columns") final Optional> columns, + @JsonProperty(value = "protobufBytes") final Optional protobufBytes, @JsonProperty(value = "tombstone") final Optional tombstone ) { this.tombstone = tombstone.orElse(false); // cannot use ImmutableList, as we need to handle `null` - this.columns = Collections.unmodifiableList( - new ArrayList<>(requireNonNull(columns, "columns")) - ); + this.columns = columns.map(objects -> Collections.unmodifiableList( + new ArrayList<>(requireNonNull(objects, "columns")))); + this.protobufBytes = protobufBytes; } @Override @@ -416,12 +471,13 @@ public boolean equals(final Object o) { } final DataRow row = (DataRow) o; return tombstone == row.tombstone - && Objects.equals(columns, row.columns); + && Objects.equals(columns, row.columns) + && Objects.equals(protobufBytes, row.protobufBytes); } @Override public int hashCode() { - return Objects.hash(tombstone, columns); + return Objects.hash(tombstone, columns, protobufBytes); } } } diff --git a/ksqldb-rest-model/src/test/java/io/confluent/ksql/rest/entity/KsqlMediaTypeTest.java b/ksqldb-rest-model/src/test/java/io/confluent/ksql/rest/entity/KsqlMediaTypeTest.java index d102f1aff08..4bd998b617b 100644 --- a/ksqldb-rest-model/src/test/java/io/confluent/ksql/rest/entity/KsqlMediaTypeTest.java +++ b/ksqldb-rest-model/src/test/java/io/confluent/ksql/rest/entity/KsqlMediaTypeTest.java @@ -44,7 +44,7 @@ public void shouldHaveUpToDateLatest() { public static class PerValue { @Parameterized.Parameters(name = "{0}") - public static KsqlMediaType[] getMetiaTypes() { + public static KsqlMediaType[] getMediaTypes() { return KsqlMediaType.values(); } @@ -58,7 +58,8 @@ public void shouldParse() { @Test public void shouldGetValueOf() { - assertThat(KsqlMediaType.valueOf("json", mediaType.getVersion()), is(mediaType)); + final String format = mediaType.mediaType().split("\\+")[1]; + assertThat(KsqlMediaType.valueOf(format, mediaType.getVersion()), is(mediaType)); } } } \ No newline at end of file diff --git a/ksqldb-serde/src/main/java/io/confluent/ksql/serde/protobuf/ProtobufNoSRSerdeFactory.java b/ksqldb-serde/src/main/java/io/confluent/ksql/serde/protobuf/ProtobufNoSRSerdeFactory.java index 2728b258c0d..5b97770a640 100644 --- a/ksqldb-serde/src/main/java/io/confluent/ksql/serde/protobuf/ProtobufNoSRSerdeFactory.java +++ b/ksqldb-serde/src/main/java/io/confluent/ksql/serde/protobuf/ProtobufNoSRSerdeFactory.java @@ -40,7 +40,7 @@ import org.apache.kafka.connect.data.Schema.Type; @SuppressWarnings("checkstyle:ClassDataAbstractionCoupling") -final class ProtobufNoSRSerdeFactory implements SerdeFactory { +public final class ProtobufNoSRSerdeFactory implements SerdeFactory { private final ProtobufNoSRProperties properties; @@ -48,7 +48,7 @@ final class ProtobufNoSRSerdeFactory implements SerdeFactory { this.properties = Objects.requireNonNull(properties, "properties"); } - ProtobufNoSRSerdeFactory(final ImmutableMap formatProperties) { + public ProtobufNoSRSerdeFactory(final ImmutableMap formatProperties) { this(new ProtobufNoSRProperties(formatProperties)); } @@ -86,7 +86,7 @@ private static void validate(final Schema schema) { SchemaWalker.visit(schema, new SchemaValidator()); } - private KsqlConnectSerializer createSerializer( + public KsqlConnectSerializer createSerializer( final Schema schema, final Class targetType, final boolean isKey