Skip to content

Commit

Permalink
[SPARK-8117] [SQL] Push codegen implementation into each Expression
Browse files Browse the repository at this point in the history
This PR move codegen implementation of expressions into Expression class itself, make it easy to manage.

It introduces two APIs in Expression:
```
def gen(ctx: CodeGenContext): GeneratedExpressionCode
def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code
```

gen(ctx) will call genSource(ctx, ev) to generate Java source code for the current expression. A expression needs to override genSource().

Here are the types:
```
type Term String
type Code String

/**
 * Java source for evaluating an [[Expression]] given a [[Row]] of input.
 */
case class GeneratedExpressionCode(var code: Code,
                               nullTerm: Term,
                               primitiveTerm: Term,
                               objectTerm: Term)
/**
 * A context for codegen, which is used to bookkeeping the expressions those are not supported
 * by codegen, then they are evaluated directly. The unsupported expression is appended at the
 * end of `references`, the position of it is kept in the code, used to access and evaluate it.
 */
class CodeGenContext {
  /**
   * Holding all the expressions those do not support codegen, will be evaluated directly.
   */
  val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]()
}
```

This is basically apache#6660, but fixed style violation and compilation failure.

Author: Davies Liu <davies@databricks.com>
Author: Reynold Xin <rxin@databricks.com>

Closes apache#6690 from rxin/codegen and squashes the following commits:

e1368c2 [Reynold Xin] Fixed tests.
73db80e [Reynold Xin] Fixed compilation failure.
19d6435 [Reynold Xin] Fixed style violation.
9adaeaf [Davies Liu] address comments
f42c732 [Davies Liu] improve coverage and tests
bad6828 [Davies Liu] address comments
e03edaa [Davies Liu] consts fold
86fac2c [Davies Liu] fix style
02262c9 [Davies Liu] address comments
b5d3617 [Davies Liu] Merge pull request #5 from rxin/codegen
48c454f [Reynold Xin] Some code gen update.
2344bc0 [Davies Liu] fix test
12ff88a [Davies Liu] fix build
c5fb514 [Davies Liu] rename
8c6d82d [Davies Liu] update docs
b145047 [Davies Liu] fix style
e57959d [Davies Liu] add type alias
3ff25f8 [Davies Liu] refactor
593d617 [Davies Liu] pushing codegen into Expression
  • Loading branch information
Davies Liu authored and rxin committed Jun 7, 2015
1 parent b127ff8 commit 5e7b6b6
Show file tree
Hide file tree
Showing 23 changed files with 1,036 additions and 718 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees

Expand All @@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def qualifiers: Seq[String] = throw new UnsupportedOperationException

override def exprId: ExprId = throw new UnsupportedOperationException

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
}

object BindReferences extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -433,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val evaluated = child.eval(input)
if (evaluated == null) null else cast(evaluated)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
// TODO(cg): Add support for more data types.
(child.dataType, dataType) match {

case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case (TimestampType, StringType) =>
super.genCode(ctx, ev)
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")

// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
super.genCode(ctx, ev)

case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c.isZero()")
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")

case (_: DecimalType, IntegerType) =>
defineCodeGen(ctx, ev, c => s"($c).toInt()")
case (_: DecimalType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
case (_: NumericType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")

case other =>
super.genCode(ctx, ev)
}
}
}

object Cast {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -51,6 +52,44 @@ abstract class Expression extends TreeNode[Expression] {
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): Any

/**
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
* can be used to generate the result of evaluating the expression on an input row.
*
* @param ctx a [[CodeGenContext]]
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
ve
}

/**
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
* implementations should override this to do actual code generation.
*
* @param ctx a [[CodeGenContext]]
* @param ev an [[GeneratedExpressionCode]] with unique terms.
* @return Java source code
*/
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.isNull} = ${objectTerm} == null;
${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm};
}
"""
}

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and input data types checking passed, and `false` if it still contains any unresolved
Expand Down Expand Up @@ -116,6 +155,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def nullable: Boolean = left.nullable || right.nullable

override def toString: String = s"($left $symbol $right)"

/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts two variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (Term, Term) => Code): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
}

val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitive, eval2.primitive)

s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if(!${eval2.isNull}) {
${ev.primitive} = $resultCode;
} else {
${ev.isNull} = true;
}
}
"""
}
}

private[sql] object BinaryExpression {
Expand All @@ -128,6 +202,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]

abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
* defineCodeGen(ctx, ev, c => s"!($c)")
* }}}
*
* @param f function that accepts a variable name and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: Term => Code): Code = {
val eval = child.gen(ctx)
// reuse the previous isNull
ev.isNull = eval.isNull
eval.code + s"""
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = ${f(eval.primitive)};
}
"""
}
}

// TODO Semantically we probably not need GroupExpression
Expand Down
Loading

0 comments on commit 5e7b6b6

Please sign in to comment.