Skip to content

Commit

Permalink
Revert "[SPARK-8770][SQL] Create BinaryOperator abstract class."
Browse files Browse the repository at this point in the history
This reverts commit 2727789.
  • Loading branch information
rxin committed Jul 1, 2015
1 parent 2727789 commit 3a342de
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ object HiveTypeCoercion {
* Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to
* the appropriate numeric equivalent.
*/
// TODO: remove this rule and make Cast handle Nan.
object ConvertNaNs extends Rule[LogicalPlan] {
private val StringNaN = Literal("NaN")

Expand All @@ -160,19 +159,19 @@ object HiveTypeCoercion {
case e if !e.childrenResolved => e

/* Double Conversions */
case b @ BinaryOperator(StringNaN, right @ DoubleType()) =>
case b @ BinaryExpression(StringNaN, right @ DoubleType()) =>
b.makeCopy(Array(Literal(Double.NaN), right))
case b @ BinaryOperator(left @ DoubleType(), StringNaN) =>
case b @ BinaryExpression(left @ DoubleType(), StringNaN) =>
b.makeCopy(Array(left, Literal(Double.NaN)))

/* Float Conversions */
case b @ BinaryOperator(StringNaN, right @ FloatType()) =>
case b @ BinaryExpression(StringNaN, right @ FloatType()) =>
b.makeCopy(Array(Literal(Float.NaN), right))
case b @ BinaryOperator(left @ FloatType(), StringNaN) =>
case b @ BinaryExpression(left @ FloatType(), StringNaN) =>
b.makeCopy(Array(left, Literal(Float.NaN)))

/* Use float NaN by default to avoid unnecessary type widening */
case b @ BinaryOperator(left @ StringNaN, StringNaN) =>
case b @ BinaryExpression(left @ StringNaN, StringNaN) =>
b.makeCopy(Array(left, Literal(Float.NaN)))
}
}
Expand Down Expand Up @@ -246,12 +245,12 @@ object HiveTypeCoercion {

Union(newLeft, newRight)

// Also widen types for BinaryOperator.
// Also widen types for BinaryExpressions.
case q: LogicalPlan => q transformExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
Expand Down Expand Up @@ -479,7 +478,7 @@ object HiveTypeCoercion {

// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
(left.dataType, right.dataType) match {
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ abstract class Expression extends TreeNode[Expression] {
*/
def childrenResolved: Boolean = children.forall(_.resolved)

/**
* Returns a string representation of this expression that does not have developer centric
* debugging information like the expression id.
*/
def prettyString: String = {
transform {
case a: AttributeReference => PrettyAttribute(a.name)
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}

/**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
Expand All @@ -143,40 +154,71 @@ abstract class Expression extends TreeNode[Expression] {
* Note: it's not valid to call this method until `childrenResolved == true`.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
self: Product =>

def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol")

override def foldable: Boolean = left.foldable && right.foldable

override def nullable: Boolean = left.nullable || right.nullable

override def toString: String = s"($left $symbol $right)"

/**
* Returns a user-facing string representation of this expression's name.
* This should usually match the name of the function in SQL.
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts two variable names and returns Java code to compute the output.
*/
def prettyName: String = getClass.getSimpleName.toLowerCase
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
s"$result = ${f(eval1, eval2)};"
})
}

/**
* Returns a user-facing string representation of this expression, i.e. does not have developer
* centric debugging information like the expression id.
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*/
def prettyString: String = {
transform {
case a: AttributeReference => PrettyAttribute(a.name)
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
protected def nullSafeCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String, String) => String): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
$resultCode
} else {
${ev.isNull} = true;
}
}
"""
}

override def toString: String = prettyName + children.mkString("(", ",", ")")
}

private[sql] object BinaryExpression {
def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right))
}

/**
* A leaf expression, i.e. one without any child expressions.
*/
abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
self: Product =>
}


/**
* An expression with one input and one output. The output is by default evaluated to null
* if the input is evaluated to null.
*/
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>

Expand Down Expand Up @@ -223,76 +265,39 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
}
}


/**
* An expression with two inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
* An trait that gets mixin to define the expected input types of an expression.
*/
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
self: Product =>

override def foldable: Boolean = left.foldable && right.foldable

override def nullable: Boolean = left.nullable || right.nullable
trait ExpectsInputTypes { self: Expression =>

/**
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
* Expected input types from child expressions. The i-th position in the returned seq indicates
* the type requirement for the i-th child.
*
* @param f accepts two variable names and returns Java code to compute the output.
* The possible values at each position are:
* 1. a specific data type, e.g. LongType, StringType.
* 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
* 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
s"$result = ${f(eval1, eval2)};"
})
}
def inputTypes: Seq[Any]

/**
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*/
protected def nullSafeCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String, String) => String): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
$resultCode
} else {
${ev.isNull} = true;
}
}
"""
override def checkInputDataTypes(): TypeCheckResult = {
// We will do the type checking in `HiveTypeCoercion`, so always returning success here.
TypeCheckResult.TypeCheckSuccess
}
}


/**
* An expression that has two inputs that are expected to the be same type. If the two inputs have
* different types, the analyzer will find the tightest common type and do the proper type casting.
* Expressions that require a specific `DataType` as input should implement this trait
* so that the proper type conversions can be performed in the analyzer.
*/
abstract class BinaryOperator extends BinaryExpression {
self: Product =>
trait AutoCastInputTypes { self: Expression =>

def symbol: String
def inputTypes: Seq[DataType]

override def toString: String = s"($left $symbol $right)"
}


private[sql] object BinaryOperator {
def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right))
override def checkInputDataTypes(): TypeCheckResult = {
// We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
// so type mismatch error won't be reported here, but for underling `Cast`s.
TypeCheckResult.TypeCheckSuccess
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi

override def nullable: Boolean = true

override def toString: String = s"UDF(${children.mkString(",")})"
override def toString: String = s"scalaUDF(${children.mkString(",")})"

// scalastyle:off

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[

override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"MAX($child)"

override def asPartial: SplitEvaluation = {
val partialMax = Alias(Max(child), "PartialMax")()
Expand Down Expand Up @@ -161,6 +162,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod

override def nullable: Boolean = false
override def dataType: LongType.type = LongType
override def toString: String = s"COUNT($child)"

override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
Expand Down Expand Up @@ -399,6 +401,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
DoubleType
}

override def toString: String = s"AVG($child)"

override def asPartial: SplitEvaluation = {
child.dataType match {
case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
Expand Down Expand Up @@ -490,6 +494,8 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
child.dataType
}

override def toString: String = s"SUM($child)"

override def asPartial: SplitEvaluation = {
child.dataType match {
case DecimalType.Fixed(_, _) =>
Expand Down
Loading

0 comments on commit 3a342de

Please sign in to comment.