diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java index f264ae4dfa31..dff64f97c9a6 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java @@ -40,7 +40,9 @@ import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.util.SchemaUtil; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -56,6 +58,7 @@ public class Analysis { private Optional joinInfo = Optional.empty(); private Optional whereExpression = Optional.empty(); private final List selectExpressions = new ArrayList<>(); + private final Set selectColumnRefs = new HashSet<>(); private final List groupByExpressions = new ArrayList<>(); private Optional windowExpression = Optional.empty(); private Optional partitionBy = Optional.empty(); @@ -76,6 +79,10 @@ void addSelectItem(final Expression expression, final ColumnName alias) { selectExpressions.add(SelectExpression.of(alias, expression)); } + void addSelectColumnRefs(final Collection columnRefs) { + selectColumnRefs.addAll(columnRefs); + } + public Optional getInto() { return into; } @@ -96,6 +103,10 @@ public List getSelectExpressions() { return Collections.unmodifiableList(selectExpressions); } + Set getSelectColumnRefs() { + return Collections.unmodifiableSet(selectColumnRefs); + } + public List getGroupByExpressions() { return ImmutableList.copyOf(groupByExpressions); } @@ -156,7 +167,7 @@ public List getFromDataSources() { return ImmutableList.copyOf(fromDataSources); } - public SourceSchemas getFromSourceSchemas() { + SourceSchemas getFromSourceSchemas() { final Map schemaBySource = fromDataSources.stream() .collect(Collectors.toMap( AliasedDataSource::getAlias, diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java index 84fea949d639..e3fa1dd90833 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java @@ -26,6 +26,7 @@ import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.metastore.MetaStore; @@ -62,6 +63,7 @@ import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -511,7 +513,7 @@ protected AstNode visitSelect(final Select node, final Void context) { visitSelectStar((AllColumns) selectItem); } else if (selectItem instanceof SingleColumn) { final SingleColumn column = (SingleColumn) selectItem; - analysis.addSelectItem(column.getExpression(), column.getAlias()); + addSelectItem(column.getExpression(), column.getAlias()); } else { throw new IllegalArgumentException( "Unsupported SelectItem type: " + selectItem.getClass().getName()); @@ -562,14 +564,19 @@ private void visitSelectStar(final AllColumns allColumns) { ? source.getAlias().name() + "_" : ""; - for (final Column column : source.getDataSource().getSchema().columns()) { + final LogicalSchema schema = source.getDataSource().getSchema(); + for (final Column column : schema.columns()) { + + if (staticQuery && schema.isMetaColumn(column.name())) { + continue; + } final ColumnReferenceExp selectItem = new ColumnReferenceExp(location, ColumnRef.of(source.getAlias(), column.name())); final String alias = aliasPrefix + column.name().name(); - analysis.addSelectItem(selectItem, ColumnName.of(alias)); + addSelectItem(selectItem, ColumnName.of(alias)); } } } @@ -598,6 +605,25 @@ public void validate() { + System.lineSeparator() + KAFKA_VALUE_FORMAT_LIMITATION_DETAILS); } } + + private void addSelectItem(final Expression exp, final ColumnName columnName) { + final Set columnRefs = new HashSet<>(); + final TraversalExpressionVisitor visitor = new TraversalExpressionVisitor() { + @Override + public Void visitColumnReference( + final ColumnReferenceExp node, + final Void context + ) { + columnRefs.add(node.getReference()); + return null; + } + }; + + visitor.process(exp, null); + + analysis.addSelectItem(exp, columnName); + analysis.addSelectColumnRefs(columnRefs); + } } @FunctionalInterface diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/StaticQueryValidator.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/StaticQueryValidator.java index a32e6e536336..3be969b6a9bf 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/StaticQueryValidator.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/StaticQueryValidator.java @@ -17,7 +17,9 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.parser.tree.ResultMaterialization; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.SchemaUtil; import java.util.List; import java.util.Objects; import java.util.function.Predicate; @@ -89,6 +91,12 @@ public class StaticQueryValidator implements QueryValidator { Rule.of( analysis -> !analysis.getLimitClause().isPresent(), "Static queries don't support LIMIT clauses." + ), + Rule.of( + analysis -> analysis.getSelectColumnRefs().stream() + .map(ColumnRef::name) + .noneMatch(n -> n.equals(SchemaUtil.ROWTIME_NAME)), + "Static queries don't support ROWTIME in select columns." ) ); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java index 30180cd35135..984934c4bf64 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java @@ -16,8 +16,11 @@ package io.confluent.ksql.analyzer; import static io.confluent.ksql.testutils.AnalysisTestUtil.analyzeQuery; +import static io.confluent.ksql.util.SchemaUtil.ROWTIME_NAME; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -35,6 +38,7 @@ import io.confluent.ksql.analyzer.Analyzer.SerdeOptionsSupplier; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; import io.confluent.ksql.execution.expression.tree.BooleanLiteral; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.Literal; import io.confluent.ksql.execution.expression.tree.StringLiteral; import io.confluent.ksql.execution.plan.SelectExpression; @@ -53,6 +57,7 @@ import io.confluent.ksql.parser.tree.Sink; import io.confluent.ksql.parser.tree.Statement; import io.confluent.ksql.planner.plan.JoinNode.JoinType; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; @@ -90,6 +95,11 @@ public class AnalyzerFunctionalTest { private static final Set DEFAULT_SERDE_OPTIONS = SerdeOption.none(); + private static final SourceName TEST1 = SourceName.of("TEST1"); + private static final ColumnName COL0 = ColumnName.of("COL0"); + private static final ColumnName COL1 = ColumnName.of("COL1"); + private static final ColumnName COL2 = ColumnName.of("COL2"); + private static final ColumnName COL3 = ColumnName.of("COL3"); private MutableMetaStore jsonMetaStore; private MutableMetaStore avroMetaStore; @@ -136,7 +146,7 @@ public void testSimpleQueryAnalysis() { final Analysis analysis = analyzeQuery(simpleQuery, jsonMetaStore); assertEquals("FROM was not analyzed correctly.", analysis.getFromDataSources().get(0).getDataSource().getName(), - SourceName.of("TEST1")); + TEST1); assertThat(analysis.getWhereExpression().get().toString(), is("(TEST1.COL0 > 100)")); final List selects = analysis.getSelectExpressions(); @@ -144,9 +154,9 @@ public void testSimpleQueryAnalysis() { assertThat(selects.get(1).getExpression().toString(), is("TEST1.COL2")); assertThat(selects.get(2).getExpression().toString(), is("TEST1.COL3")); - assertThat(selects.get(0).getName(), is(ColumnName.of("COL0"))); - assertThat(selects.get(1).getName(), is(ColumnName.of("COL2"))); - assertThat(selects.get(2).getName(), is(ColumnName.of("COL3"))); + assertThat(selects.get(0).getName(), is(COL0)); + assertThat(selects.get(1).getName(), is(COL2)); + assertThat(selects.get(2).getName(), is(COL3)); } @Test @@ -202,7 +212,7 @@ public void testBooleanExpressionAnalysis() { final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore); assertEquals("FROM was not analyzed correctly.", - analysis.getFromDataSources().get(0).getDataSource().getName(), SourceName.of("TEST1")); + analysis.getFromDataSources().get(0).getDataSource().getName(), TEST1); final List selects = analysis.getSelectExpressions(); assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)")); @@ -215,7 +225,7 @@ public void testFilterAnalysis() { final String queryStr = "SELECT col0 = 10, col2, col3 > col1 FROM test1 WHERE col0 > 20 EMIT CHANGES;"; final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore); - assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(SourceName.of("TEST1"))); + assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(TEST1)); final List selects = analysis.getSelectExpressions(); assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)")); @@ -450,6 +460,50 @@ public void shouldThrowOnJoinIfKafkaFormat() { analyzer.analyze(query, Optional.of(sink)); } + @Test + public void shouldCaptureProjectionColumnRefs() { + // Given: + query = parseSingle("Select COL0, COL0 + COL1, SUBSTRING(COL2, 1) from TEST1;"); + + // When: + final Analysis analysis = analyzer.analyze(query, Optional.empty()); + + // Then: + assertThat(analysis.getSelectColumnRefs(), containsInAnyOrder( + ColumnRef.of(TEST1, COL0), + ColumnRef.of(TEST1, COL1), + ColumnRef.of(TEST1, COL2) + )); + } + + @Test + public void shouldIncludeMetaColumnsForSelectStarOnContinuousQueries() { + // Given: + query = parseSingle("Select * from TEST1 EMIT CHANGES;"); + + // When: + final Analysis analysis = analyzer.analyze(query, Optional.empty()); + + // Then: + assertThat(analysis.getSelectExpressions(), hasItem( + SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME))) + )); + } + + @Test + public void shouldNotIncludeMetaColumnsForSelectStartOnStaticQueries() { + // Given: + query = parseSingle("Select * from TEST1;"); + + // When: + final Analysis analysis = analyzer.analyze(query, Optional.empty()); + + // Then: + assertThat(analysis.getSelectExpressions(), not(hasItem( + SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME))) + ))); + } + @SuppressWarnings("unchecked") private T parseSingle(final String simpleQuery) { return (T) Iterables.getOnlyElement(parse(simpleQuery, jsonMetaStore)); @@ -478,7 +532,7 @@ private void buildProps() { private void registerKafkaSource() { final LogicalSchema schema = LogicalSchema.builder() - .valueColumn(ColumnName.of("COL0"), SqlTypes.BIGINT) + .valueColumn(COL0, SqlTypes.BIGINT) .build(); final KsqlTopic topic = new KsqlTopic( diff --git a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/StaticQueryValidatorTest.java b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/StaticQueryValidatorTest.java index cba43c81bafc..da64949e4673 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/StaticQueryValidatorTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/StaticQueryValidatorTest.java @@ -19,12 +19,15 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.confluent.ksql.analyzer.Analysis.Into; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.parser.tree.ResultMaterialization; import io.confluent.ksql.parser.tree.WindowExpression; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.SchemaUtil; import java.util.Optional; import java.util.OptionalInt; import org.junit.Before; @@ -109,7 +112,7 @@ public void shouldThrowOnStaticQueryThatIsWindowed() { } @Test - public void shouldThrowOnStaticQueryThatHasGroupBy() { + public void shouldThrowOnGroupBy() { // Given: when(analysis.getGroupByExpressions()).thenReturn(ImmutableList.of(AN_EXPRESSION)); @@ -122,7 +125,7 @@ public void shouldThrowOnStaticQueryThatHasGroupBy() { } @Test - public void shouldThrowOnStaticQueryThatHasPartitionBy() { + public void shouldThrowOnPartitionBy() { // Given: when(analysis.getPartitionBy()).thenReturn(Optional.of(ColumnName.of("Something"))); @@ -135,7 +138,7 @@ public void shouldThrowOnStaticQueryThatHasPartitionBy() { } @Test - public void shouldThrowOnStaticQueryThatHasHavingClause() { + public void shouldThrowOnHavingClause() { // Given: when(analysis.getHavingExpression()).thenReturn(Optional.of(AN_EXPRESSION)); @@ -148,7 +151,7 @@ public void shouldThrowOnStaticQueryThatHasHavingClause() { } @Test - public void shouldThrowOnStaticQueryThatHasLimitClause() { + public void shouldThrowOnLimitClause() { // Given: when(analysis.getLimitClause()).thenReturn(OptionalInt.of(1)); @@ -159,4 +162,18 @@ public void shouldThrowOnStaticQueryThatHasLimitClause() { // When: validator.validate(analysis); } + + @Test + public void shouldThrowOnRowTimeInProjection() { + // Given: + when(analysis.getSelectColumnRefs()) + .thenReturn(ImmutableSet.of(ColumnRef.of(SchemaUtil.ROWTIME_NAME))); + + // Then: + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Static queries don't support ROWTIME in select columns."); + + // When: + validator.validate(analysis); + } } \ No newline at end of file diff --git a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json index b67bd82ca788..d67d50641934 100644 --- a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json +++ b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json @@ -235,6 +235,32 @@ } ] }, + { + "name": "non-windowed projection WITH ROWTIME", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;", + "SELECT ROWTIME + 10, COUNT FROM AGGREGATE WHERE ROWKEY='10';" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage", + "message": "Static queries don't support ROWTIME in select columns.", + "status": 400 + } + }, + { + "name": "windowed with projection with ROWTIME", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", + "SELECT COUNT, ROWTIME + 10 FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=12000;" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage", + "message": "Static queries don't support ROWTIME in select columns.", + "status": 400 + } + }, { "name": "text datetime window bounds", "enabled": false,