Skip to content

Commit

Permalink
Add misc function: sha2.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 22, 2015
1 parent 47c1d56 commit 59e41aa
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 4 deletions.
18 changes: 18 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
'monotonicallyIncreasingId',
'rand',
'randn',
'sha2',
'sparkPartitionId',
'struct',
'udf',
Expand Down Expand Up @@ -363,6 +364,23 @@ def randn(seed=None):
return Column(jc)


@ignore_unicode_prefix
@since(1.5)
def sha2(col, length):
"""Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
and SHA-512).
>>> digests = df.select(sha2(df.name, 256).alias('s')).collect()
>>> digests[0]
Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')
>>> digests[1]
Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.sha2(_to_java_column(col), length)
return Column(jc)


@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ object FunctionRegistry {

// misc functions
expression[Md5]("md5"),
expression[Sha2]("sha2"),

// aggregate functions
expression[Average]("avg"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

package org.apache.spark.sql.catalyst.expressions

import java.security.MessageDigest
import java.security.NoSuchAlgorithmException

import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{BinaryType, StringType, DataType}
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType}
import org.apache.spark.unsafe.types.UTF8String

/**
Expand All @@ -44,7 +47,96 @@ case class Md5(child: Expression)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c =>
"org.apache.spark.unsafe.types.UTF8String.fromString" +
s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
}
}

/**
* A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512)
* and returns it as a hex string. The first argument is the string or binary to be hashed. The
* second argument indicates the desired bit length of the result, which must have a value of 224,
* 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If
* asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or
* the hash length is not one of the permitted values, the return value is NULL.
*/
case class Sha2(left: Expression, right: Expression)
extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>

override def dataType: DataType = StringType

override def toString: String = s"SHA2($left, $right)"

override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)

override def eval(input: InternalRow): Any = {
val evalE1 = left.eval(input)
if (evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
if (evalE2 == null) {
null
} else {
val bitLength = evalE2.asInstanceOf[Int]
val input = evalE1.asInstanceOf[Array[Byte]]
bitLength match {
case 224 =>
// DigestUtils doesn't support SHA-224 now
try {
val md = MessageDigest.getInstance("SHA-224")
md.update(input)
UTF8String.fromBytes(md.digest())
} catch {
// SHA-224 is not supported on the system, return null
case noa: NoSuchAlgorithmException => null
}
case 256 | 0 =>
UTF8String.fromString(DigestUtils.sha256Hex(input))
case 384 =>
UTF8String.fromString(DigestUtils.sha384Hex(input))
case 512 =>
UTF8String.fromString(DigestUtils.sha512Hex(input))
case _ => null
}
}
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val digestUtils = "org.apache.commons.codec.digest.DigestUtils"

s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
if (${eval2.primitive} == 224) {
try {
java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
md.update(${eval1.primitive});
${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
} catch (java.security.NoSuchAlgorithmException e) {
${ev.isNull} = true;
}
} else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) {
${ev.primitive} =
${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive}));
} else if (${eval2.primitive} == 384) {
${ev.primitive} =
${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive}));
} else if (${eval2.primitive} == 512) {
${ev.primitive} =
${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive}));
} else {
${ev.isNull} = true;
}
} else {
${ev.isNull} = true;
}
}
"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{StringType, BinaryType}
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand All @@ -29,4 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
}

test("sha2") {
checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
// unsupported bit length
checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
}
}
16 changes: 16 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,22 @@ object functions {
*/
def md5(columnName: String): Column = md5(Column(columnName))

/**
* Calculates the SHA-2 family of hash functions and returns the value as a hex string.
*
* @group misc_funcs
* @since 1.5.0
*/
def sha2(e: Column, bit: Int): Column = Sha2(e.expr, lit(bit).expr)

/**
* Calculates the SHA-2 family of hash functions and returns the value as a hex string.
*
* @group misc_funcs
* @since 1.5.0
*/
def sha2(columnName: String, bit: Int): Column = sha2(Column(columnName), bit)

//////////////////////////////////////////////////////////////////////////////////////////////
// String functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,19 @@ class DataFrameFunctionsSuite extends QueryTest {
Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c"))
}

test("misc sha2 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer(
df.select(sha2($"a", 256), sha2("b", 256)),
Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
"7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))

checkAnswer(
df.selectExpr("sha2(a, 256)", "sha2(b, 256)"),
Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
"7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))
}

test("string length function") {
checkAnswer(
nullStrings.select(strlen($"s"), strlen("s")),
Expand Down

0 comments on commit 59e41aa

Please sign in to comment.