Skip to content

Commit

Permalink
add factorial
Browse files Browse the repository at this point in the history
  • Loading branch information
zhichao-li committed Jul 3, 2015
1 parent d983819 commit 26edf4f
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ object FunctionRegistry {
expression[Exp]("exp"),
expression[Expm1]("expm1"),
expression[Floor]("floor"),
expression[Factorial]("factorial"),
expression[Hypot]("hypot"),
expression[Hex]("hex"),
expression[Logarithm]("log"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import java.lang.{Long => JLong}
import java.util.Arrays

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{StringType}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, IntegerType}
import org.apache.spark.unsafe.types.UTF8String

/**
Expand Down Expand Up @@ -159,6 +161,82 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP

case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")

object Factorial {

def factorial(n: Int): Long = {
if (n < factorials.length) factorials(n) else Long.MaxValue
}

private val factorials: Array[Long] = Array[Long](
1,
1,
2,
6,
24,
120,
720,
5040,
40320,
362880,
3628800,
39916800,
479001600,
6227020800L,
87178291200L,
1307674368000L,
20922789888000L,
355687428096000L,
6402373705728000L,
121645100408832000L,
2432902008176640000L
)
}

case class Factorial(child: Expression)
extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[DataType] = Seq(IntegerType)

override def dataType: DataType = LongType

override def foldable: Boolean = child.foldable

// If the value not in the range of [0, 20], it still will be null, so set it to be true here.
override def nullable: Boolean = true

override def toString: String = s"factorial($child)"

override def eval(input: InternalRow): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
val input = evalE.asInstanceOf[Integer]
if (input > 20 || input < 0) {
null
} else {
Factorial.factorial(input)
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
if (${eval.primitive} > 20 || ${eval.primitive} < 0) {
${ev.isNull} = true;
} else {
${ev.primitive} =
org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive});
}
}
"""
}
}

case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")

case class Log2(child: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

package org.apache.spark.sql.catalyst.expressions

import com.google.common.math.LongMath

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType}
import org.apache.spark.sql.types.{DataType, LongType}
import org.apache.spark.sql.types.{IntegerType, DoubleType}

class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -157,6 +160,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testUnary(Floor, math.floor)
}

test("factorial") {
val dataLong = (0 to 20)
dataLong.foreach { value =>
checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow)
}
checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null))
checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow)
checkEvaluation(Factorial(Literal(21)), null, EmptyRow)
}

test("rint") {
testUnary(Rint, math.rint)
}
Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,22 @@ object functions {
*/
def expm1(columnName: String): Column = expm1(Column(columnName))

/**
* Computes the factorial of the given value.
*
* @group math_funcs
* @since 1.5.0
*/
def factorial(e: Column): Column = Factorial(e.expr)

/**
* Computes the factorial of the given column.
*
* @group math_funcs
* @since 1.5.0
*/
def factorial(columnName: String): Column = factorial(Column(columnName))

/**
* Computes the floor of the given value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{log => logarithm}


private object MathExpressionsTestData {
case class DoubleData(a: java.lang.Double, b: java.lang.Double)
case class NullDoubles(a: java.lang.Double)
Expand Down Expand Up @@ -183,6 +182,18 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction(floor, math.floor)
}

test("factorial") {
val df = (0 to 5).map(i => (i, i)).toDF("a", "b")
checkAnswer(
df.select(factorial('a)),
Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120))
)
checkAnswer(
df.selectExpr("factorial(a)"),
Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120))
)
}

test("rint") {
testOneToOneMathFunction(rint, math.rint)
}
Expand Down

0 comments on commit 26edf4f

Please sign in to comment.