Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 4, 2015
1 parent 593d617 commit 3ff25f8
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
s"""
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
final ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
${ctx.defaultPrimitive(dataType)} : (${ctx.getColumn(dataType, ordinal)});
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,33 +439,30 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w

case Cast(child @ BinaryType(), StringType) =>
castOrNull (ctx, ev, c =>
s"new org.apache.spark.sql.types.UTF8String().set($c)",
StringType)
s"new org.apache.spark.sql.types.UTF8String().set($c)")

case Cast(child @ DateType(), StringType) =>
castOrNull(ctx, ev, c =>
s"""new org.apache.spark.sql.types.UTF8String().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
StringType)
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")

case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c?1:0)", dt)
case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)")

case Cast(child @ DecimalType(), IntegerType) =>
castOrNull(ctx, ev, c => s"($c).toInt()", IntegerType)
castOrNull(ctx, ev, c => s"($c).toInt()")

case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"($c).to${ctx.termForType(dt)}()", dt)
castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")

case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c)", dt)
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")

// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
castOrNull(ctx, ev, c =>
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))",
StringType)
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))")

case other =>
super.genSource(ctx, ev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ abstract class Expression extends TreeNode[Expression] {
/* expression: ${this} */
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
${ctx.primitiveForType(e.dataType)} ${ev.primitiveTerm} =
${ctx.defaultPrimitive(e.dataType)};
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) ${ev.primitiveTerm} =
(${ctx.termForType(e.dataType)})${ev.objectTerm};
(${ctx.boxedType(e.dataType)})${ev.objectTerm};
"""
}

Expand Down Expand Up @@ -173,12 +173,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
*/
def evaluate(ctx: CodeGenContext,
ev: EvaluatedExpression,
f: (String, String) => String): String =
evaluateAs(left.dataType)(ctx, ev, f)

def evaluateAs(resultType: DataType)(ctx: CodeGenContext,
ev: EvaluatedExpression,
f: (String, String) => String): String = {
f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
Expand All @@ -188,14 +183,19 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)

eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = ${eval1.nullTerm} || ${eval2.nullTerm};
${ctx.primitiveForType(resultType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(resultType)};
if(!${ev.nullTerm}) {
${ev.primitiveTerm} = (${ctx.primitiveForType(resultType)})($resultCode);
}
"""
s"""
${eval1.code}
boolean ${ev.nullTerm} = ${eval1.nullTerm};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
${eval2.code}
if(!${eval2.nullTerm}) {
${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode);
} else {
${ev.nullTerm} = true;
}
}
"""
}
}

Expand All @@ -207,16 +207,15 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
self: Product =>
def castOrNull(ctx: CodeGenContext,
ev: EvaluatedExpression,
f: String => String, dataType: DataType): String = {
f: String => String): String = {
val eval = child.gen(ctx)
eval.code +
s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
}
"""
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
}
"""
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultPrimitive(left.dataType)};
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
} else {
Expand Down Expand Up @@ -279,8 +279,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultPrimitive(left.dataType)};
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
} else {
Expand Down Expand Up @@ -412,8 +412,8 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
val eval2 = right.gen(ctx)
eval1.code + eval2.code + s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultPrimitive(left.dataType)};
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};

if (${eval1.nullTerm}) {
${ev.nullTerm} = ${eval2.nullTerm};
Expand Down Expand Up @@ -468,8 +468,8 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {

eval1.code + eval2.code + s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultPrimitive(left.dataType)};
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};

if (${eval1.nullTerm}) {
${ev.nullTerm} = ${eval2.nullTerm};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
dataType match {
case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)"
case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)"
case _ => s"(${termForType(dataType)})i.apply($ordinal)"
case _ => s"(${boxedType(dataType)})i.apply($ordinal)"
}
}

