Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 5, 2015
1 parent b5d3617 commit 02262c9
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -442,47 +442,35 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")

case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")

case (BooleanType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)")

case (_: NumericType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")

case (_: DecimalType, ByteType) =>
defineCodeGen(ctx, ev, c => s"($c).toByte()")

case (_: DecimalType, ShortType) =>
defineCodeGen(ctx, ev, c => s"($c).toShort()")

case (_: DecimalType, IntegerType) =>
defineCodeGen(ctx, ev, c => s"($c).toInt()")

case (_: DecimalType, LongType) =>
defineCodeGen(ctx, ev, c => s"($c).toLong()")

case (_: DecimalType, FloatType) =>
defineCodeGen(ctx, ev, c => s"($c).toFloat()")

case (_: DecimalType, DoubleType) =>
defineCodeGen(ctx, ev, c => s"($c).toDouble()")

case (_: DecimalType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")

// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case (TimestampType, StringType) =>
super.genCode(ctx, ev)

case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")

// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
super.genCode(ctx, ev)

case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)")
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c.isZero()")
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")

case (_: DecimalType, IntegerType) =>
defineCodeGen(ctx, ev, c => s"($c).toInt()")
case (_: DecimalType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
case (_: NumericType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")

case other =>
super.genCode(ctx, ev)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -62,8 +62,7 @@ abstract class Expression extends TreeNode[Expression] {
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val nullTerm = ctx.freshName("nullTerm")
val primitiveTerm = ctx.freshName("primitiveTerm")
val objectTerm = ctx.freshName("objectTerm")
val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm, objectTerm)
val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm)
ve.code = genCode(ctx, ve)
ve
}
Expand All @@ -77,17 +76,18 @@ abstract class Expression extends TreeNode[Expression] {
* @param ev an [[GeneratedExpressionCode]] with unique terms.
* @return Java source code
*/
def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val e = this.asInstanceOf[Expression]
ctx.references += e
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${ev.objectTerm};
}
/* expression: ${this} */
final Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
final boolean ${ev.nullTerm} = ${objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${objectTerm};
}
"""
}

Expand Down Expand Up @@ -167,7 +167,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
f: (Term, Term) => Code): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
Expand Down Expand Up @@ -214,10 +214,11 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: String => String): String = {
f: Term => Code): Code = {
val eval = child.gen(ctx)
// reuse the previous nullTerm
ev.nullTerm = eval.nullTerm
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,8 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
* to null.
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
* valid if `nullTerm` is set to `true`.
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
*/
case class GeneratedExpressionCode(var code: Code,
nullTerm: Term,
primitiveTerm: Term,
objectTerm: Term)
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, primitiveTerm: Term)

/**
* A context for codegen, which is used to bookkeeping the expressions those are not supported
Expand Down Expand Up @@ -73,40 +69,44 @@ class CodeGenContext {
s"$prefix${curId.getAndIncrement}"
}

/**
* Return the code to access a column for given DataType
*/
def getColumn(dataType: DataType, ordinal: Int): Code = {
dataType match {
case StringType => s"($stringType)i.apply($ordinal)"
case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)"
case _ => s"(${boxedType(dataType)})i.apply($ordinal)"
if (isNativeType(dataType)) {
s"i.${accessorForType(dataType)}($ordinal)"
} else {
s"(${boxedType(dataType)})i.apply($ordinal)"
}
}

def setColumn(destinationRow: Term, dataType: DataType, ordinal: Int, value: Term): Code = {
dataType match {
case StringType => s"$destinationRow.update($ordinal, $value)"
case dt: DataType if isNativeType(dt) =>
s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => s"$destinationRow.update($ordinal, $value)"
/**
* Return the code to update a column in Row for given DataType
*/
def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = {
if (isNativeType(dataType)) {
s"${mutatorForType(dataType)}($ordinal, $value)"
} else {
s"update($ordinal, $value)"
}
}

/**
* Return the name of accessor in Row for a DataType
*/
def accessorForType(dt: DataType): Term = dt match {
case IntegerType => "getInt"
case other => s"get${boxedType(dt)}"
}

/**
* Return the name of mutator in Row for a DataType
*/
def mutatorForType(dt: DataType): Term = dt match {
case IntegerType => "setInt"
case other => s"set${boxedType(dt)}"
}

def hashSetForType(dt: DataType): Term = dt match {
case IntegerType => classOf[IntegerHashSet].getName
case LongType => classOf[LongHashSet].getName
case unsupportedType =>
sys.error(s"Code generation not support for hashset of type $unsupportedType")
}

/**
* Return the primitive type for a DataType
*/
Expand All @@ -123,9 +123,26 @@ class CodeGenContext {
case StringType => stringType
case DateType => "int"
case TimestampType => "java.sql.Timestamp"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
}

/**
* Return the boxed type in Java
*/
def boxedType(dt: DataType): Term = dt match {
case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
case ByteType => "Byte"
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
case DateType => "Integer"
case _ => primitiveType(dt)
}

/**
* Return the representation of default value for given DataType
*/
Expand All @@ -138,30 +155,9 @@ class CodeGenContext {
case DoubleType => "-1.0"
case IntegerType => "-1"
case DateType => "-1"
case dt: DecimalType => "null"
case StringType => "null"
case _ => "null"
}

/**
* Return the boxed type in Java
*/
def boxedType(dt: DataType): Term = dt match {
case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
case ByteType => "Byte"
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
case DateType => "Integer"
case TimestampType => "java.sql.Timestamp"
case _ => "Object"
}

/**
* Returns a function to generate equal expression in Java
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
if(${evaluationCode.nullTerm})
mutableRow.setNullAt($i);
else
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)};
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitiveTerm)};
"""
}.mkString("\n")
val code = s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,11 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = ${ctx.defaultValue(DecimalType())};
${ctx.decimalType} ${ev.primitiveTerm} = null;

