From 2f27b68e97bebefdf5e9aad5b320e1f84dbce3b7 Mon Sep 17 00:00:00 2001 From: Andy Coates <8012398+big-andy-coates@users.noreply.github.com> Date: Fri, 27 Sep 2019 17:42:27 +0100 Subject: [PATCH] feat(static): fail on ROWTIME in projection (#3430) * feat(static): fail on ROWTIME in projection At the moment static queries do not support returning ROWTIME as this information is not available in the response for KS IQ. In the future, we _may_ choose to support this by always including ROWTIME in the value of the changelog topic, but this is out of scope for this initial MVP. --- .../io/confluent/ksql/analyzer/Analysis.java | 13 +++- .../io/confluent/ksql/analyzer/Analyzer.java | 32 ++++++++- .../ksql/analyzer/StaticQueryValidator.java | 8 +++ .../ksql/analyzer/AnalyzerFunctionalTest.java | 68 +++++++++++++++++-- .../analyzer/StaticQueryValidatorTest.java | 25 +++++-- ...materialized-aggregate-static-queries.json | 26 +++++++ 6 files changed, 157 insertions(+), 15 deletions(-) diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java index f264ae4dfa31..dff64f97c9a6 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java @@ -40,7 +40,9 @@ import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.util.SchemaUtil; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -56,6 +58,7 @@ public class Analysis { private Optional joinInfo = Optional.empty(); private Optional whereExpression = Optional.empty(); private final List selectExpressions = new ArrayList<>(); + private final Set selectColumnRefs = new HashSet<>(); private final List groupByExpressions = new ArrayList<>(); private Optional windowExpression = Optional.empty(); private Optional partitionBy = Optional.empty(); @@ -76,6 +79,10 @@ void addSelectItem(final Expression expression, final ColumnName alias) { selectExpressions.add(SelectExpression.of(alias, expression)); } + void addSelectColumnRefs(final Collection columnRefs) { + selectColumnRefs.addAll(columnRefs); + } + public Optional getInto() { return into; } @@ -96,6 +103,10 @@ public List getSelectExpressions() { return Collections.unmodifiableList(selectExpressions); } + Set getSelectColumnRefs() { + return Collections.unmodifiableSet(selectColumnRefs); + } + public List getGroupByExpressions() { return ImmutableList.copyOf(groupByExpressions); } @@ -156,7 +167,7 @@ public List getFromDataSources() { return ImmutableList.copyOf(fromDataSources); } - public SourceSchemas getFromSourceSchemas() { + SourceSchemas getFromSourceSchemas() { final Map schemaBySource = fromDataSources.stream() .collect(Collectors.toMap( AliasedDataSource::getAlias, diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java index 84fea949d639..e3fa1dd90833 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java @@ -26,6 +26,7 @@ import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.metastore.MetaStore; @@ -62,6 +63,7 @@ import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -511,7 +513,7 @@ protected AstNode visitSelect(final Select node, final Void context) { visitSelectStar((AllColumns) selectItem); } else if (selectItem instanceof SingleColumn) { final SingleColumn column = (SingleColumn) selectItem; - analysis.addSelectItem(column.getExpression(), column.getAlias()); + addSelectItem(column.getExpression(), column.getAlias()); } else { throw new IllegalArgumentException( "Unsupported SelectItem type: " + selectItem.getClass().getName()); @@ -562,14 +564,19 @@ private void visitSelectStar(final AllColumns allColumns) { ? source.getAlias().name() + "_" : ""; - for (final Column column : source.getDataSource().getSchema().columns()) { + final LogicalSchema schema = source.getDataSource().getSchema(); + for (final Column column : schema.columns()) { + + if (staticQuery && schema.isMetaColumn(column.name())) { + continue; + } final ColumnReferenceExp selectItem = new ColumnReferenceExp(location, ColumnRef.of(source.getAlias(), column.name())); final String alias = aliasPrefix + column.name().name(); - analysis.addSelectItem(selectItem, ColumnName.of(alias)); + addSelectItem(selectItem, ColumnName.of(alias)); } } } @@ -598,6 +605,25 @@ public void validate() { + System.lineSeparator() + KAFKA_VALUE_FORMAT_LIMITATION_DETAILS); } } + + private void addSelectItem(final Expression exp, final ColumnName columnName) { + final Set columnRefs = new HashSet<>(); + final TraversalExpressionVisitor visitor = new TraversalExpressionVisitor() { + @Override + public Void visitColumnReference( + final ColumnReferenceExp node, + final Void context + ) { + columnRefs.add(node.getReference()); + return null; + } + }; + + visitor.process(exp, null); + + analysis.addSelectItem(exp, columnName); + analysis.addSelectColumnRefs(columnRefs); + } } @FunctionalInterface diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/StaticQueryValidator.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/StaticQueryValidator.java index a32e6e536336..3be969b6a9bf 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/StaticQueryValidator.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/StaticQueryValidator.java @@ -17,7 +17,9 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.parser.tree.ResultMaterialization; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.SchemaUtil; import java.util.List; import java.util.Objects; import java.util.function.Predicate; @@ -89,6 +91,12 @@ public class StaticQueryValidator implements QueryValidator { Rule.of( analysis -> !analysis.getLimitClause().isPresent(), "Static queries don't support LIMIT clauses." + ), + Rule.of( + analysis -> analysis.getSelectColumnRefs().stream() + .map(ColumnRef::name) + .noneMatch(n -> n.equals(SchemaUtil.ROWTIME_NAME)), + "Static queries don't support ROWTIME in select columns." ) ); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java index 30180cd35135..984934c4bf64 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java @@ -16,8 +16,11 @@ package io.confluent.ksql.analyzer; import static io.confluent.ksql.testutils.AnalysisTestUtil.analyzeQuery; +import static io.confluent.ksql.util.SchemaUtil.ROWTIME_NAME; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -35,6 +38,7 @@ import io.confluent.ksql.analyzer.Analyzer.SerdeOptionsSupplier; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; import io.confluent.ksql.execution.expression.tree.BooleanLiteral; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.Literal; import io.confluent.ksql.execution.expression.tree.StringLiteral; import io.confluent.ksql.execution.plan.SelectExpression; @@ -53,6 +57,7 @@ import io.confluent.ksql.parser.tree.Sink; import io.confluent.ksql.parser.tree.Statement; import io.confluent.ksql.planner.plan.JoinNode.JoinType; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; @@ -90,6 +95,11 @@ public class AnalyzerFunctionalTest { private static final Set DEFAULT_SERDE_OPTIONS = SerdeOption.none(); + private static final SourceName TEST1 = SourceName.of("TEST1"); + private static final ColumnName COL0 = ColumnName.of("COL0"); + private static final ColumnName COL1 = ColumnName.of("COL1"); + private static final ColumnName COL2 = ColumnName.of("COL2"); + private static final ColumnName COL3 = ColumnName.of("COL3"); private MutableMetaStore jsonMetaStore; private MutableMetaStore avroMetaStore; @@ -136,7 +146,7 @@ public void testSimpleQueryAnalysis() { final Analysis analysis = analyzeQuery(simpleQuery, jsonMetaStore); assertEquals("FROM was not analyzed correctly.", analysis.getFromDataSources().get(0).getDataSource().getName(), - SourceName.of("TEST1")); + TEST1); assertThat(analysis.getWhereExpression().get().toString(), is("(TEST1.COL0 > 100)")); final List selects = analysis.getSelectExpressions(); @@ -144,9 +154,9 @@ public void testSimpleQueryAnalysis() { assertThat(selects.get(1).getExpression().toString(), is("TEST1.COL2")); assertThat(selects.get(2).getExpression().toString(), is("TEST1.COL3")); - assertThat(selects.get(0).getName(), is(ColumnName.of("COL0"))); - assertThat(selects.get(1).getName(), is(ColumnName.of("COL2"))); - assertThat(selects.get(2).getName(), is(ColumnName.of("COL3"))); + assertThat(selects.get(0).getName(), is(COL0)); + assertThat(selects.get(1).getName(), is(COL2)); + assertThat(selects.get(2).getName(), is(COL3)); } @Test @@ -202,7 +212,7 @@ public void testBooleanExpressionAnalysis() { final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore); assertEquals("FROM was not analyzed correctly.", - analysis.getFromDataSources().get(0).getDataSource().getName(), SourceName.of("TEST1")); + analysis.getFromDataSources().get(0).getDataSource().getName(), TEST1); final List selects = analysis.getSelectExpressions(); assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)")); @@ -215,7 +225,7 @@ public void testFilterAnalysis() { final String queryStr = "SELECT col0 = 10, col2, col3 > col1 FROM test1 WHERE col0 > 20 EMIT CHANGES;"; final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore); - assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(SourceName.of("TEST1"))); + assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(TEST1)); final List selects = analysis.getSelectExpressions(); assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)")); @@ -450,6 +460,50 @@ public void shouldThrowOnJoinIfKafkaFormat() { analyzer.analyze(query, Optional.of(sink)); } + @Test + public void shouldCaptureProjectionColumnRefs() { + // Given: + query = parseSingle("Select COL0, COL0 + COL1, SUBSTRING(COL2, 1) from TEST1;"); + + // When: + final Analysis analysis = analyzer.analyze(query, Optional.empty()); + + // Then: + assertThat(analysis.getSelectColumnRefs(), containsInAnyOrder( + ColumnRef.of(TEST1, COL0), + ColumnRef.of(TEST1, COL1), + ColumnRef.of(TEST1, COL2) + )); + } + + @Test + public void shouldIncludeMetaColumnsForSelectStarOnContinuousQueries() { + // Given: + query = parseSingle("Select * from TEST1 EMIT CHANGES;"); + + // When: + final Analysis analysis = analyzer.analyze(query, Optional.empty()); + + // Then: + assertThat(analysis.getSelectExpressions(), hasItem( + SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME))) + )); + } + + @Test + public void shouldNotIncludeMetaColumnsForSelectStartOnStaticQueries() { + // Given: + query = parseSingle("Select * from TEST1;"); + + // When: + final Analysis analysis = analyzer.analyze(query, Optional.empty()); + + // Then: + assertThat(analysis.getSelectExpressions(), not(hasItem( + SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME))) + ))); + } + @SuppressWarnings("unchecked") private T parseSingle(final String simpleQuery) { return (T) Iterables.getOnlyElement(parse(simpleQuery, jsonMetaStore)); @@ -478,7 +532,7 @@ private void buildProps() { private void registerKafkaSource() { final LogicalSchema schema = LogicalSchema.builder() - .valueColumn(ColumnName.of("COL0"), SqlTypes.BIGINT) + .valueColumn(COL0, SqlTypes.BIGINT) .build(); final KsqlTopic topic = new KsqlTopic( diff --git a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/StaticQueryValidatorTest.java b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/StaticQueryValidatorTest.java index cba43c81bafc..da64949e4673 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/StaticQueryValidatorTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/StaticQueryValidatorTest.java @@ -19,12 +19,15 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.confluent.ksql.analyzer.Analysis.Into; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.parser.tree.ResultMaterialization; import io.confluent.ksql.parser.tree.WindowExpression; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.SchemaUtil; import java.util.Optional; import java.util.OptionalInt; import org.junit.Before; @@ -109,7 +112,7 @@ public void shouldThrowOnStaticQueryThatIsWindowed() { } @Test - public void shouldThrowOnStaticQueryThatHasGroupBy() { + public void shouldThrowOnGroupBy() { // Given: when(analysis.getGroupByExpressions()).thenReturn(ImmutableList.of(AN_EXPRESSION)); @@ -122,7 +125,7 @@ public void shouldThrowOnStaticQueryThatHasGroupBy() { } @Test - public void shouldThrowOnStaticQueryThatHasPartitionBy() { + public void shouldThrowOnPartitionBy() { // Given: when(analysis.getPartitionBy()).thenReturn(Optional.of(ColumnName.of("Something"))); @@ -135,7 +138,7 @@ public void shouldThrowOnStaticQueryThatHasPartitionBy() { } @Test - public void shouldThrowOnStaticQueryThatHasHavingClause() { + public void shouldThrowOnHavingClause() { // Given: when(analysis.getHavingExpression()).thenReturn(Optional.of(AN_EXPRESSION)); @@ -148,7 +151,7 @@ public void shouldThrowOnStaticQueryThatHasHavingClause() { } @Test - public void shouldThrowOnStaticQueryThatHasLimitClause() { + public void shouldThrowOnLimitClause() { // Given: when(analysis.getLimitClause()).thenReturn(OptionalInt.of(1)); @@ -159,4 +162,18 @@ public void shouldThrowOnStaticQueryThatHasLimitClause() { // When: validator.validate(analysis); } + + @Test + public void shouldThrowOnRowTimeInProjection() { + // Given: + when(analysis.getSelectColumnRefs()) + .thenReturn(ImmutableSet.of(ColumnRef.of(SchemaUtil.ROWTIME_NAME))); + + // Then: + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Static queries don't support ROWTIME in select columns."); + + // When: + validator.validate(analysis); + } } \ No newline at end of file diff --git a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json index b67bd82ca788..d67d50641934 100644 --- a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json +++ b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json @@ -235,6 +235,32 @@ } ] }, + { + "name": "non-windowed projection WITH ROWTIME", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;", + "SELECT ROWTIME + 10, COUNT FROM AGGREGATE WHERE ROWKEY='10';" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage", + "message": "Static queries don't support ROWTIME in select columns.", + "status": 400 + } + }, + { + "name": "windowed with projection with ROWTIME", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", + "SELECT COUNT, ROWTIME + 10 FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=12000;" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage", + "message": "Static queries don't support ROWTIME in select columns.", + "status": 400 + } + }, { "name": "text datetime window bounds", "enabled": false,