Skip to content

Commit

Permalink
pushing codegen into Expression
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 4, 2015
1 parent 2bcdf8c commit 593d617
Show file tree
Hide file tree
Showing 16 changed files with 650 additions and 587 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees

Expand All @@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def qualifiers: Seq[String] = throw new UnsupportedOperationException

override def exprId: ExprId = throw new UnsupportedOperationException

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
s"""
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
final ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
${ctx.defaultPrimitive(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
}

object BindReferences extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -433,6 +434,42 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val evaluated = child.eval(input)
if (evaluated == null) null else cast(evaluated)
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = this match {

case Cast(child @ BinaryType(), StringType) =>
castOrNull (ctx, ev, c =>
s"new org.apache.spark.sql.types.UTF8String().set($c)",
StringType)

case Cast(child @ DateType(), StringType) =>
castOrNull(ctx, ev, c =>
s"""new org.apache.spark.sql.types.UTF8String().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
StringType)

case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c?1:0)", dt)

case Cast(child @ DecimalType(), IntegerType) =>
castOrNull(ctx, ev, c => s"($c).toInt()", IntegerType)

case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"($c).to${ctx.termForType(dt)}()", dt)

case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c)", dt)

// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
castOrNull(ctx, ev, c =>
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))",
StringType)

case other =>
super.genSource(ctx, ev)
}
}

object Cast {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -51,6 +52,51 @@ abstract class Expression extends TreeNode[Expression] {
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): Any

/**
* Returns an [[EvaluatedExpression]], which contains Java source code that
* can be used to generate the result of evaluating the expression on an input row.
* @param ctx a [[CodeGenContext]]
*/
def gen(ctx: CodeGenContext): EvaluatedExpression = {
val nullTerm = ctx.freshName("nullTerm")
val primitiveTerm = ctx.freshName("primitiveTerm")
val objectTerm = ctx.freshName("objectTerm")
val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm)
ve.code = genSource(ctx, ve)

// Only inject debugging code if debugging is turned on.
// val debugCode =
// if (debugLogging) {
// val localLogger = log
// val localLoggerTree = reify { localLogger }
// s"""
// $localLoggerTree.debug(
// ${this.toString} + ": " + (if (${ev.nullTerm}) "null" else ${ev.primitiveTerm}.toString))
// """
// } else {
// ""
// }

ve
}

/**
* Returns Java source code for this expression
*/
def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
val e = this.asInstanceOf[Expression]
ctx.references += e
s"""
/* expression: ${this} */
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
${ctx.primitiveForType(e.dataType)} ${ev.primitiveTerm} =
${ctx.defaultPrimitive(e.dataType)};
if (!${ev.nullTerm}) ${ev.primitiveTerm} =
(${ctx.termForType(e.dataType)})${ev.objectTerm};
"""
}

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and input data types checking passed, and `false` if it still contains any unresolved
Expand Down Expand Up @@ -116,6 +162,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def nullable: Boolean = left.nullable || right.nullable

override def toString: String = s"($left $symbol $right)"


/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f a function from two primitive term names to a tree that evaluates them.
*/
def evaluate(ctx: CodeGenContext,
ev: EvaluatedExpression,
f: (String, String) => String): String =
evaluateAs(left.dataType)(ctx, ev, f)

def evaluateAs(resultType: DataType)(ctx: CodeGenContext,
ev: EvaluatedExpression,
f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
}

val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)

eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = ${eval1.nullTerm} || ${eval2.nullTerm};
${ctx.primitiveForType(resultType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(resultType)};
if(!${ev.nullTerm}) {
${ev.primitiveTerm} = (${ctx.primitiveForType(resultType)})($resultCode);
}
"""
}
}

abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
Expand All @@ -124,6 +205,19 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]

abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
def castOrNull(ctx: CodeGenContext,
ev: EvaluatedExpression,
f: String => String, dataType: DataType): String = {
val eval = child.gen(ctx)
eval.code +
s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
}
"""
}
}

// TODO Semantically we probably not need GroupExpression
Expand Down
Loading

0 comments on commit 593d617

Please sign in to comment.