Skip to content

Commit

Permalink
Merge pull request #14 from marmbrus/castingAndTypes
Browse files Browse the repository at this point in the history
Make casting semantics more like Hive's
  • Loading branch information
marmbrus committed Jan 17, 2014
2 parents b21f803 + b2a1ec5 commit b4adb0f
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 94 deletions.
2 changes: 2 additions & 0 deletions src/main/scala/catalyst/analysis/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
Batch("Aggregation", Once,
GlobalAggregates),
Batch("Type Coersion", fixedPoint,
StringToIntegralCasts,
BooleanCasts,
PromoteNumericTypes,
PromoteStrings,
ConvertNaNs,
Expand Down
30 changes: 27 additions & 3 deletions src/main/scala/catalyst/analysis/typeCoercion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,30 @@ object BooleanComparisons extends Rule[LogicalPlan] {
}
}

/**
* Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since
* the JVM does not consider Booleans to be numeric types.
*/
object BooleanCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Cast(e, BooleanType) => Not(Equals(e, Literal(0)))
case Cast(e, dataType) if e.dataType == BooleanType =>
Cast(If(e, Literal(1), Literal(0)), dataType)
}
}

/**
* When encountering a cast from a string representing a valid fractional number to an integral type
* the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the
* truncated version of this number.
*/
object StringToIntegralCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Cast(e @ StringType(), t: IntegralType) =>
Cast(Cast(e, DecimalType), t)
}
}

