From 607d7a3a3aa51795965df5bd97f1603c3f3d668a Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 30 Jun 2015 23:45:48 -0700 Subject: [PATCH] use checkInputTypes --- .../spark/sql/catalyst/expressions/math.scala | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 8fdd37f80e17c..80d30aa2a2e92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} -import java.nio.charset.{StandardCharsets, Charset} import java.util.Arrays -import org.apache.commons.codec.DecoderException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -361,27 +359,43 @@ case class Pow(left: Expression, right: Expression) * Resulting characters are returned as a byte array. */ case class UnHex(child: Expression) - extends UnaryExpression with AutoCastInputTypes with Serializable { - - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + extends UnaryExpression with Serializable { override def dataType: DataType = BinaryType + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$unHex accepts String type, not ${child.dataType}") + } + } + + override def eval(input: InternalRow): Any = { val num = child.eval(input) if (num == null) { null } else { - unhex(num.asInstanceOf[UTF8String]) + unhex(num.asInstanceOf[UTF8String].toString) } } - private def unhex(utf8Str: UTF8String): Array[Byte] = { - try { - new org.apache.commons.codec.binary.Hex(StandardCharsets.UTF_8).decode(utf8Str.getBytes) - } catch { - case _: DecoderException => null + private def unhex(s: String): Array[Byte] = { + // append a leading 0 if needed + val str = if (s.length % 2 == 1) {"0" + s} else {s} + val result = new Array[Byte](str.length / 2) + var i = 0 + while (i < str.length()) { + try { + result(i / 2) = Integer.parseInt(str.substring(i, i + 2), 16).asInstanceOf[Byte] + } catch { + // invalid character present, return null + case _: NumberFormatException => return null + } + i += 2 } + result } }