diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 73c9a1c7afdad..831fb4fe95fe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -239,37 +239,43 @@ trait HiveTypeCoercion { a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) // we should cast all timestamp/date/string compare into string compare - case p: BinaryPredicate if p.left.dataType == StringType - && p.right.dataType == DateType => + case p: BinaryComparison if p.left.dataType == StringType && + p.right.dataType == DateType => p.makeCopy(Array(p.left, Cast(p.right, StringType))) - case p: BinaryPredicate if p.left.dataType == DateType - && p.right.dataType == StringType => + case p: BinaryComparison if p.left.dataType == DateType && + p.right.dataType == StringType => p.makeCopy(Array(Cast(p.left, StringType), p.right)) - case p: BinaryPredicate if p.left.dataType == StringType - && p.right.dataType == TimestampType => + case p: BinaryComparison if p.left.dataType == StringType && + p.right.dataType == TimestampType => p.makeCopy(Array(p.left, Cast(p.right, StringType))) - case p: BinaryPredicate if p.left.dataType == TimestampType - && p.right.dataType == StringType => + case p: BinaryComparison if p.left.dataType == TimestampType && + p.right.dataType == StringType => p.makeCopy(Array(Cast(p.left, StringType), p.right)) - case p: BinaryPredicate if p.left.dataType == TimestampType - && p.right.dataType == DateType => + case p: BinaryComparison if p.left.dataType == TimestampType && + p.right.dataType == DateType => p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - case p: BinaryPredicate if p.left.dataType == DateType - && p.right.dataType == TimestampType => + case p: BinaryComparison if p.left.dataType == DateType && + p.right.dataType == TimestampType => p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => + case p: BinaryComparison if p.left.dataType == StringType && + p.right.dataType != StringType => p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) - case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => + case p: BinaryComparison if p.left.dataType != StringType && + p.right.dataType == StringType => p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) - case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) => + case i @ In(a, b) if a.dataType == DateType && + b.forall(_.dataType == StringType) => i.makeCopy(Array(Cast(a, StringType), b)) - case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => + case i @ In(a, b) if a.dataType == TimestampType && + b.forall(_.dataType == StringType) => i.makeCopy(Array(Cast(a, StringType), b)) - case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) => + case i @ In(a, b) if a.dataType == DateType && + b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) => + case i @ In(a, b) if a.dataType == TimestampType && + b.forall(_.dataType == DateType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e) if e.dataType == StringType => @@ -420,19 +426,19 @@ trait HiveTypeCoercion { ) case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) // Promote integers inside a binary expression with fixed-precision decimals to decimals, @@ -481,8 +487,8 @@ trait HiveTypeCoercion { // No need to change the EqualNullSafe operators, too case e: EqualNullSafe => e // Otherwise turn them to Byte types so that there exists and ordering. - case p: BinaryComparison - if p.left.dataType == BooleanType && p.right.dataType == BooleanType => + case p: BinaryComparison if p.left.dataType == BooleanType && + p.right.dataType == BooleanType => p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) } } @@ -564,10 +570,6 @@ trait HiveTypeCoercion { case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) - // Compatible with Hive - case Substring(e, start, len) if e.dataType != StringType => - Substring(Cast(e, StringType), start, len) - // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1d71c1b4b0c7c..4fd1bc4dd642d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -86,6 +85,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def foldable: Boolean = left.foldable && right.foldable + override def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"($left $symbol $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 140ccd8d3796f..c7a37ad966df6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -74,14 +74,12 @@ abstract class BinaryArithmetic extends BinaryExpression { type EvaluatedType = Any - def nullable: Boolean = left.nullable || right.nullable - override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType && !DecimalType.isFixed(left.dataType) - def dataType: DataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 9cb00cb2732ff..26c38c56c04f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -70,16 +70,14 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } -abstract class BinaryPredicate extends BinaryExpression with Predicate { - self: Product => - override def nullable: Boolean = left.nullable || right.nullable -} -case class Not(child: Expression) extends UnaryExpression with Predicate { +case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" + override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) + override def eval(input: Row): Any = { child.eval(input) match { case null => null @@ -120,7 +118,11 @@ case class InSet(value: Expression, hset: Set[Any]) } } -case class And(left: Expression, right: Expression) extends BinaryPredicate { +case class And(left: Expression, right: Expression) + extends BinaryExpression with Predicate with ExpectsInputTypes { + + override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def symbol: String = "&&" override def eval(input: Row): Any = { @@ -142,7 +144,11 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { } } -case class Or(left: Expression, right: Expression) extends BinaryPredicate { +case class Or(left: Expression, right: Expression) + extends BinaryExpression with Predicate with ExpectsInputTypes { + + override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def symbol: String = "||" override def eval(input: Row): Any = { @@ -164,7 +170,7 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate { } } -abstract class BinaryComparison extends BinaryPredicate { +abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index d597bf7ce756a..d6f23df30ffb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -22,7 +22,7 @@ import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types._ -trait StringRegexExpression { +trait StringRegexExpression extends ExpectsInputTypes { self: BinaryExpression => type EvaluatedType = Any @@ -32,6 +32,7 @@ trait StringRegexExpression { override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType + override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -57,11 +58,11 @@ trait StringRegexExpression { if(r == null) { null } else { - val regex = pattern(r.asInstanceOf[UTF8String].toString) + val regex = pattern(r.asInstanceOf[UTF8String].toString()) if(regex == null) { null } else { - matches(regex, l.asInstanceOf[UTF8String].toString) + matches(regex, l.asInstanceOf[UTF8String].toString()) } } } @@ -110,7 +111,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } -trait CaseConversionExpression { +trait CaseConversionExpression extends ExpectsInputTypes { self: UnaryExpression => type EvaluatedType = Any @@ -118,8 +119,9 @@ trait CaseConversionExpression { def convert(v: UTF8String): UTF8String override def foldable: Boolean = child.foldable - def nullable: Boolean = child.nullable - def dataType: DataType = StringType + override def nullable: Boolean = child.nullable + override def dataType: DataType = StringType + override def expectedChildTypes: Seq[DataType] = Seq(StringType) override def eval(input: Row): Any = { val evaluated = child.eval(input) @@ -136,7 +138,7 @@ trait CaseConversionExpression { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase + override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" } @@ -146,21 +148,21 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase + override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" } /** A base trait for functions that compare two strings, returning a boolean. */ trait StringComparison { - self: BinaryPredicate => + self: BinaryExpression => + + def compare(l: UTF8String, r: UTF8String): Boolean override type EvaluatedType = Any override def nullable: Boolean = left.nullable || right.nullable - def compare(l: UTF8String, r: UTF8String): Boolean - override def eval(input: Row): Any = { val leftEval = left.eval(input) if(leftEval == null) { @@ -181,31 +183,35 @@ trait StringComparison { * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryPredicate with StringComparison { + extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) + override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) } /** * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryPredicate with StringComparison { + extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) + override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) } /** * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryPredicate with StringComparison { + extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) + override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) } /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. */ -case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression { +case class Substring(str: Expression, pos: Expression, len: Expression) + extends Expression with ExpectsInputTypes { type EvaluatedType = Any @@ -219,6 +225,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends if (str.dataType == BinaryType) str.dataType else StringType } + override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) + override def children: Seq[Expression] = str :: pos :: len :: Nil @inline @@ -258,7 +266,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val (st, end) = slicePos(start, length, () => ba.length) ba.slice(st, end) case s: UTF8String => - val (st, end) = slicePos(start, length, () => s.length) + val (st, end) = slicePos(start, length, () => s.length()) s.slice(st, end) } }