/**
* This ensure that the types for various functions are as expected. Most of these rules are
* actually Hive specific.
Expand All @@ -162,9 +186,9 @@ object FunctionArgumentConversion extends Rule[LogicalPlan] {
case e if !e.childrenResolved => e

// Promote SUM to largest types to prevent overflows.
// TODO: This is enough to make most of the tests pass, but we really need a full set of our own
// to really ensure compatibility.
case Sum(e) if e.dataType == IntegerType => Sum(Cast(e, LongType))
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))

}
}
2 changes: 2 additions & 0 deletions src/main/scala/catalyst/execution/FunctionRegistry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ abstract class HiveUdf extends Expression with ImplementedUdf with Logging {
case l: LongWritable => l.get
case d: DoubleWritable => d.get()
case d: org.apache.hadoop.hive.serde2.io.DoubleWritable => d.get
case s: org.apache.hadoop.hive.serde2.io.ShortWritable => s.get
case b: BooleanWritable => b.get()
case b: org.apache.hadoop.hive.serde2.io.ByteWritable => b.get
case list: java.util.List[_] => list.map(unwrap)
case p: java.lang.Short => p
case p: java.lang.Long => p
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/catalyst/execution/SharkInstance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ abstract class SharkInstance extends Logging {
s"""== Logical Plan ==
|${stringOrError(analyzed)}
|== Physical Plan ==
|${stringOrError(sharkPlan)}
|${stringOrError(executedPlan)}
""".stripMargin.trim
}

Expand Down
112 changes: 40 additions & 72 deletions src/main/scala/catalyst/expressions/Evaluate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,109 +25,73 @@ object Evaluate extends Logging {
null
else
e.dataType match {
case IntegerType =>
f.asInstanceOf[(Numeric[Int], Int) => Int](
implicitly[Numeric[Int]], eval(e).asInstanceOf[Int])
case DoubleType =>
f.asInstanceOf[(Numeric[Double], Double) => Double](
implicitly[Numeric[Double]], eval(e).asInstanceOf[Double])
case LongType =>
f.asInstanceOf[(Numeric[Long], Long) => Long](
implicitly[Numeric[Long]], eval(e).asInstanceOf[Long])
case FloatType =>
f.asInstanceOf[(Numeric[Float], Float) => Float](
implicitly[Numeric[Float]], eval(e).asInstanceOf[Float])
case ByteType =>
f.asInstanceOf[(Numeric[Byte], Byte) => Byte](
implicitly[Numeric[Byte]], eval(e).asInstanceOf[Byte])
case ShortType =>
f.asInstanceOf[(Numeric[Short], Short) => Short](
implicitly[Numeric[Short]], eval(e).asInstanceOf[Short])
case n: NumericType =>
val castedFunction = f.asInstanceOf[(Numeric[n.JvmType], n.JvmType) => n.JvmType]
castedFunction(n.numeric, eval(e).asInstanceOf[n.JvmType])
case other => sys.error(s"Type $other does not support numeric operations")
}
}

@inline
def n2(e1: Expression, e2: Expression, f: ((Numeric[Any], Any, Any) => Any)): Any = {
if (e1.dataType != e2.dataType)
throw new OptimizationException(e, s"Data types do not match ${e1.dataType} != ${e2.dataType}")
throw new OptimizationException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}")

val evalE1 = eval(e1)
val evalE2 = eval(e2)
if (evalE1 == null || evalE2 == null)
null
else
e1.dataType match {
case IntegerType =>
f.asInstanceOf[(Numeric[Int], Int, Int) => Int](
implicitly[Numeric[Int]], evalE1.asInstanceOf[Int], evalE2.asInstanceOf[Int])
case DoubleType =>
f.asInstanceOf[(Numeric[Double], Double, Double) => Double](
implicitly[Numeric[Double]], evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double])
case LongType =>
f.asInstanceOf[(Numeric[Long], Long, Long) => Long](
implicitly[Numeric[Long]], evalE1.asInstanceOf[Long], evalE2.asInstanceOf[Long])
case FloatType =>
f.asInstanceOf[(Numeric[Float], Float, Float) => Float](
implicitly[Numeric[Float]], evalE1.asInstanceOf[Float], evalE2.asInstanceOf[Float])
case ByteType =>
f.asInstanceOf[(Numeric[Byte], Byte, Byte) => Byte](
implicitly[Numeric[Byte]], evalE1.asInstanceOf[Byte], evalE2.asInstanceOf[Byte])
case ShortType =>
f.asInstanceOf[(Numeric[Short], Short, Short) => Short](
implicitly[Numeric[Short]], evalE1.asInstanceOf[Short], evalE2.asInstanceOf[Short])
case n: NumericType =>
f.asInstanceOf[(Numeric[n.JvmType], n.JvmType, n.JvmType) => Int](
n.numeric, evalE1.asInstanceOf[n.JvmType], evalE2.asInstanceOf[n.JvmType])
case other => sys.error(s"Type $other does not support numeric operations")
}
}

@inline
def f2(e1: Expression, e2: Expression, f: ((Fractional[Any], Any, Any) => Any)): Any = {
if (e1.dataType != e2.dataType)
throw new OptimizationException(e, s"Data types do not match ${e1.dataType} != ${e2.dataType}")
throw new OptimizationException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}")

val evalE1 = eval(e1)
val evalE2 = eval(e2)
if (evalE1 == null || evalE2 == null)
null
else
e1.dataType match {
case DoubleType =>
f.asInstanceOf[(Fractional[Double], Double, Double) => Double](
implicitly[Fractional[Double]], evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double])
case FloatType =>
f.asInstanceOf[(Fractional[Float], Float, Float) => Float](
implicitly[Fractional[Float]], evalE1.asInstanceOf[Float], evalE2.asInstanceOf[Float])
case f: FractionalType =>
f.asInstanceOf[(Fractional[f.JvmType], f.JvmType, f.JvmType) => f.JvmType](
f.fractional, evalE1.asInstanceOf[f.JvmType], evalE2.asInstanceOf[f.JvmType])
case other => sys.error(s"Type $other does not support fractional operations")
}
}

@inline
def i2(e1: Expression, e2: Expression, f: ((Integral[Any], Any, Any) => Any)): Any = {
if (e1.dataType != e2.dataType) throw new OptimizationException(e, s"Data types do not match ${e1.dataType} != ${e2.dataType}")
if (e1.dataType != e2.dataType)
throw new OptimizationException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}")
val evalE1 = eval(e1)
val evalE2 = eval(e2)
if (evalE1 == null || evalE2 == null)
null
else
e1.dataType match {
case IntegerType =>
f.asInstanceOf[(Integral[Int], Int, Int) => Int](
implicitly[Integral[Int]], evalE1.asInstanceOf[Int], evalE2.asInstanceOf[Int])
case LongType =>
f.asInstanceOf[(Integral[Long], Long, Long) => Long](
implicitly[Integral[Long]], evalE1.asInstanceOf[Long], evalE2.asInstanceOf[Long])
case ByteType =>
f.asInstanceOf[(Integral[Byte], Byte, Byte) => Byte](
implicitly[Integral[Byte]], evalE1.asInstanceOf[Byte], evalE2.asInstanceOf[Byte])
case ShortType =>
f.asInstanceOf[(Integral[Short], Short, Short) => Short](
implicitly[Integral[Short]], evalE1.asInstanceOf[Short], evalE2.asInstanceOf[Short])
case i: IntegralType =>
f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType](
i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
case other => sys.error(s"Type $other does not support numeric operations")
}
}

@inline def castOrNull[A](f: => A) =
try f catch { case _: java.lang.NumberFormatException => null }
@inline def castOrNull[A](e: Expression, f: String => A) =
try {
eval(e) match {
case null => null
case s: String => f(s)
}
} catch { case _: java.lang.NumberFormatException => null }

val result = e match {
case Literal(v, _) => v
Expand All @@ -142,13 +106,16 @@ object Evaluate extends Logging {
case Add(l, r) => n2(l,r, _.plus(_, _))
case Subtract(l, r) => n2(l,r, _.minus(_, _))
case Multiply(l, r) => n2(l,r, _.times(_, _))
// Divide & remainder implementation are different for fractional and integral dataTypes.
case Divide(l, r) if (l.dataType == DoubleType || l.dataType == FloatType) => f2(l,r, _.div(_, _))
case Divide(l, r) => i2(l,r, _.quot(_, _))
// Divide implementation are different for fractional and integral dataTypes.
case Divide(l @ FractionalType(), r) => f2(l,r, _.div(_, _))
case Divide(l @ IntegralType(), r) => i2(l,r, _.quot(_, _))
// Remainder is only allowed on Integral types.
case Remainder(l, r) => i2(l,r, _.rem(_, _))
case UnaryMinus(child) => n1(child, _.negate(_))

/* Control Flow */
case If(e, t, f) => if (eval(e).asInstanceOf[Boolean]) eval(t) else eval(f)

