From 66fa6bd6d48b08625ecedfcb5a976678141300bd Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 29 Jul 2017 10:11:31 -0700 Subject: [PATCH] [SPARK-19451][SQL] rangeBetween method should accept Long value as boundary ## What changes were proposed in this pull request? Long values can be passed to `rangeBetween` as range frame boundaries, but we silently convert it to Int values, this can cause wrong results and we should fix this. Further more, we should accept any legal literal values as range frame boundaries. In this PR, we make it possible for Long values, and make accepting other DataTypes really easy to add. This PR is mostly based on Herman's previous amazing work: https://github.com/hvanhovell/spark/commit/596f53c339b1b4629f5651070e56a8836a397768 After this been merged, we can close #16818 . ## How was this patch tested? Add new tests in `DataFrameWindowFunctionsSuite` and `TypeCoercionSuite`. Author: Xingbo Jiang Closes #18540 from jiangxb1987/rangeFrame. (cherry picked from commit 92d85637e7f382aae61c0f26eb1524d2b4c93516) Signed-off-by: gatorsmile --- .../sql/catalyst/analysis/CheckAnalysis.scala | 15 +- .../sql/catalyst/analysis/TypeCoercion.scala | 23 ++ .../sql/catalyst/expressions/package.scala | 7 + .../expressions/windowExpressions.scala | 328 +++++++++--------- .../sql/catalyst/parser/AstBuilder.scala | 20 +- .../spark/sql/catalyst/trees/TreeNode.scala | 2 - .../analysis/AnalysisErrorSuite.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 36 ++ .../parser/ExpressionParserSuite.scala | 17 +- .../sql/catalyst/parser/PlanParserSuite.scala | 2 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 27 +- .../sql/execution/window/WindowExec.scala | 103 +++--- .../spark/sql/expressions/WindowSpec.scala | 33 +- .../resources/sql-tests/inputs/window.sql | 24 +- .../sql-tests/results/window.sql.out | 146 ++++++-- .../sql/DataFrameWindowFunctionsSuite.scala | 42 +++ .../catalyst/ExpressionSQLBuilderSuite.scala | 10 +- 17 files changed, 533 insertions(+), 304 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2e3ac3e474866..2e300c538ebd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -106,11 +106,9 @@ trait CheckAnalysis extends PredicateHelper { case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") - case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, - SpecifiedWindowFrame(frame, - FrameBoundary(l), - FrameBoundary(h)))) - if order.isEmpty || frame != RowFrame || l != h => + case w @ WindowExpression(_: OffsetWindowFunction, + WindowSpecDefinition(_, order, frame: SpecifiedWindowFrame)) + if order.isEmpty || !frame.isOffset => failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") @@ -119,15 +117,10 @@ trait CheckAnalysis extends PredicateHelper { // function. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } - // Make sure the window specification is valid. - s.validate match { - case Some(m) => - failAnalysis(s"Window specification $s is not valid because $m") - case None => w - } case s @ ScalarSubquery(query, conditions, _) => checkAnalysis(query) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e1dd010d37a95..a2fcad9b2f5c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -58,6 +58,7 @@ object TypeCoercion { PropagateTypes :: ImplicitTypeCasts :: DateTimeOperations :: + WindowFrameCoercion :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -785,4 +786,26 @@ object TypeCoercion { Option(ret) } } + + /** + * Cast WindowFrame boundaries to the type they operate upon. + */ + object WindowFrameCoercion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) + if order.resolved => + s.copy(frameSpecification = SpecifiedWindowFrame( + RangeFrame, + createBoundaryCast(lower, order.dataType), + createBoundaryCast(upper, order.dataType))) + } + + private def createBoundaryCast(boundary: Expression, dt: DataType): Expression = { + boundary match { + case e: SpecialFrameBoundary => e + case e: Expression if e.dataType != dt && Cast.canCast(e.dataType, dt) => Cast(e, dt) + case _ => boundary + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 4c8b177237d23..1a48995358af7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -74,6 +74,13 @@ package object expressions { def initialize(partitionIndex: Int): Unit = {} } + /** + * An identity projection. This returns the input row. + */ + object IdentityProjection extends Projection { + override def apply(row: InternalRow): InternalRow = row + } + /** * Converts a [[InternalRow]] to another Row given a sequence of expression that define each * column of the new row. If the schema of the input row is specified, then the given expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 37190429fc423..62c4832827c4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} @@ -43,34 +42,7 @@ case class WindowSpecDefinition( orderSpec: Seq[SortOrder], frameSpecification: WindowFrame) extends Expression with WindowSpec with Unevaluable { - def validate: Option[String] = frameSpecification match { - case UnspecifiedFrame => - Some("Found a UnspecifiedFrame. It should be converted to a SpecifiedWindowFrame " + - "during analysis. Please file a bug report.") - case frame: SpecifiedWindowFrame => frame.validate.orElse { - def checkValueBasedBoundaryForRangeFrame(): Option[String] = { - if (orderSpec.length > 1) { - // It is not allowed to have a value-based PRECEDING and FOLLOWING - // as the boundary of a Range Window Frame. - Some("This Range Window Frame only accepts at most one ORDER BY expression.") - } else if (orderSpec.nonEmpty && !orderSpec.head.dataType.isInstanceOf[NumericType]) { - Some("The data type of the expression in the ORDER BY clause should be a numeric type.") - } else { - None - } - } - - (frame.frameType, frame.frameStart, frame.frameEnd) match { - case (RangeFrame, vp: ValuePreceding, _) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, vf: ValueFollowing, _) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, _, vp: ValuePreceding) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, _, vf: ValueFollowing) => checkValueBasedBoundaryForRangeFrame() - case (_, _, _) => None - } - } - } - - override def children: Seq[Expression] = partitionSpec ++ orderSpec + override def children: Seq[Expression] = partitionSpec ++ orderSpec :+ frameSpecification override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && @@ -78,23 +50,46 @@ case class WindowSpecDefinition( override def nullable: Boolean = true override def foldable: Boolean = false - override def dataType: DataType = throw new UnsupportedOperationException + override def dataType: DataType = throw new UnsupportedOperationException("dataType") - override def sql: String = { - val partition = if (partitionSpec.isEmpty) { - "" - } else { - "PARTITION BY " + partitionSpec.map(_.sql).mkString(", ") + " " + override def checkInputDataTypes(): TypeCheckResult = { + frameSpecification match { + case UnspecifiedFrame => + TypeCheckFailure( + "Cannot use an UnspecifiedFrame. This should have been converted during analysis. " + + "Please file a bug report.") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && !f.isUnbounded && + orderSpec.isEmpty => + TypeCheckFailure( + "A range window frame cannot be used in an unordered window specification.") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && + orderSpec.size > 1 => + TypeCheckFailure( + s"A range window frame with value boundaries cannot be used in a window specification " + + s"with multiple order by expressions: ${orderSpec.mkString(",")}") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && + !isValidFrameType(f.valueBoundary.head.dataType) => + TypeCheckFailure( + s"The data type '${orderSpec.head.dataType}' used in the order specification does " + + s"not match the data type '${f.valueBoundary.head.dataType}' which is used in the " + + "range frame.") + case _ => TypeCheckSuccess } + } - val order = if (orderSpec.isEmpty) { - "" - } else { - "ORDER BY " + orderSpec.map(_.sql).mkString(", ") + " " + override def sql: String = { + def toSql(exprs: Seq[Expression], prefix: String): Seq[String] = { + Seq(exprs).filter(_.nonEmpty).map(_.map(_.sql).mkString(prefix, ", ", "")) } - s"($partition$order${frameSpecification.toString})" + val elements = + toSql(partitionSpec, "PARTITION BY ") ++ + toSql(orderSpec, "ORDER BY ") ++ + Seq(frameSpecification.sql) + elements.mkString("(", " ", ")") } + + private def isValidFrameType(ft: DataType): Boolean = orderSpec.head.dataType == ft } /** @@ -106,22 +101,26 @@ case class WindowSpecReference(name: String) extends WindowSpec /** * The trait used to represent the type of a Window Frame. */ -sealed trait FrameType +sealed trait FrameType { + def inputType: AbstractDataType + def sql: String +} /** - * RowFrame treats rows in a partition individually. When a [[ValuePreceding]] - * or a [[ValueFollowing]] is used as its [[FrameBoundary]], the value is considered - * as a physical offset. + * RowFrame treats rows in a partition individually. Values used in a row frame are considered + * to be physical offsets. * For example, `ROW BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a 3-row frame, * from the row precedes the current row to the row follows the current row. */ -case object RowFrame extends FrameType +case object RowFrame extends FrameType { + override def inputType: AbstractDataType = IntegerType + override def sql: String = "ROWS" +} /** - * RangeFrame treats rows in a partition as groups of peers. - * All rows having the same `ORDER BY` ordering are considered as peers. - * When a [[ValuePreceding]] or a [[ValueFollowing]] is used as its [[FrameBoundary]], - * the value is considered as a logical offset. + * RangeFrame treats rows in a partition as groups of peers. All rows having the same `ORDER BY` + * ordering are considered as peers. Values used in a range frame are considered to be logical + * offsets. * For example, assuming the value of the current row's `ORDER BY` expression `expr` is `v`, * `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a frame containing rows whose values * `expr` are in the range of [v-1, v+1]. @@ -129,138 +128,144 @@ case object RowFrame extends FrameType * If `ORDER BY` clause is not defined, all rows in the partition is considered as peers * of the current row. */ -case object RangeFrame extends FrameType - -/** - * The trait used to represent the type of a Window Frame Boundary. - */ -sealed trait FrameBoundary { - def notFollows(other: FrameBoundary): Boolean +case object RangeFrame extends FrameType { + override def inputType: AbstractDataType = NumericType + override def sql: String = "RANGE" } /** - * Extractor for making working with frame boundaries easier. + * The trait used to represent special boundaries used in a window frame. */ -object FrameBoundary { - def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { - case CurrentRow => Some(0) - case ValuePreceding(offset) => Some(-offset) - case ValueFollowing(offset) => Some(offset) - case _ => None - } +sealed trait SpecialFrameBoundary extends Expression with Unevaluable { + override def children: Seq[Expression] = Nil + override def dataType: DataType = NullType + override def foldable: Boolean = false + override def nullable: Boolean = false } -/** UNBOUNDED PRECEDING boundary. */ -case object UnboundedPreceding extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => true - case vp: ValuePreceding => true - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = "UNBOUNDED PRECEDING" +/** UNBOUNDED boundary. */ +case object UnboundedPreceding extends SpecialFrameBoundary { + override def sql: String = "UNBOUNDED PRECEDING" } -/** PRECEDING boundary. */ -case class ValuePreceding(value: Int) extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case ValuePreceding(anotherValue) => value >= anotherValue - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = s"$value PRECEDING" +case object UnboundedFollowing extends SpecialFrameBoundary { + override def sql: String = "UNBOUNDED FOLLOWING" } /** CURRENT ROW boundary. */ -case object CurrentRow extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = "CURRENT ROW" -} - -/** FOLLOWING boundary. */ -case class ValueFollowing(value: Int) extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => false - case ValueFollowing(anotherValue) => value <= anotherValue - case UnboundedFollowing => true - } - - override def toString: String = s"$value FOLLOWING" -} - -/** UNBOUNDED FOLLOWING boundary. */ -case object UnboundedFollowing extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => false - case vf: ValueFollowing => false - case UnboundedFollowing => true - } - - override def toString: String = "UNBOUNDED FOLLOWING" +case object CurrentRow extends SpecialFrameBoundary { + override def sql: String = "CURRENT ROW" } /** * The trait used to represent the a Window Frame. */ -sealed trait WindowFrame +sealed trait WindowFrame extends Expression with Unevaluable { + override def children: Seq[Expression] = Nil + override def dataType: DataType = throw new UnsupportedOperationException("dataType") + override def foldable: Boolean = false + override def nullable: Boolean = false +} /** Used as a place holder when a frame specification is not defined. */ case object UnspecifiedFrame extends WindowFrame -/** A specified Window Frame. */ +/** + * A specified Window Frame. The val lower/uppper can be either a foldable [[Expression]] or a + * [[SpecialFrameBoundary]]. + */ case class SpecifiedWindowFrame( frameType: FrameType, - frameStart: FrameBoundary, - frameEnd: FrameBoundary) extends WindowFrame { - - /** If this WindowFrame is valid or not. */ - def validate: Option[String] = (frameType, frameStart, frameEnd) match { - case (_, UnboundedFollowing, _) => - Some(s"$UnboundedFollowing is not allowed as the start of a Window Frame.") - case (_, _, UnboundedPreceding) => - Some(s"$UnboundedPreceding is not allowed as the end of a Window Frame.") - // case (RowFrame, start, end) => ??? RowFrame specific rule - // case (RangeFrame, start, end) => ??? RangeFrame specific rule - case (_, start, end) => - if (start.notFollows(end)) { - None - } else { - val reason = - s"The end of this Window Frame $end is smaller than the start of " + - s"this Window Frame $start." - Some(reason) - } + lower: Expression, + upper: Expression) + extends WindowFrame { + + override def children: Seq[Expression] = lower :: upper :: Nil + + lazy val valueBoundary: Seq[Expression] = + children.filterNot(_.isInstanceOf[SpecialFrameBoundary]) + + override def checkInputDataTypes(): TypeCheckResult = { + // Check lower value. + val lowerCheck = checkBoundary(lower, "lower") + if (lowerCheck.isFailure) { + return lowerCheck + } + + // Check upper value. + val upperCheck = checkBoundary(upper, "upper") + if (upperCheck.isFailure) { + return upperCheck + } + + // Check combination (of expressions). + (lower, upper) match { + case (l: Expression, u: Expression) if !isValidFrameBoundary(l, u) => + TypeCheckFailure(s"Window frame upper bound '$upper' does not followes the lower bound " + + s"'$lower'.") + case (l: SpecialFrameBoundary, _) => TypeCheckSuccess + case (_, u: SpecialFrameBoundary) => TypeCheckSuccess + case (l: Expression, u: Expression) if l.dataType != u.dataType => + TypeCheckFailure( + s"Window frame bounds '$lower' and '$upper' do no not have the same data type: " + + s"'${l.dataType.catalogString}' <> '${u.dataType.catalogString}'") + case (l: Expression, u: Expression) if isGreaterThan(l, u) => + TypeCheckFailure( + "The lower bound of a window frame must be less than or equal to the upper bound") + case _ => TypeCheckSuccess + } + } + + override def sql: String = { + val lowerSql = boundarySql(lower) + val upperSql = boundarySql(upper) + s"${frameType.sql} BETWEEN $lowerSql AND $upperSql" } - override def toString: String = frameType match { - case RowFrame => s"ROWS BETWEEN $frameStart AND $frameEnd" - case RangeFrame => s"RANGE BETWEEN $frameStart AND $frameEnd" + def isUnbounded: Boolean = lower == UnboundedPreceding && upper == UnboundedFollowing + + def isValueBound: Boolean = valueBoundary.nonEmpty + + def isOffset: Boolean = (lower, upper) match { + case (l: Expression, u: Expression) => frameType == RowFrame && l == u + case _ => false + } + + private def boundarySql(expr: Expression): String = expr match { + case e: SpecialFrameBoundary => e.sql + case UnaryMinus(n) => n.sql + " PRECEDING" + case e: Expression => e.sql + " FOLLOWING" + } + + private def isGreaterThan(l: Expression, r: Expression): Boolean = { + GreaterThan(l, r).eval().asInstanceOf[Boolean] + } + + private def checkBoundary(b: Expression, location: String): TypeCheckResult = b match { + case _: SpecialFrameBoundary => TypeCheckSuccess + case e: Expression if !e.foldable => + TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") + case e: Expression if !frameType.inputType.acceptsType(e.dataType) => + TypeCheckFailure( + s"The data type of the $location bound '${e.dataType} does not match " + + s"the expected data type '${frameType.inputType}'.") + case _ => TypeCheckSuccess + } + + private def isValidFrameBoundary(l: Expression, u: Expression): Boolean = { + (l, u) match { + case (UnboundedFollowing, _) => false + case (_, UnboundedPreceding) => false + case _ => true + } } } object SpecifiedWindowFrame { /** - * * @param hasOrderSpecification If the window spec has order by expressions. * @param acceptWindowFrame If the window function accepts user-specified frame. - * @return + * @return the default window frame. */ def defaultWindowFrame( hasOrderSpecification: Boolean, @@ -351,20 +356,25 @@ abstract class OffsetWindowFunction override def nullable: Boolean = default == null || default.nullable || input.nullable - override lazy val frame = { - // This will be triggered by the Analyzer. - val offsetValue = offset.eval() match { - case o: Int => o - case x => throw new AnalysisException( - s"Offset expression must be a foldable integer expression: $x") - } + override lazy val frame: WindowFrame = { val boundary = direction match { - case Ascending => ValueFollowing(offsetValue) - case Descending => ValuePreceding(offsetValue) + case Ascending => offset + case Descending => UnaryMinus(offset) } SpecifiedWindowFrame(RowFrame, boundary, boundary) } + override def checkInputDataTypes(): TypeCheckResult = { + val check = super.checkInputDataTypes() + if (check.isFailure) { + check + } else if (!offset.foldable) { + TypeCheckFailure(s"Offset expression '$offset' must be a literal.") + } else { + TypeCheckSuccess + } + } + override def dataType: DataType = input.dataType override def inputTypes: Seq[AbstractDataType] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d1c9332bee18b..8ea18fa718e51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1142,32 +1142,26 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value - * Preceding/Following boundaries. These expressions must be constant (foldable) and return an - * integer value. + * Create or resolve a frame boundary expressions. */ - override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { - // We currently only allow foldable integers. - def value: Int = { + override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) { + def value: Expression = { val e = expression(ctx.expression) - validate(e.resolved && e.foldable && e.dataType == IntegerType, - "Frame bound value must be a constant integer.", - ctx) - e.eval().asInstanceOf[Int] + validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx) + e } - // Create the FrameBoundary ctx.boundType.getType match { case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => UnboundedPreceding case SqlBaseParser.PRECEDING => - ValuePreceding(value) + UnaryMinus(value) case SqlBaseParser.CURRENT => CurrentRow case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => UnboundedFollowing case SqlBaseParser.FOLLOWING => - ValueFollowing(value) + value } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ae5c513eb040b..6fb5a9976c215 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -688,8 +688,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case id: FunctionIdentifier => true case spec: BucketSpec => true case catalog: CatalogTable => true - case boundary: FrameBoundary => true - case frame: WindowFrame => true case partition: Partitioning => true case resource: FunctionResource => true case broadcast: BroadcastMode => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5050318d96358..adc50321c3d39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -191,7 +191,7 @@ class AnalysisErrorSuite extends AnalysisTest { WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, - SpecifiedWindowFrame(RangeFrame, ValueFollowing(1), ValueFollowing(2)))).as('window)), + SpecifiedWindowFrame(RangeFrame, Literal(1), Literal(2)))).as('window)), "window frame" :: "must match the required frame" :: Nil) errorTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2624f5586fd5d..345ec0fc31301 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -998,6 +998,42 @@ class TypeCoercionSuite extends PlanTest { EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) } + + test("cast WindowFrame boundaries to the type they operate upon") { + // Can cast frame boundaries to order dataType. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(3), Literal(2147483648L))), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, Cast(3, LongType), Literal(2147483648L))) + ) + // Cannot cast frame boundaries to order dataType. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal.default(DateType), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal.default(DateType), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))) + ) + // Should not cast SpecialFrameBoundary. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index f06219198bb58..e770cc35974e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -262,16 +262,17 @@ class ExpressionParserSuite extends PlanTest { // Range/Row val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val boundaries = Seq( - ("10 preceding", ValuePreceding(10), CurrentRow), - ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("10 preceding", -Literal(10), CurrentRow), + ("2147483648 preceding", -Literal(2147483648L), CurrentRow), + ("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), ("unbounded preceding", UnboundedPreceding, CurrentRow), ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), ("between unbounded preceding and unbounded following", UnboundedPreceding, UnboundedFollowing), - ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), - ("between current row and 5 following", CurrentRow, ValueFollowing(5)), - ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ("between 10 preceding and current row", -Literal(10), CurrentRow), + ("between current row and 5 following", CurrentRow, Literal(5)), + ("between 10 preceding and 5 following", -Literal(10), Literal(5)) ) frameTypes.foreach { case (frameTypeSql, frameType) => @@ -283,13 +284,9 @@ class ExpressionParserSuite extends PlanTest { } } - // We cannot use non integer constants. - intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", - "Frame bound value must be a constant integer.") - // We cannot use an arbitrary expression. intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", - "Frame bound value must be a constant integer.") + "Frame bound value must be a literal.") } test("row constructor") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 950f152b94b4d..1e43654c5c68e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -243,7 +243,7 @@ class PlanParserSuite extends PlanTest { val sql = "select * from t" val plan = table("t").select(star()) val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), - SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + SpecifiedWindowFrame(RowFrame, -Literal(1), Literal(1))) // Test window resolution. val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 06ef7bcee0d84..363c1e21a2957 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -436,21 +436,22 @@ class TreeNodeSuite extends SparkFunSuite { "bucketColumnNames" -> "[bucket]", "sortColumnNames" -> "[sort]")) - // Converts FrameBoundary to JSON - assertJSON( - ValueFollowing(3), - JObject( - "product-class" -> classOf[ValueFollowing].getName, - "value" -> 3)) - // Converts WindowFrame to JSON assertJSON( - SpecifiedWindowFrame(RowFrame, UnboundedFollowing, CurrentRow), - JObject( - "product-class" -> classOf[SpecifiedWindowFrame].getName, - "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)), - "frameStart" -> JObject("object" -> JString(UnboundedFollowing.getClass.getName)), - "frameEnd" -> JObject("object" -> JString(CurrentRow.getClass.getName)))) + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow), + List( + JObject( + "class" -> classOf[SpecifiedWindowFrame].getName, + "num-children" -> 2, + "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)), + "lower" -> 0, + "upper" -> 1), + JObject( + "class" -> UnboundedPreceding.getClass.getName, + "num-children" -> 0), + JObject( + "class" -> CurrentRow.getClass.getName, + "num-children" -> 0))) // Converts Partitioning to JSON assertJSON( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 950a6794a74a3..9efbeeef35ac0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.types.IntegerType /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -109,46 +108,50 @@ case class WindowExec( * * This method uses Code Generation. It can only be used on the executor side. * - * @param frameType to evaluate. This can either be Row or Range based. - * @param offset with respect to the row. + * @param frame to evaluate. This can either be a Row or Range frame. + * @param bound with respect to the row. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { - frameType match { - case RangeFrame => - val (exprs, current, bound) = if (offset == 0) { - // Use the entire order expression when the offset is 0. - val exprs = orderSpec.map(_.child) - val buildProjection = () => newMutableProjection(exprs, child.output) - (orderSpec, buildProjection(), buildProjection()) - } else if (orderSpec.size == 1) { - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output) - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => -offset - case Ascending => offset - } - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output) - (sortExpr :: Nil, current, bound) - } else { - sys.error("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") + private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + (frame, bound) match { + case (RowFrame, CurrentRow) => + RowBoundOrdering(0) + + case (RowFrame, IntegerLiteral(offset)) => + RowBoundOrdering(offset) + + case (RangeFrame, CurrentRow) => + val ordering = newOrdering(orderSpec, child.output) + RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) + + case (RangeFrame, offset: Expression) if orderSpec.size == 1 => + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output) + + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => UnaryMinus(offset) + case Ascending => offset } + + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = Add(expr, Cast(boundOffset, expr.dataType)) + val bound = newMutableProjection(boundExpr :: Nil, child.output) + // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). - val sortExprs = exprs.zipWithIndex.map { case (e, i) => - SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) - } - val ordering = newOrdering(sortExprs, Nil) + val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil + val ordering = newOrdering(boundSortExprs, Nil) RangeBoundOrdering(ordering, current, bound) - case RowFrame => RowBoundOrdering(offset) + + case (RangeFrame, _) => + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") } } @@ -157,13 +160,13 @@ case class WindowExec( * WindowExpressions and factory function for the WindowFrameFunction. */ private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type FrameKey = (String, FrameType, Expression, Expression) type ExpressionBuffer = mutable.Buffer[Expression] val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] // Add a function and its function to the map for a given frame. def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val key = (tpe, fr.frameType, fr.lower, fr.upper) val (es, fns) = framedFunctions.getOrElseUpdate( key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) es += e @@ -203,7 +206,7 @@ case class WindowExec( // Create the factory val factory = key match { // Offset Frame - case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + case ("OFFSET", _, IntegerLiteral(offset), _) => target: InternalRow => new OffsetWindowFunctionFrame( target, @@ -215,38 +218,38 @@ case class WindowExec( newMutableProjection(expressions, schema, subexpressionEliminationEnabled), offset) + // Entire Partition Frame. + case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => + target: InternalRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + // Growing Frame. - case ("AGGREGATE", frameType, None, Some(high)) => + case ("AGGREGATE", frameType, UnboundedPreceding, upper) => target: InternalRow => { new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, high)) + createBoundOrdering(frameType, upper)) } // Shrinking Frame. - case ("AGGREGATE", frameType, Some(low), None) => + case ("AGGREGATE", frameType, lower, UnboundedFollowing) => target: InternalRow => { new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, low)) + createBoundOrdering(frameType, lower)) } // Moving Frame. - case ("AGGREGATE", frameType, Some(low), Some(high)) => + case ("AGGREGATE", frameType, lower, upper) => target: InternalRow => { new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, low), - createBoundOrdering(frameType, high)) - } - - // Entire Partition Frame. - case ("AGGREGATE", frameType, None, None) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) + createBoundOrdering(frameType, lower), + createBoundOrdering(frameType, upper)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 6279d48c94de5..e6cb889b338c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.Column +import org.apache.spark.sql.{AnalysisException, Column} import org.apache.spark.sql.catalyst.expressions._ /** @@ -123,7 +123,24 @@ class WindowSpec private[sql]( */ // Note: when updating the doc for this method, also update Window.rowsBetween. def rowsBetween(start: Long, end: Long): WindowSpec = { - between(RowFrame, start, end) + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case x => throw new AnalysisException(s"Boundary start is not a valid integer: $x") + } + + val boundaryEnd = end match { + case 0 => CurrentRow + case Long.MaxValue => UnboundedFollowing + case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case x => throw new AnalysisException(s"Boundary end is not a valid integer: $x") + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(RowFrame, boundaryStart, boundaryEnd)) } /** @@ -174,28 +191,22 @@ class WindowSpec private[sql]( */ // Note: when updating the doc for this method, also update Window.rangeBetween. def rangeBetween(start: Long, end: Long): WindowSpec = { - between(RangeFrame, start, end) - } - - private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { val boundaryStart = start match { case 0 => CurrentRow case Long.MinValue => UnboundedPreceding - case x if x < 0 => ValuePreceding(-start.toInt) - case x if x > 0 => ValueFollowing(start.toInt) + case x => Literal(x) } val boundaryEnd = end match { case 0 => CurrentRow case Long.MaxValue => UnboundedFollowing - case x if x < 0 => ValuePreceding(-end.toInt) - case x if x > 0 => ValueFollowing(end.toInt) + case x => Literal(x) } new WindowSpec( partitionSpec, orderSpec, - SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + SpecifiedWindowFrame(RangeFrame, boundaryStart, boundaryEnd)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index c800fc3d49891..342e5719e9a60 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -1,24 +1,44 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES -(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null) -AS testData(val, cate); +(null, 1L, "a"), (1, 1L, "a"), (1, 2L, "a"), (2, 2147483650L, "a"), (1, null, "b"), (2, 3L, "b"), +(3, 2147483650L, "b"), (null, null, null), (3, 1L, null) +AS testData(val, val_long, cate); -- RowsBetween SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData ORDER BY cate, val; SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long; -- RangeBetween SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData ORDER BY cate, val; SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long; -- RangeBetween with reverse OrderBy SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +-- Invalid window frame +SELECT val, cate, count(val) OVER(PARTITION BY cate +ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val, cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY current_date +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val; + + -- Window functions SELECT val, cate, max(val) OVER w AS max, diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index aa5856138ed81..97511068b323c 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,11 +1,12 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 19 -- !query 0 CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES -(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null) -AS testData(val, cate) +(null, 1L, "a"), (1, 1L, "a"), (1, 2L, "a"), (2, 2147483650L, "a"), (1, null, "b"), (2, 3L, "b"), +(3, 2147483650L, "b"), (null, null, null), (3, 1L, null) +AS testData(val, val_long, cate) -- !query 0 schema struct<> -- !query 0 output @@ -47,11 +48,21 @@ NULL a 1 -- !query 3 +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType does not match the expected data type 'IntegerType'.; line 1 pos 41 + + +-- !query 4 SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData ORDER BY cate, val --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output NULL NULL 0 3 NULL 1 NULL a 0 @@ -63,12 +74,12 @@ NULL a 0 3 b 2 --- !query 4 +-- !query 5 SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output NULL NULL NULL 3 NULL 3 NULL a NULL @@ -80,12 +91,29 @@ NULL a NULL 3 b 3 --- !query 5 +-- !query 6 +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long +-- !query 6 schema +struct +-- !query 6 output +NULL NULL NULL +1 NULL 1 +1 a 4 +1 a 4 +2 a 2147483652 +2147483650 a 2147483650 +NULL b NULL +3 b 2147483653 +2147483650 b 2147483650 + + +-- !query 7 SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val --- !query 5 schema +-- !query 7 schema struct --- !query 5 output +-- !query 7 output NULL NULL NULL 3 NULL 3 NULL a NULL @@ -97,7 +125,73 @@ NULL a NULL 3 b 5 --- !query 6 +-- !query 8 +SELECT val, cate, count(val) OVER(PARTITION BY cate +ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +cannot resolve 'ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING' due to data type mismatch: Window frame upper bound '1' does not followes the lower bound 'unboundedfollowing$()'.; line 1 pos 33 + + +-- !query 9 +SELECT val, cate, count(val) OVER(PARTITION BY cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame cannot be used in an unordered window specification.; line 1 pos 33 + + +-- !query 10 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val, cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY testdata.`val` ASC NULLS FIRST, testdata.`cate` ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame with value boundaries cannot be used in a window specification with multiple order by expressions: val#x ASC NULLS FIRST,cate#x ASC NULLS FIRST; line 1 pos 33 + + +-- !query 11 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY current_date +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_date() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'DateType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 33 + + +-- !query 12 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING) FROM testData ORDER BY cate, val +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve 'RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING' due to data type mismatch: The lower bound of a window frame must be less than or equal to the upper bound; line 1 pos 33 + + +-- !query 13 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +Frame bound value must be a literal.(line 2, pos 30) + +== SQL == +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val +------------------------------^^^ + + +-- !query 14 SELECT val, cate, max(val) OVER w AS max, min(val) OVER w AS min, @@ -124,9 +218,9 @@ approx_count_distinct(val) OVER w AS approx_count_distinct FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val --- !query 6 schema +-- !query 14 schema struct --- !query 6 output +-- !query 14 output NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 @@ -138,11 +232,11 @@ NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0. 3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 --- !query 7 +-- !query 15 SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val --- !query 7 schema +-- !query 15 schema struct --- !query 7 output +-- !query 15 output NULL NULL NULL 3 NULL NULL NULL a NULL @@ -154,20 +248,20 @@ NULL a NULL 3 b NULL --- !query 8 +-- !query 16 SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val --- !query 8 schema +-- !query 16 schema struct<> --- !query 8 output +-- !query 16 output org.apache.spark.sql.AnalysisException Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table; --- !query 9 +-- !query 17 SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val --- !query 9 schema +-- !query 17 schema struct --- !query 9 output +-- !query 17 output NULL NULL 13 1.8571428571428572 3 NULL 13 1.8571428571428572 NULL a 13 1.8571428571428572 @@ -179,7 +273,7 @@ NULL a 13 1.8571428571428572 3 b 13 1.8571428571428572 --- !query 10 +-- !query 18 SELECT val, cate, first_value(false) OVER w AS first_value, first_value(true, true) OVER w AS first_value_ignore_null, @@ -190,9 +284,9 @@ last_value(false, false) OVER w AS last_value_contain_null FROM testData WINDOW w AS () ORDER BY cate, val --- !query 10 schema +-- !query 18 schema struct --- !query 10 output +-- !query 18 output NULL NULL false true false false true false 3 NULL false true false false true false NULL a false true false false true false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 204858fa29787..9806e57f08744 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -151,6 +151,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(2.0d), Row(2.0d))) } + test("row between should accept integer values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), + (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept integer/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), + (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + } + test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.createOrReplaceTempView("window_table") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index 149ce1e195111..90f90599d5bf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -98,27 +98,27 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), - s"(PARTITION BY `a` $frame)" + s"(PARTITION BY `a` ${frame.sql})" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, Nil, frame), - s"(PARTITION BY `a`, `b` $frame)" + s"(PARTITION BY `a`, `b` ${frame.sql})" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: Nil, frame), - s"(ORDER BY `a` ASC NULLS FIRST $frame)" + s"(ORDER BY `a` ASC NULLS FIRST ${frame.sql})" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: 'b.string.desc :: Nil, frame), - s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST $frame)" + s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST ${frame.sql})" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, 'c.int.asc :: 'd.string.desc :: Nil, frame), - s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST $frame)" + s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST ${frame.sql})" ) }