diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cfa87aeea193a..b9771ec4ba89a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -42,6 +42,7 @@ 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha1', 'sparkPartitionId', 'struct', 'udf', @@ -363,6 +364,26 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha1(col, length): + """Returns the hex string result of SHA-1. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col), length) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha(col, length): + """Returns the hex string result of SHA-1. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha(_to_java_column(col), length) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5fb3369f85d12..0c9d24de18626 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -135,6 +135,8 @@ object FunctionRegistry { // misc functions expression[Md5]("md5"), + expression[Sha1]("sha1"), + expression[Sha1]("sha"), // aggregate functions expression[Average]("avg"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4bee8cb728e5c..6f4292402b65a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BinaryType, StringType, DataType} +import org.apache.spark.sql.types.{LongType, BinaryType, StringType, DataType} import org.apache.spark.unsafe.types.UTF8String /** @@ -48,3 +48,38 @@ case class Md5(child: Expression) s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } } + +/** + * A function that calculates a sha1 hash value and returns it as a hex string + * For input of type [[BinaryType]] or [[StringType]] + */ +case class Sha1(child: Expression) extends UnaryExpression { + + override def dataType: DataType = StringType + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + value match { + case b: Array[Byte] => + UTF8String.fromString(DigestUtils.sha1Hex(value.asInstanceOf[Array[Byte]])) + case s: UTF8String => + UTF8String.fromString(DigestUtils.sha1Hex(s.getBytes)) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + child.dataType match { + case BinaryType => + "org.apache.spark.unsafe.types.UTF8String.fromString" + + s"(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" + case StringType => + "org.apache.spark.unsafe.types.UTF8String.fromString" + + s"(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c.getBytes()))" + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 48b84130b4556..7fc511e42a400 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -29,4 +29,13 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(null, BinaryType)), null) } + test("sha1") { + checkEvaluation(Sha1(Literal("ABC")), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal.create(null, StringType)), null) + checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) + checkEvaluation(Sha1(Literal("")), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 38d9085a505fb..4d310400a711a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1414,6 +1414,38 @@ object functions { */ def md5(columnName: String): Column = md5(Column(columnName)) + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(columnName: String): Column = sha1(Column(columnName)) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha(columnName: String): Column = sha1(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8b53b384a22fd..f051aebda783d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -144,6 +144,30 @@ class DataFrameFunctionsSuite extends QueryTest { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } + test("misc sha1 function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha1($"a"), sha1("b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha1(a)", "sha1(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + + test("misc sha function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha($"a"), sha("b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha(a)", "sha(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")),