diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index f7df2340edb6c..21f8c812c9ce5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -38,11 +38,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.nullTerm} = ${eval.nullTerm}; - long ${ev.primitiveTerm} = ${ev.nullTerm} ? -1 : ${eval.primitiveTerm}.toUnscaledLong(); - """ + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 846fc9d90a86b..d69324acf0e5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -146,9 +146,12 @@ case class And(left: Expression, right: Expression) } } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) + + // The result should be `false`, if any of them is `false` whenever the other is null or not. s""" ${eval1.code} boolean ${ev.nullTerm} = false; @@ -192,20 +195,21 @@ case class Or(left: Expression, right: Expression) } } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) + + // The result should be `true`, if any of them is `true` whenever the other is null or not. s""" ${eval1.code} boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = false; + boolean ${ev.primitiveTerm} = true; if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - ${ev.primitiveTerm} = true; } else { ${eval2.code} if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - ${ev.primitiveTerm} = true; } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { ${ev.primitiveTerm} = false; } else { @@ -218,19 +222,6 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - left.dataType match { - case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { - (c1, c3) => s"$c1 $symbol $c3" - }) - case TimestampType => - // java.sql.Timestamp does not have compare() - super.genCode(ctx, ev) - case other => defineCodeGen (ctx, ev, { - (c1, c2) => s"$c1.compare($c2) $symbol 0" - }) - } - } override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { @@ -258,6 +249,20 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + left.dataType match { + case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" + }) + case TimestampType => + // java.sql.Timestamp does not have compare() + super.genCode(ctx, ev) + case other => defineCodeGen (ctx, ev, { + (c1, c2) => s"$c1.compare($c2) $symbol 0" + }) + } + } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryComparisons must override either eval or evalInternal") } @@ -389,9 +394,9 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.gen(ctx) s""" + ${condEval.code} boolean ${ev.nullTerm} = false; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - ${condEval.code} if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) { ${trueEval.code} ${ev.nullTerm} = ${trueEval.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index a0c81473ec050..40107c5985481 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -112,11 +112,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { val htype = ctx.primitiveType(dataType) ev.nullTerm = "false" + ev.primitiveTerm = setEval.primitiveTerm itemEval.code + setEval.code + s""" if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); } - ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; """ case _ => super.genCode(ctx, ev) } @@ -147,10 +147,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightValue = iterator.next() leftEval.add(rightValue) } - leftEval - } else { - null } + leftEval } else { null } @@ -164,10 +162,12 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightEval = right.gen(ctx) val htype = ctx.primitiveType(dataType) - ev.nullTerm = "false" + ev.nullTerm = leftEval.nullTerm + ev.primitiveTerm = leftEval.primitiveTerm leftEval.code + rightEval.code + s""" - ${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm}; - ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); + if (!${leftEval.nullTerm} && !${rightEval.nullTerm}) { + ${leftEval.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); + } """ case _ => super.genCode(ctx, ev) }