From 3a342dedc04799948bf6da69843bd1a91202ffe5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 16:59:39 -0700 Subject: [PATCH] Revert "[SPARK-8770][SQL] Create BinaryOperator abstract class." This reverts commit 272778999823ed79af92280350c5869a87a21f29. --- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../expressions/ExpectsInputTypes.scala | 59 ------- .../sql/catalyst/expressions/Expression.scala | 161 +++++++++--------- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 6 + .../sql/catalyst/expressions/arithmetic.scala | 14 +- .../expressions/complexTypeCreator.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 2 + .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/expressions/sets.scala | 2 + .../expressions/stringOperations.scala | 26 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 6 +- 12 files changed, 135 insertions(+), 170 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8420c54f7c335..2ab5cb666fbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -150,7 +150,6 @@ object HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ - // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -160,19 +159,19 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryOperator(StringNaN, right @ DoubleType()) => + case b @ BinaryExpression(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryOperator(left @ DoubleType(), StringNaN) => + case b @ BinaryExpression(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryOperator(StringNaN, right @ FloatType()) => + case b @ BinaryExpression(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryOperator(left @ FloatType(), StringNaN) => + case b @ BinaryExpression(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryOperator(left @ StringNaN, StringNaN) => + case b @ BinaryExpression(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -246,12 +245,12 @@ object HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryOperator. + // Also widen types for BinaryExpressions. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -479,7 +478,7 @@ object HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala deleted file mode 100644 index 450fc4165f93b..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.types.DataType - - -/** - * An trait that gets mixin to define the expected input types of an expression. - */ -trait ExpectsInputTypes { self: Expression => - - /** - * Expected input types from child expressions. The i-th position in the returned seq indicates - * the type requirement for the i-th child. - * - * The possible values at each position are: - * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). - */ - def inputTypes: Seq[Any] - - override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess - } -} - -/** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. - */ -trait AutoCastInputTypes { self: Expression => - - def inputTypes: Seq[DataType] - - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index cafbbafdca207..e18a3118945e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,6 +119,17 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) + /** + * Returns a string representation of this expression that does not have developer centric + * debugging information like the expression id. + */ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString + } + /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -143,40 +154,71 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess +} + +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable + + override def toString: String = s"($left $symbol $right)" /** - * Returns a user-facing string representation of this expression's name. - * This should usually match the name of the function in SQL. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. */ - def prettyName: String = getClass.getSimpleName.toLowerCase + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } /** - * Returns a user-facing string representation of this expression, i.e. does not have developer - * centric debugging information like the expression id. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } - - override def toString: String = prettyName + children.mkString("(", ",", ")") } +private[sql] object BinaryExpression { + def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) +} -/** - * A leaf expression, i.e. one without any child expressions. - */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } - -/** - * An expression with one input and one output. The output is by default evaluated to null - * if the input is evaluated to null. - */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -223,76 +265,39 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } - /** - * An expression with two inputs and one output. The output is by default evaluated to null - * if any input is evaluated to null. + * An trait that gets mixin to define the expected input types of an expression. */ -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable +trait ExpectsInputTypes { self: Expression => /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. * - * @param f accepts two variable names and returns Java code to compute the output. + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def inputTypes: Seq[Any] - /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess } } - /** - * An expression that has two inputs that are expected to the be same type. If the two inputs have - * different types, the analyzer will find the tightest common type and do the proper type casting. + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. */ -abstract class BinaryOperator extends BinaryExpression { - self: Product => +trait AutoCastInputTypes { self: Expression => - def symbol: String + def inputTypes: Seq[DataType] - override def toString: String = s"($left $symbol $right)" -} - - -private[sql] object BinaryOperator { - def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index caf021b016a41..ebabb6f117851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"UDF(${children.mkString(",")})" + override def toString: String = s"scalaUDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index da520f56b430e..a9fc54c548f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -128,6 +128,7 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType + override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -161,6 +162,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType + override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -399,6 +401,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } + override def toString: String = s"AVG($child)" + override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -490,6 +494,8 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } + override def toString: String = s"SUM($child)" + override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4fbf4c87009c2..5363b3556886a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def prettyName: String = "positive" + override def toString: String = s"positive($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { + override def toString: String = s"Abs($child)" + override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -77,9 +79,10 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryOperator { +abstract class BinaryArithmetic extends BinaryExpression { self: Product => + override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -357,9 +360,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - - override def symbol: String = "max" - override def prettyName: String = symbol + override def toString: String = s"MaxOf($left, $right)" } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -412,6 +413,5 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def symbol: String = "min" - override def prettyName: String = symbol + override def toString: String = s"MinOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67e7dc4ec8b14..5def57b067424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def prettyName: String = "array" + override def toString: String = s"Array(${children.mkString(",")})" } /** @@ -71,6 +71,4 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } - - override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 145d323a9f0bb..78be2824347d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,6 +38,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } + override def toString: String = s"Coalesce(${children.mkString(",")})" + override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 34df89a163895..a777f77add2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,7 +120,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -169,7 +169,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryOperator with Predicate { +abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 5d51a4ca65332..daa9f4403ffab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,6 +137,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType + override def symbol: String = "++=" + override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index b020f2bbc5818..4cbfc4e084948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -75,6 +75,8 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { + override def symbol: String = "LIKE" + // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -99,16 +101,14 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() - - override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { + override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) - override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -134,7 +134,9 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase + override def convert(v: UTF8String): UTF8String = v.toUpperCase() + + override def toString: String = s"Upper($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -146,7 +148,9 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase + override def convert(v: UTF8String): UTF8String = v.toLowerCase() + + override def toString: String = s"Lower($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -174,6 +178,8 @@ trait StringComparison extends AutoCastInputTypes { } } + override def symbol: String = nodeName + override def toString: String = s"$nodeName($left, $right)" } @@ -278,6 +284,12 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } + + override def toString: String = len match { + // TODO: This is broken because max is not an integer value. + case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" + case _ => s"SUBSTR($str, $pos, $len)" + } } /** @@ -292,9 +304,9 @@ case class StringLength(child: Expression) extends UnaryExpression with AutoCast if (string == null) null else string.asInstanceOf[UTF8String].length } + override def toString: String = s"length($child)" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } - - override def prettyName: String = "length" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 86792f0217572..bda217935cb05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryOperator => actual.append(b.symbol); b + case b: BinaryExpression => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryOperator => actual.append(b.symbol); b + case b: BinaryExpression => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryOperator => actual.append(b.symbol); + case b: BinaryExpression => actual.append(b.symbol); case l: Literal => actual.append(l.toString); }