Expand All @@ -86,12 +86,12 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {

def accessorForType(dt: DataType): String = dt match {
case IntegerType => "getInt"
case other => s"get${termForType(dt)}"
case other => s"get${boxedType(dt)}"
}

def mutatorForType(dt: DataType): String = dt match {
case IntegerType => "setInt"
case other => s"set${termForType(dt)}"
case other => s"set${boxedType(dt)}"
}

def hashSetForType(dt: DataType): String = dt match {
Expand All @@ -101,7 +101,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
sys.error(s"Code generation not support for hashset of type $unsupportedType")
}

def primitiveForType(dt: DataType): String = dt match {
/**
* Return the primitive type for a DataType
*/
def primitiveType(dt: DataType): String = dt match {
case IntegerType => "int"
case LongType => "long"
case ShortType => "short"
Expand All @@ -117,7 +120,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
case _ => "Object"
}

def defaultPrimitive(dt: DataType): String = dt match {
/**
* Return the representation of default value for given DataType
*/
def defaultValue(dt: DataType): String = dt match {
case BooleanType => "false"
case FloatType => "-1.0f"
case ShortType => "-1"
Expand All @@ -131,7 +137,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
case _ => "null"
}

def termForType(dt: DataType): String = dt match {
/**
* Return the boxed type in Java
*/
def boxedType(dt: DataType): String = dt match {
case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
Expand All @@ -147,6 +156,15 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
case _ => "Object"
}

/**
* Returns a function to generate equal expression in Java
*/
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" }
case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" }
case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" }
}

/**
* List of data types that have special accessors and setters in [[Row]].
*/
Expand All @@ -166,7 +184,6 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
*/
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {

protected val rowType = classOf[Row].getName
protected val exprType = classOf[Expression].getName
protected val mutableRowType = classOf[MutableRow].getName
protected val genericMutableRowType = classOf[GenericMutableRow].getName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val ctx = newCodeGenContext()
val columns = expressions.zipWithIndex.map {
case (e, i) =>
s"private ${ctx.primitiveForType(e.dataType)} c$i = ${ctx.defaultPrimitive(e.dataType)};\n"
s"private ${ctx.primitiveType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
}.mkString("\n ")

val initColumns = expressions.zipWithIndex.map {
Expand All @@ -68,7 +68,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}.mkString("\n ")

val updateCases = expressions.zipWithIndex.map { case (e, i) =>
s"case $i: { c$i = (${ctx.termForType(e.dataType)})value; return;}"
s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
}.mkString("\n ")

val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
Expand All @@ -80,14 +80,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if (cases.count(_ != '\n') > 0) {
s"""
@Override
public ${ctx.primitiveForType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
public ${ctx.primitiveType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
if (isNullAt(i)) {
return ${ctx.defaultPrimitive(dataType)};
return ${ctx.defaultValue(dataType)};
}
switch (i) {
$cases
}
return ${ctx.defaultPrimitive(dataType)};
return ${ctx.defaultValue(dataType)};
}"""
} else {
""
Expand All @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if (cases.count(_ != '\n') > 0) {
s"""
@Override
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveForType(dataType)} value) {
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveType(dataType)} value) {
nullBits[i] = false;
switch (i) {
$cases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} =
${ctx.defaultPrimitive(DecimalType())};
${ctx.defaultValue(DecimalType())};

if (!${ev.nullTerm}) {
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,33 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
if (value == null) {
s"""
final boolean ${ev.nullTerm} = true;
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
"""
} else {
dataType match {
case StringType =>
val v = value.asInstanceOf[UTF8String]
val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}"
s"""
final boolean ${ev.nullTerm} = false;
org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} =
new org.apache.spark.sql.types.UTF8String().set(${arr});
"""
final boolean ${ev.nullTerm} = false;
org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} =
new org.apache.spark.sql.types.UTF8String().set(${arr});
"""
case FloatType =>
s"""
final boolean ${ev.nullTerm} = false;
float ${ev.primitiveTerm} = ${value}f;
"""
final boolean ${ev.nullTerm} = false;
float ${ev.primitiveTerm} = ${value}f;
"""
case dt: DecimalType =>
s"""
final boolean ${ev.nullTerm} = false;
${ctx.primitiveForType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveForType(dt)}().set($value);
"""
final boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dt)}().set($value);
"""
case dt: NumericType =>
s"""
final boolean ${ev.nullTerm} = false;
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = $value;
"""
final boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
"""
case other =>
super.genSource(ctx, ev)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
s"""
boolean ${ev.nullTerm} = true;
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
""" +
children.map { e =>
val eval = e.gen(ctx)
Expand Down Expand Up @@ -131,4 +131,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}
numNonNulls >= n
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
val nonnull = ctx.freshName("nonnull")
val code = children.map { e =>
val eval = e.gen(ctx)
s"""
if($nonnull < $n) {
${eval.code}
if(!${eval.nullTerm}) {
$nonnull += 1;
}
}
"""
}.mkString("\n")
s"""
int $nonnull = 0;
$code
boolean ${ev.nullTerm} = false;
boolean ${ev.primitiveTerm} = $nonnull >= $n;
"""
}
}
Loading

0 comments on commit 3ff25f8

Please sign in to comment.