From e03edaaf9c235c2acd86e97e8f6c4efb21487437 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Jun 2015 14:30:41 -0700 Subject: [PATCH] consts fold --- .../sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 11 +++++--- .../expressions/codegen/CodeGenerator.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 26 ++++++++++--------- .../catalyst/expressions/nullFunctions.scala | 14 +++++----- .../sql/catalyst/expressions/predicates.scala | 2 +- .../spark/sql/catalyst/expressions/sets.scala | 6 ++--- 7 files changed, 33 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2df6737adb42b..6866b1182e0da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -184,7 +184,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express if (!${ev.nullTerm}) { ${eval2.code} if(!${eval2.nullTerm}) { - ${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode); + ${ev.primitiveTerm} = $resultCode; } else { ${ev.nullTerm} = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a049f8878ed32..0923ab6f59564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -118,12 +118,15 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - if (left.dataType.isInstanceOf[DecimalType]) { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") - } else { + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + defineCodeGen(ctx, ev, (eval1, eval2) => + s"(${ctx.primitiveType(dataType)})($eval1 $symbol $eval2)") + case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } protected def evalInternal(evalE1: Any, evalE2: Any): Any = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c963971d28cb1..f6a2a2be1c89f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -40,7 +40,7 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not * valid if `nullTerm` is set to `true`. */ -case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, primitiveTerm: Term) +case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 4f00cb6bec586..e121d39e1d9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -82,23 +82,25 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def eval(input: Row): Any = value override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + // change the nullTerm and primitiveTerm to consts, to inline them if (value == null) { - s""" - final boolean ${ev.nullTerm} = true; - final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - """ + ev.nullTerm = "true" + ev.primitiveTerm = ctx.defaultValue(dataType) + "" } else { dataType match { + case BooleanType => + ev.nullTerm = "false" + ev.primitiveTerm = value.toString + "" case FloatType => // This must go before NumericType - s""" - final boolean ${ev.nullTerm} = false; - final float ${ev.primitiveTerm} = ${value}f; - """ + ev.nullTerm = "false" + ev.primitiveTerm = s"${value}f" + "" case dt: NumericType if !dt.isInstanceOf[DecimalType] => - s""" - final boolean ${ev.nullTerm} = false; - final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; - """ + ev.nullTerm = "false" + ev.primitiveTerm = value.toString + "" // eval() version may be faster for non-primitive types case other => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index e380eafc3fc2a..e3c3489d11aea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -83,10 +83,9 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) - eval.code + s""" - final boolean ${ev.nullTerm} = false; - final boolean ${ev.primitiveTerm} = ${eval.nullTerm}; - """ + ev.nullTerm = "false" + ev.primitiveTerm = eval.nullTerm + eval.code } override def toString: String = s"IS NULL $child" @@ -103,10 +102,9 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = !${eval.nullTerm}; - """ + ev.nullTerm = "false" + ev.primitiveTerm = s"(!(${eval.nullTerm}))" + eval.code } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 67cac26fd0d55..846fc9d90a86b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -304,8 +304,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) + ev.nullTerm = "false" eval1.code + eval2.code + s""" - final boolean ${ev.nullTerm} = false; final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) || (!${eval1.nullTerm} && $equalCode); """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index ef1c2bc5836e0..a0c81473ec050 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -64,8 +64,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { elementType match { case IntegerType | LongType => + ev.nullTerm = "false" s""" - boolean ${ev.nullTerm} = false; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}(); """ case _ => super.genCode(ctx, ev) @@ -111,11 +111,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { val setEval = set.gen(ctx) val htype = ctx.primitiveType(dataType) + ev.nullTerm = "false" itemEval.code + setEval.code + s""" if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); } - boolean ${ev.nullTerm} = false; ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; """ case _ => super.genCode(ctx, ev) @@ -164,8 +164,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightEval = right.gen(ctx) val htype = ctx.primitiveType(dataType) + ev.nullTerm = "false" leftEval.code + rightEval.code + s""" - boolean ${ev.nullTerm} = false; ${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm}; ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); """