/* Comparisons */
case Equals(l, r) =>
val left = eval(l)
Expand Down Expand Up @@ -197,16 +164,14 @@ object Evaluate extends Logging {
}

// String => Numeric Types
case Cast(e, IntegerType) if e.dataType == StringType =>
eval(e) match {
case null => null
case s: String => castOrNull(s.toInt)
}
case Cast(e, DoubleType) if e.dataType == StringType =>
eval(e) match {
case null => null
case s: String => castOrNull(s.toDouble)
}
case Cast(e @ StringType(), IntegerType) => castOrNull(e, _.toInt)
case Cast(e @ StringType(), DoubleType) => castOrNull(e, _.toDouble)
case Cast(e @ StringType(), FloatType) => castOrNull(e, _.toFloat)
case Cast(e @ StringType(), LongType) => castOrNull(e, _.toLong)
case Cast(e @ StringType(), ShortType) => castOrNull(e, _.toShort)
case Cast(e @ StringType(), ByteType) => castOrNull(e, _.toByte)
case Cast(e @ StringType(), DecimalType) => castOrNull(e, BigDecimal(_))

// Boolean conversions
case Cast(e, ByteType) if e.dataType == BooleanType =>
eval(e) match {
Expand Down Expand Up @@ -263,6 +228,9 @@ object Evaluate extends Logging {
case implementedFunction: ImplementedUdf =>
implementedFunction.evaluate(implementedFunction.children.map(eval))

case a: Attribute =>
throw new OptimizationException(a,
"Unable to evaluate unbound reference without access to the input schema.")
case other => throw new OptimizationException(other, "evaluation not implemented")
}

Expand Down
17 changes: 17 additions & 0 deletions src/main/scala/catalyst/expressions/predicates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package catalyst
package expressions

import types._
import catalyst.analysis.UnresolvedException

trait Predicate extends Expression {
self: Product =>
Expand Down Expand Up @@ -74,3 +75,19 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
override def foldable = child.foldable
def nullable = false
}

case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
extends Expression {

def children = predicate :: trueValue :: falseValue :: Nil
def nullable = trueValue.nullable || falseValue.nullable
def references = children.flatMap(_.references).toSet
override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
def dataType = {
if (!resolved) {
throw new UnresolvedException(
this, s"Invalid types: ${trueValue.dataType}, ${falseValue.dataType}")
}
trueValue.dataType
}
}
13 changes: 11 additions & 2 deletions src/main/scala/catalyst/optimizer/Optimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ object Optimize extends RuleExecutor[LogicalPlan] {
EliminateSubqueries) ::
Batch("ConstantFolding", Once,
ConstantFolding,
BooleanSimplification
) :: Nil
BooleanSimplification,
SimplifyCasts) :: Nil
}

/**
Expand Down Expand Up @@ -68,4 +68,13 @@ object BooleanSimplification extends Rule[LogicalPlan] {
}
}
}
}

/**
* Removes casts that are unnecessary because the input is already the correct type.
*/
object SimplifyCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Cast(e, dataType) if e.dataType == dataType => e
}
}
Loading

0 comments on commit b4adb0f

Please sign in to comment.