From 3ff25f81a8fc6840b5c6dc75377fc89e41454586 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Jun 2015 17:45:19 -0700 Subject: [PATCH] refactor --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 19 +++-- .../sql/catalyst/expressions/Expression.scala | 51 +++++++------ .../sql/catalyst/expressions/arithmetic.scala | 16 ++--- .../expressions/codegen/CodeGenerator.scala | 31 ++++++-- .../codegen/GenerateProjection.scala | 12 ++-- .../expressions/decimalFunctions.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 28 ++++---- .../catalyst/expressions/nullFunctions.scala | 23 +++++- .../sql/catalyst/expressions/predicates.scala | 37 ++++------ .../spark/sql/catalyst/expressions/sets.scala | 71 +++++++++++-------- 11 files changed, 163 insertions(+), 131 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1055be6e9d273..1d7f3b766a160 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -46,8 +46,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { s""" final boolean ${ev.nullTerm} = i.isNullAt($ordinal); - final ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? - ${ctx.defaultPrimitive(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? + ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a986844d18e8f..bf8642cdde535 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -439,33 +439,30 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case Cast(child @ BinaryType(), StringType) => castOrNull (ctx, ev, c => - s"new org.apache.spark.sql.types.UTF8String().set($c)", - StringType) + s"new org.apache.spark.sql.types.UTF8String().set($c)") case Cast(child @ DateType(), StringType) => castOrNull(ctx, ev, c => s"""new org.apache.spark.sql.types.UTF8String().set( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", - StringType) + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c?1:0)", dt) + case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)") case Cast(child @ DecimalType(), IntegerType) => - castOrNull(ctx, ev, c => s"($c).toInt()", IntegerType) + castOrNull(ctx, ev, c => s"($c).toInt()") case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"($c).to${ctx.termForType(dt)}()", dt) + castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c)", dt) + castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. case Cast(e, StringType) if e.dataType != TimestampType => castOrNull(ctx, ev, c => - s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))", - StringType) + s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))") case other => super.genSource(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f66f8f9ff105e..9b89a4bc744c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -90,10 +90,10 @@ abstract class Expression extends TreeNode[Expression] { /* expression: ${this} */ Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i); boolean ${ev.nullTerm} = ${ev.objectTerm} == null; - ${ctx.primitiveForType(e.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(e.dataType)}; + ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(e.dataType)}; if (!${ev.nullTerm}) ${ev.primitiveTerm} = - (${ctx.termForType(e.dataType)})${ev.objectTerm}; + (${ctx.boxedType(e.dataType)})${ev.objectTerm}; """ } @@ -173,12 +173,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express */ def evaluate(ctx: CodeGenContext, ev: EvaluatedExpression, - f: (String, String) => String): String = - evaluateAs(left.dataType)(ctx, ev, f) - - def evaluateAs(resultType: DataType)(ctx: CodeGenContext, - ev: EvaluatedExpression, - f: (String, String) => String): String = { + f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (left.dataType != right.dataType) { // log.warn(s"${left.dataType} != ${right.dataType}") @@ -188,14 +183,19 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express val eval2 = right.gen(ctx) val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - eval1.code + eval2.code + - s""" - boolean ${ev.nullTerm} = ${eval1.nullTerm} || ${eval2.nullTerm}; - ${ctx.primitiveForType(resultType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(resultType)}; - if(!${ev.nullTerm}) { - ${ev.primitiveTerm} = (${ctx.primitiveForType(resultType)})($resultCode); - } - """ + s""" + ${eval1.code} + boolean ${ev.nullTerm} = ${eval1.nullTerm}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; + if (!${ev.nullTerm}) { + ${eval2.code} + if(!${eval2.nullTerm}) { + ${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode); + } else { + ${ev.nullTerm} = true; + } + } + """ } } @@ -207,16 +207,15 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio self: Product => def castOrNull(ctx: CodeGenContext, ev: EvaluatedExpression, - f: String => String, dataType: DataType): String = { + f: String => String): String = { val eval = child.gen(ctx) - eval.code + - s""" - boolean ${ev.nullTerm} = ${eval.nullTerm}; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; - } - """ + eval.code + s""" + boolean ${ev.nullTerm} = ${eval.nullTerm}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; + } + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4320fbf51bd6d..79350dd3d65f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -221,8 +221,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { ${ev.nullTerm} = true; } else { @@ -279,8 +279,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { ${ev.nullTerm} = true; } else { @@ -412,8 +412,8 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { val eval2 = right.gen(ctx) eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm}) { ${ev.nullTerm} = ${eval2.nullTerm}; @@ -468,8 +468,8 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm}) { ${ev.nullTerm} = ${eval2.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bec1899a3aad2..4f21a1892df25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -71,7 +71,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { dataType match { case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)" case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)" - case _ => s"(${termForType(dataType)})i.apply($ordinal)" + case _ => s"(${boxedType(dataType)})i.apply($ordinal)" } } @@ -86,12 +86,12 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { def accessorForType(dt: DataType): String = dt match { case IntegerType => "getInt" - case other => s"get${termForType(dt)}" + case other => s"get${boxedType(dt)}" } def mutatorForType(dt: DataType): String = dt match { case IntegerType => "setInt" - case other => s"set${termForType(dt)}" + case other => s"set${boxedType(dt)}" } def hashSetForType(dt: DataType): String = dt match { @@ -101,7 +101,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { sys.error(s"Code generation not support for hashset of type $unsupportedType") } - def primitiveForType(dt: DataType): String = dt match { + /** + * Return the primitive type for a DataType + */ + def primitiveType(dt: DataType): String = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -117,7 +120,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { case _ => "Object" } - def defaultPrimitive(dt: DataType): String = dt match { + /** + * Return the representation of default value for given DataType + */ + def defaultValue(dt: DataType): String = dt match { case BooleanType => "false" case FloatType => "-1.0f" case ShortType => "-1" @@ -131,7 +137,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { case _ => "null" } - def termForType(dt: DataType): String = dt match { + /** + * Return the boxed type in Java + */ + def boxedType(dt: DataType): String = dt match { case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" @@ -147,6 +156,15 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { case _ => "Object" } + /** + * Returns a function to generate equal expression in Java + */ + def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { + case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" } + case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" } + case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" } + } + /** * List of data types that have special accessors and setters in [[Row]]. */ @@ -166,7 +184,6 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val rowType = classOf[Row].getName protected val exprType = classOf[Expression].getName protected val mutableRowType = classOf[MutableRow].getName protected val genericMutableRowType = classOf[GenericMutableRow].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 0e8ad76f65bad..00c856dc02ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -45,7 +45,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val ctx = newCodeGenContext() val columns = expressions.zipWithIndex.map { case (e, i) => - s"private ${ctx.primitiveForType(e.dataType)} c$i = ${ctx.defaultPrimitive(e.dataType)};\n" + s"private ${ctx.primitiveType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" }.mkString("\n ") val initColumns = expressions.zipWithIndex.map { @@ -68,7 +68,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n ") val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${ctx.termForType(e.dataType)})value; return;}" + s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" }.mkString("\n ") val specificAccessorFunctions = ctx.nativeTypes.map { dataType => @@ -80,14 +80,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public ${ctx.primitiveForType(dataType)} ${ctx.accessorForType(dataType)}(int i) { + public ${ctx.primitiveType(dataType)} ${ctx.accessorForType(dataType)}(int i) { if (isNullAt(i)) { - return ${ctx.defaultPrimitive(dataType)}; + return ${ctx.defaultValue(dataType)}; } switch (i) { $cases } - return ${ctx.defaultPrimitive(dataType)}; + return ${ctx.defaultValue(dataType)}; }""" } else { "" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveForType(dataType)} value) { + public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveType(dataType)} value) { nullBits[i] = false; switch (i) { $cases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 76273a5b7ee68..68daea725cd40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -68,7 +68,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(DecimalType())}; + ${ctx.defaultValue(DecimalType())}; if (!${ev.nullTerm}) { ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d9fbda9511a5e..366e1083eb687 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -85,7 +85,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres if (value == null) { s""" final boolean ${ev.nullTerm} = true; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; """ } else { dataType match { @@ -93,25 +93,25 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres val v = value.asInstanceOf[UTF8String] val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}" s""" - final boolean ${ev.nullTerm} = false; - org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} = - new org.apache.spark.sql.types.UTF8String().set(${arr}); - """ + final boolean ${ev.nullTerm} = false; + org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} = + new org.apache.spark.sql.types.UTF8String().set(${arr}); + """ case FloatType => s""" - final boolean ${ev.nullTerm} = false; - float ${ev.primitiveTerm} = ${value}f; - """ + final boolean ${ev.nullTerm} = false; + float ${ev.primitiveTerm} = ${value}f; + """ case dt: DecimalType => s""" - final boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveForType(dt)}().set($value); - """ + final boolean ${ev.nullTerm} = false; + ${ctx.primitiveType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dt)}().set($value); + """ case dt: NumericType => s""" - final boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = $value; - """ + final boolean ${ev.nullTerm} = false; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; + """ case other => super.genSource(ctx, ev) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 2af0f96146c1f..79c97f651f540 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -56,7 +56,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { s""" boolean ${ev.nullTerm} = true; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; """ + children.map { e => val eval = e.gen(ctx) @@ -131,4 +131,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } numNonNulls >= n } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val nonnull = ctx.freshName("nonnull") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if($nonnull < $n) { + ${eval.code} + if(!${eval.nullTerm}) { + $nonnull += 1; + } + } + """ + }.mkString("\n") + s""" + int $nonnull = 0; + $code + boolean ${ev.nullTerm} = false; + boolean ${ev.primitiveTerm} = $nonnull >= $n; + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b6b2c7db28960..3c1eeb07a91a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -85,8 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { - // Uh, bad function name... - castOrNull(ctx, ev, c => s"!($c)", BooleanType) + castOrNull(ctx, ev, c => s"!($c)") } } @@ -221,13 +220,14 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { left.dataType match { - case dt: NumericType => evaluateAs(BooleanType) (ctx, ev, { - (eval1, eval2) => s"$eval1 $symbol $eval2" + case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" }) - case dt: TimestampType => + case TimestampType => + // java.sql.Timestamp does not have compare() super.genSource(ctx, ev) - case other => evaluateAs(BooleanType) (ctx, ev, { - (eval1, eval2) => s"$eval1.compare($eval2) $symbol 0" + case other => evaluate (ctx, ev, { + (c1, c2) => s"$c1.compare($c2) $symbol 0" }) } } @@ -277,15 +277,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression) = { - left.dataType match { - case BinaryType() => - evaluateAs (BooleanType) (ctx, ev, { - case (eval1, eval2) => - s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" - }) - case other => - evaluateAs (BooleanType) (ctx, ev, { case (eval1, eval2) => s"$eval1 == $eval2" }) - } + evaluate(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -311,16 +303,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val cmpCode = if (left.dataType.isInstanceOf[BinaryType]) { - s"java.util.Arrays.equals((byte[])${eval1.primitiveTerm}, (byte[])${eval2.primitiveTerm})" - } else { - s"${eval1.primitiveTerm} == ${eval2.primitiveTerm}" - } - eval1.code + eval2.code + - s""" + val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) + eval1.code + eval2.code + s""" final boolean ${ev.nullTerm} = false; final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) || - (!${eval1.nullTerm} && $cmpCode); + (!${eval1.nullTerm} && $equalCode); """ } } @@ -403,7 +390,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; ${condEval.code} if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { ${trueEval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index e6ae81c2aad52..22755b6ecb7e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -62,11 +62,15 @@ case class NewSet(elementType: DataType) extends LeafExpression { } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { - s""" - boolean ${ev.nullTerm} = false; - ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = - new ${ctx.hashSetForType(elementType)}(); - """ + elementType match { + case IntegerType | LongType => + s""" + boolean ${ev.nullTerm} = false; + ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = + new ${ctx.hashSetForType(elementType)}(); + """ + case _ => super.genSource(ctx, ev) + } } override def toString: String = s"new Set($dataType)" @@ -101,20 +105,22 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { - val itemEval = item.gen(ctx) - val setEval = set.gen(ctx) - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = ctx.hashSetForType(elementType) - - itemEval.code + setEval.code + - s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); - } - boolean ${ev.nullTerm} = false; - ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; - """ + elementType match { + case IntegerType | LongType => + val itemEval = item.gen(ctx) + val setEval = set.gen(ctx) + val htype = ctx.hashSetForType(elementType) + + itemEval.code + setEval.code + s""" + if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { + (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); + } + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; + """ + case _ => super.genSource(ctx, ev) + } } override def toString: String = s"$set += $item" @@ -152,19 +158,20 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { - val leftEval = left.gen(ctx) - val rightEval = right.gen(ctx) - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = ctx.hashSetForType(elementType) - - leftEval.code + rightEval.code + - s""" - boolean ${ev.nullTerm} = false; - ${htype} ${ev.primitiveTerm} = - (${htype})${leftEval.primitiveTerm}; - ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); - """ + elementType match { + case IntegerType | LongType => + val leftEval = left.gen(ctx) + val rightEval = right.gen(ctx) + val htype = ctx.hashSetForType(elementType) + + leftEval.code + rightEval.code + s""" + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = ${leftEval.primitiveTerm}; + ${ev.primitiveTerm}.union(${rightEval.primitiveTerm}); + """ + case _ => super.genSource(ctx, ev) + } } } @@ -184,5 +191,9 @@ case class CountSet(child: Expression) extends UnaryExpression { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + castOrNull(ctx, ev, c => s"$c.size().toLong()") + } + override def toString: String = s"$child.count()" }