Skip to content

Commit

Permalink
[SPARK-7266] Add ExpectsInputTypes to expressions when possible.
Browse files Browse the repository at this point in the history
This should gives us better analysis time error messages (rather than runtime) and automatic type casting.

Author: Reynold Xin <rxin@databricks.com>

Closes apache#5796 from rxin/expected-input-types and squashes the following commits:

c900760 [Reynold Xin] [SPARK-7266] Add ExpectsInputTypes to expressions when possible.
  • Loading branch information
rxin authored and jeanlyn committed Jun 12, 2015
1 parent f0dac74 commit 3b64cda
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,37 +239,43 @@ trait HiveTypeCoercion {
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))

// we should cast all timestamp/date/string compare into string compare
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == DateType =>
case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType == DateType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == StringType =>
case p: BinaryComparison if p.left.dataType == DateType &&
p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == TimestampType =>
case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType == TimestampType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == StringType =>
case p: BinaryComparison if p.left.dataType == TimestampType &&
p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == DateType =>
case p: BinaryComparison if p.left.dataType == TimestampType &&
p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == TimestampType =>
case p: BinaryComparison if p.left.dataType == DateType &&
p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))

case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
case p: BinaryComparison if p.left.dataType != StringType &&
p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))

case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
case i @ In(a, b) if a.dataType == DateType &&
b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
case i @ In(a, b) if a.dataType == TimestampType &&
b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
case i @ In(a, b) if a.dataType == DateType &&
b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
case i @ In(a, b) if a.dataType == TimestampType &&
b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))

case Sum(e) if e.dataType == StringType =>
Expand Down Expand Up @@ -420,19 +426,19 @@ trait HiveTypeCoercion {
)

case LessThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

// Promote integers inside a binary expression with fixed-precision decimals to decimals,
Expand Down Expand Up @@ -481,8 +487,8 @@ trait HiveTypeCoercion {
// No need to change the EqualNullSafe operators, too
case e: EqualNullSafe => e
// Otherwise turn them to Byte types so that there exists and ordering.
case p: BinaryComparison
if p.left.dataType == BooleanType && p.right.dataType == BooleanType =>
case p: BinaryComparison if p.left.dataType == BooleanType &&
p.right.dataType == BooleanType =>
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
}
}
Expand Down Expand Up @@ -564,10 +570,6 @@ trait HiveTypeCoercion {
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))

// Compatible with Hive
case Substring(e, start, len) if e.dataType != StringType =>
Substring(Cast(e, StringType), start, len)

// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -86,6 +85,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

override def foldable: Boolean = left.foldable && right.foldable

override def nullable: Boolean = left.nullable || right.nullable