if (!${ev.nullTerm}) {
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
${ev.primitiveTerm} =
${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale);
${ev.primitiveTerm} = (new ${ctx.decimalType}()).setOrNull(
${eval.primitiveTerm}, $precision, $scale);
${ev.nullTerm} = ${ev.primitiveTerm} == null;
}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,34 +85,21 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
if (value == null) {
s"""
final boolean ${ev.nullTerm} = true;
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
"""
} else {
// TODO(cg): Add support for more data types.
dataType match {
case StringType =>
val v = value.asInstanceOf[UTF8String]
val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}"
s"""
final boolean ${ev.nullTerm} = false;
${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr});
"""
case FloatType => // This must go before NumericType
s"""
final boolean ${ev.nullTerm} = false;
float ${ev.primitiveTerm} = ${value}f;
"""
case dt: DecimalType => // This must go before NumericType
s"""
final boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dt)} ${ev.primitiveTerm} =
new ${ctx.primitiveType(dt)}().set($value);
final float ${ev.primitiveTerm} = ${value}f;
"""
case dt: NumericType =>
case dt: NumericType if !dt.isInstanceOf[DecimalType]=>
s"""
final boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
"""
// eval() version may be faster for non-primitive types
case other =>
super.genCode(ctx, ev)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ case class NewSet(elementType: DataType) extends LeafExpression {
case IntegerType | LongType =>
s"""
boolean ${ev.nullTerm} = false;
${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} =
new ${ctx.hashSetForType(elementType)}();
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}();
"""
case _ => super.genCode(ctx, ev)
}
Expand Down Expand Up @@ -110,14 +109,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
case IntegerType | LongType =>
val itemEval = item.gen(ctx)
val setEval = set.gen(ctx)
val htype = ctx.hashSetForType(elementType)
val htype = ctx.primitiveType(dataType)

itemEval.code + setEval.code + s"""
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
}
boolean ${ev.nullTerm} = false;
${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm};
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
}
boolean ${ev.nullTerm} = false;
${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm};
"""
case _ => super.genCode(ctx, ev)
}
Expand Down Expand Up @@ -163,7 +162,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
case IntegerType | LongType =>
val leftEval = left.gen(ctx)
val rightEval = right.gen(ctx)
val htype = ctx.hashSetForType(elementType)
val htype = ctx.primitiveType(dataType)

leftEval.code + rightEval.code + s"""
boolean ${ev.nullTerm} = false;
Expand Down

0 comments on commit 02262c9

Please sign in to comment.