Skip to content

Commit

Permalink
[SW-376] Fix DataType and DecimalType conversions on external backend. (
Browse files Browse the repository at this point in the history
#235)

The DateType is internally handled as Long and the corresponsing internal java class for external backend  has to be marked as long as well
The DecimalType is special type which is not listed in supported types. The corresponding internal type needs to be double.
(cherry picked from commit 90adde0)
  • Loading branch information
jakubhava committed Apr 7, 2017
1 parent f815050 commit 2e031c1
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.h2o.backends.external

import org.apache.spark.h2o._
import org.apache.spark.h2o.converters.WriteConverterCtx
import org.apache.spark.h2o.converters.WriteConverterCtxUtils.UploadPlan
import org.apache.spark.h2o.utils.NodeDesc
import org.apache.spark.h2o.utils.{NodeDesc, ReflectionUtils}
import org.apache.spark.h2o.utils.SupportedTypes._
import org.apache.spark.sql.types._
import water.{ExternalFrameUtils, ExternalFrameWriterClient}

class ExternalWriteConverterCtx(nodeDesc: NodeDesc, totalNumOfRows: Int) extends WriteConverterCtx {
Expand Down Expand Up @@ -67,9 +68,31 @@ class ExternalWriteConverterCtx(nodeDesc: NodeDesc, totalNumOfRows: Int) extends

object ExternalWriteConverterCtx extends ExternalBackendUtils {

import scala.language.postfixOps
import scala.reflect.runtime.universe._

def scheduleUpload(numPartitions: Int): UploadPlan = {
val nodes = cloudMembers
val uploadPlan = (0 until numPartitions).zip(Stream.continually(nodes).flatten).toMap
uploadPlan
}

// In external cluster we need to use just the basic types allowed for the conversion. Supported types has
// associated java classes but that's actually not the type as which the data are transferred. The following methods
// overrides this behaviour so the correct internal type for transfer is returned
def internalJavaClassOf[T](implicit ttag: TypeTag[T]) = ReflectionUtils.javaClassOf[T]

def internalJavaClassOf(supportedType: SupportedType) = {
supportedType match {
case Date => Long.javaClass
case _ => supportedType.javaClass
}
}
def internalJavaClassOf(dt: DataType) : Class[_] = {
dt match {
case n if n.isInstanceOf[DecimalType] & n.getClass.getSuperclass != classOf[DecimalType] => Double.javaClass
case _ : DateType => Long.javaClass
case _ : DataType => ReflectionUtils.supportedTypeOf(dt).javaClass
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.h2o.converters

import org.apache.spark.TaskContext
import org.apache.spark.h2o.backends.external.ExternalWriteConverterCtx
import org.apache.spark.h2o.converters.WriteConverterCtxUtils.UploadPlan
import org.apache.spark.h2o.utils.NodeDesc
import org.apache.spark.h2o.utils.SupportedTypes.SupportedType
Expand Down Expand Up @@ -71,7 +72,7 @@ case class H2OFrameFromRDDProductBuilder(hc: H2OContext, rdd: RDD[Product], fram
val expectedTypes = if(hc.getConf.runsInInternalClusterMode){
meta.vecTypes
}else{
val javaClasses = meta.types.map(_.javaClass)
val javaClasses = meta.types.map(ExternalWriteConverterCtx.internalJavaClassOf(_))
ExternalFrameUtils.prepareExpectedTypes(javaClasses)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.h2o.converters

import org.apache.spark.TaskContext
import org.apache.spark.h2o._
import org.apache.spark.h2o.backends.external.ExternalWriteConverterCtx
import org.apache.spark.h2o.converters.WriteConverterCtxUtils.UploadPlan
import org.apache.spark.h2o.utils.{NodeDesc, ReflectionUtils}
import org.apache.spark.internal.Logging
Expand All @@ -43,7 +44,7 @@ private[converters] object PrimitiveRDDConverter extends Logging{
val expectedTypes = if(hc.getConf.runsInInternalClusterMode){
Array[Byte](vecTypeOf[T])
}else{
val clazz = ReflectionUtils.javaClassOf[T]
val clazz = ExternalWriteConverterCtx.internalJavaClassOf[T]
ExternalFrameUtils.prepareExpectedTypes(Array[Class[_]](clazz))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ trait ReadConverterCtx {
var rowIdx: Int = 0

def numRows: Int

def increaseRowIdx() = rowIdx += 1

def hasNext = rowIdx < numRows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ package org.apache.spark.h2o.converters

import org.apache.spark._
import org.apache.spark.h2o.H2OContext
import org.apache.spark.h2o.backends.external.ExternalWriteConverterCtx
import org.apache.spark.h2o.converters.WriteConverterCtxUtils.UploadPlan
import org.apache.spark.h2o.utils.ReflectionUtils._
import org.apache.spark.h2o.utils.{H2OSchemaUtils, NodeDesc}
import org.apache.spark.h2o.utils.{H2OSchemaUtils, ReflectionUtils}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, H2OFrameRelation, Row, SQLContext}
import water.{ExternalFrameUtils, Key}
import water.fvec.{Frame, H2OFrame}
import water.{ExternalFrameUtils, Key}

import scala.collection.immutable

private[h2o] object SparkDataFrameConverter extends Logging {

Expand Down Expand Up @@ -65,10 +64,12 @@ private[h2o] object SparkDataFrameConverter extends Logging {
// otherwise for external backend store expected types
val expectedTypes = if(hc.getConf.runsInInternalClusterMode){
// Transform datatype into h2o types
flatRddSchema.map(f => vecTypeFor(f._2.dataType)).toArray
flatRddSchema.map(f => ReflectionUtils.vecTypeFor(f._2.dataType)).toArray
}else{
val javaClasses = flatRddSchema.map(f => supportedTypeOf(f._2.dataType).javaClass).toArray
ExternalFrameUtils.prepareExpectedTypes(javaClasses)
val internalJavaClasses = flatRddSchema.map{f =>
ExternalWriteConverterCtx.internalJavaClassOf(f._2.dataType)
}.toArray
ExternalFrameUtils.prepareExpectedTypes(internalJavaClasses)
}
WriteConverterCtxUtils.convert[Row](hc, dfRdd, keyName, fnames, expectedTypes, perSQLPartition(flatRddSchema))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ trait WriteConverterCtx {
case n: Double => put(colIdx, n)
case n: String => put(colIdx, n)
case n: java.sql.Timestamp => put(colIdx, n)
case n: java.sql.Date => put(colIdx, n)
case _ => putNA(colIdx)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,31 @@ object ReflectionUtils {

def supportedTypeOf(value : Any): SupportedType = {
value match {
case n: Byte => Byte
case n: Short => Short
case n: Int => Integer
case n: Long => Long
case n: Float => Float
case n: Double => Double
case n: Boolean => Boolean
case n: String => String
case n: java.sql.Timestamp => Timestamp
case n: java.sql.Date => Date
case _: Byte => Byte
case _: Short => Short
case _: Int => Integer
case _: Long => Long
case _: Float => Float
case _: Double => Double
case _: Boolean => Boolean
case _: String => String
case _: java.sql.Timestamp => Timestamp
case _: java.sql.Date => Date
case n: DataType => bySparkType(n)
case q => throw new IllegalArgumentException(s"Do not understand type $q")
}
}

def javaClassOf[T](implicit ttag: TypeTag[T]) = supportedTypeFor(typeOf[T]).javaClass


def javaClassOf(dt: DataType) : Class[_] = {
dt match {
case n if n.isInstanceOf[DecimalType] & n.getClass.getSuperclass != classOf[DecimalType] => Double.javaClass
case _ => bySparkType(dt).javaClass
}
}

def supportedTypeFor(tpe: Type): SupportedType = SupportedTypes.byType(tpe)

def classFor(tpe: Type): Class[_] = supportedTypeFor(tpe).javaClass
Expand Down

0 comments on commit 2e031c1

Please sign in to comment.