override def toString: String = s"($left $symbol $right)"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,12 @@ abstract class BinaryArithmetic extends BinaryExpression {

type EvaluatedType = Any

def nullable: Boolean = left.nullable || right.nullable

override lazy val resolved =
left.resolved && right.resolved &&
left.dataType == right.dataType &&
!DecimalType.isFixed(left.dataType)

def dataType: DataType = {
override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this,
s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,14 @@ trait PredicateHelper {
expr.references.subsetOf(plan.outputSet)
}

abstract class BinaryPredicate extends BinaryExpression with Predicate {
self: Product =>
override def nullable: Boolean = left.nullable || right.nullable
}

case class Not(child: Expression) extends UnaryExpression with Predicate {
case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def toString: String = s"NOT $child"

override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)

override def eval(input: Row): Any = {
child.eval(input) match {
case null => null
Expand Down Expand Up @@ -120,7 +118,11 @@ case class InSet(value: Expression, hset: Set[Any])
}
}

case class And(left: Expression, right: Expression) extends BinaryPredicate {
case class And(left: Expression, right: Expression)
extends BinaryExpression with Predicate with ExpectsInputTypes {

override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)

override def symbol: String = "&&"

override def eval(input: Row): Any = {
Expand All @@ -142,7 +144,11 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
}
}

case class Or(left: Expression, right: Expression) extends BinaryPredicate {
case class Or(left: Expression, right: Expression)
extends BinaryExpression with Predicate with ExpectsInputTypes {

override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)

override def symbol: String = "||"

override def eval(input: Row): Any = {
Expand All @@ -164,7 +170,7 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate {
}
}

abstract class BinaryComparison extends BinaryPredicate {
abstract class BinaryComparison extends BinaryExpression with Predicate {
self: Product =>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.regex.Pattern
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.types._

trait StringRegexExpression {
trait StringRegexExpression extends ExpectsInputTypes {
self: BinaryExpression =>

type EvaluatedType = Any
Expand All @@ -32,6 +32,7 @@ trait StringRegexExpression {

override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)

// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
Expand All @@ -57,11 +58,11 @@ trait StringRegexExpression {
if(r == null) {
null
} else {
val regex = pattern(r.asInstanceOf[UTF8String].toString)
val regex = pattern(r.asInstanceOf[UTF8String].toString())
if(regex == null) {
null
} else {
matches(regex, l.asInstanceOf[UTF8String].toString)
matches(regex, l.asInstanceOf[UTF8String].toString())
}
}
}
Expand Down Expand Up @@ -110,16 +111,17 @@ case class RLike(left: Expression, right: Expression)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}

trait CaseConversionExpression {
trait CaseConversionExpression extends ExpectsInputTypes {
self: UnaryExpression =>

type EvaluatedType = Any

def convert(v: UTF8String): UTF8String

override def foldable: Boolean = child.foldable
def nullable: Boolean = child.nullable
def dataType: DataType = StringType
override def nullable: Boolean = child.nullable
override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)

override def eval(input: Row): Any = {
val evaluated = child.eval(input)
Expand All @@ -136,7 +138,7 @@ trait CaseConversionExpression {
*/
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {

override def convert(v: UTF8String): UTF8String = v.toUpperCase
override def convert(v: UTF8String): UTF8String = v.toUpperCase()

override def toString: String = s"Upper($child)"
}
Expand All @@ -146,21 +148,21 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
*/
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {

override def convert(v: UTF8String): UTF8String = v.toLowerCase
override def convert(v: UTF8String): UTF8String = v.toLowerCase()

override def toString: String = s"Lower($child)"
}

/** A base trait for functions that compare two strings, returning a boolean. */
trait StringComparison {
self: BinaryPredicate =>
self: BinaryExpression =>

def compare(l: UTF8String, r: UTF8String): Boolean

override type EvaluatedType = Any

override def nullable: Boolean = left.nullable || right.nullable

def compare(l: UTF8String, r: UTF8String): Boolean

override def eval(input: Row): Any = {
val leftEval = left.eval(input)
if(leftEval == null) {
Expand All @@ -181,31 +183,35 @@ trait StringComparison {
* A function that returns true if the string `left` contains the string `right`.
*/
case class Contains(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}

/**
* A function that returns true if the string `left` starts with the string `right`.
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}

/**
* A function that returns true if the string `left` ends with the string `right`.
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}

/**
* A function that takes a substring of its first argument starting at a given position.
* Defined for String and Binary types.
*/
case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression {
case class Substring(str: Expression, pos: Expression, len: Expression)
extends Expression with ExpectsInputTypes {

type EvaluatedType = Any

Expand All @@ -219,6 +225,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
if (str.dataType == BinaryType) str.dataType else StringType
}

override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)

override def children: Seq[Expression] = str :: pos :: len :: Nil

@inline
Expand Down Expand Up @@ -258,7 +266,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
val (st, end) = slicePos(start, length, () => ba.length)
ba.slice(st, end)
case s: UTF8String =>
val (st, end) = slicePos(start, length, () => s.length)
val (st, end) = slicePos(start, length, () => s.length())
s.slice(st, end)
}
}
Expand Down

0 comments on commit 3b64cda

Please sign in to comment.