Skip to content

Commit

Permalink
optimize type converter
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 25, 2015
1 parent c337844 commit 326c82c
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ object CatalystTypeConverters {
}
}

private def isWholePrimitive(dt: DataType): Boolean = dt match {
case dt if isPrimitive(dt) => true
case ArrayType(elementType, _) => isWholePrimitive(elementType)
case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
case _ => false
}

private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
val converter = dataType match {
case udt: UserDefinedType[_] => UDTConverter(udt)
Expand Down Expand Up @@ -148,6 +155,8 @@ object CatalystTypeConverters {

private[this] val elementConverter = getConverterForType(elementType)

private[this] val isNoChange = isWholePrimitive(elementType)

override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
scalaValue match {
case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
Expand All @@ -166,8 +175,10 @@ object CatalystTypeConverters {
override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
if (catalystValue == null) {
null
} else if (isNoChange) {
catalystValue
} else {
catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala)
catalystValue.map(elementConverter.toScala)
}
}

Expand All @@ -183,6 +194,8 @@ object CatalystTypeConverters {
private[this] val keyConverter = getConverterForType(keyType)
private[this] val valueConverter = getConverterForType(valueType)

private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType)

override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
case m: Map[_, _] =>
m.map { case (k, v) =>
Expand All @@ -203,6 +216,8 @@ object CatalystTypeConverters {
override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
if (catalystValue == null) {
null
} else if (isNoChange) {
catalystValue
} else {
catalystValue.map { case (k, v) =>
keyConverter.toScala(k) -> valueConverter.toScala(v)
Expand Down Expand Up @@ -258,24 +273,22 @@ object CatalystTypeConverters {
toScala(row(column).asInstanceOf[InternalRow])
}

private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] {
override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
case str: String => UTF8String.fromString(str)
case utf8: UTF8String => utf8
}
override def toScala(catalystValue: Any): String = catalystValue match {
case null => null
case str: String => str
case utf8: UTF8String => utf8.toString()
}
override def toScala(catalystValue: UTF8String): String =
if (catalystValue == null) null else catalystValue.toString
override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString
}

private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue)
override def toScala(catalystValue: Any): Date =
if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int])
override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column))
override def toScalaImpl(row: InternalRow, column: Int): Date =
DateTimeUtils.toJavaDate(row.getInt(column))
}

private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
Expand All @@ -285,7 +298,7 @@ object CatalystTypeConverters {
if (catalystValue == null) null
else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
override def toScalaImpl(row: InternalRow, column: Int): Timestamp =
toScala(row.getLong(column))
DateTimeUtils.toJavaTimestamp(row.getLong(column))
}

private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
Expand All @@ -296,10 +309,7 @@ object CatalystTypeConverters {
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
row.get(column) match {
case d: JavaBigDecimal => d
case d: Decimal => d.toJavaBigDecimal
}
row.get(column).asInstanceOf[Decimal].toJavaBigDecimal
}

private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
Expand Down Expand Up @@ -362,6 +372,19 @@ object CatalystTypeConverters {
}
}

/**
* Creates a converter function that will convert Catalyst types to Scala type.
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
if (isPrimitive(dataType)) {
identity
} else {
getConverterForType(dataType).toScala
}
}

/**
* Converts Scala objects to Catalyst rows / types.
*
Expand Down Expand Up @@ -389,15 +412,6 @@ object CatalystTypeConverters {
* produced by createToScalaConverter.
*/
def convertToScala(catalystValue: Any, dataType: DataType): Any = {
getConverterForType(dataType).toScala(catalystValue)
}

/**
* Creates a converter function that will convert Catalyst types to Scala type.
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
getConverterForType(dataType).toScala
createToScalaConverter(dataType)(catalystValue)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types.DataType

Expand All @@ -39,7 +38,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
(1 to 22).map { x =>
val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _)
val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _)
lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _)
val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _)
val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _)
s"""case $x =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame}

Expand Down Expand Up @@ -110,13 +111,17 @@ private[sql] object FrequentItems extends Logging {
baseCounts
}
)
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
val resultRow = InternalRow(justItems : _*)

// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
val schema = StructType(outputCols)

val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
val resultRow = converter(InternalRow(justItems : _*)).asInstanceOf[InternalRow]

new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, Seq(resultRow)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
output: Seq[Attribute],
rdd: RDD[Row]): RDD[InternalRow] = {
if (relation.relation.needConversion) {
execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType))
execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
} else {
rdd.map(_.asInstanceOf[InternalRow])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ case class AllDataTypesScan(
UTF8String.fromString(s"varchar_$i"),
Seq(i, i + 1),
Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
Map(i -> i.toString),
Map(i -> UTF8String.fromString(i.toString)),
Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
Row(i, i.toString),
Row(i, UTF8String.fromString(i.toString)),
Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
}
Expand Down

0 comments on commit 326c82c

Please sign in to comment.