diff --git a/core/src/main/scala/org/apache/spark/h2o/backends/external/ExternalWriteConverterCtx.scala b/core/src/main/scala/org/apache/spark/h2o/backends/external/ExternalWriteConverterCtx.scala index 9482efd31f..ec0ac52755 100644 --- a/core/src/main/scala/org/apache/spark/h2o/backends/external/ExternalWriteConverterCtx.scala +++ b/core/src/main/scala/org/apache/spark/h2o/backends/external/ExternalWriteConverterCtx.scala @@ -17,10 +17,11 @@ package org.apache.spark.h2o.backends.external -import org.apache.spark.h2o._ import org.apache.spark.h2o.converters.WriteConverterCtx import org.apache.spark.h2o.converters.WriteConverterCtxUtils.UploadPlan -import org.apache.spark.h2o.utils.NodeDesc +import org.apache.spark.h2o.utils.{NodeDesc, ReflectionUtils} +import org.apache.spark.h2o.utils.SupportedTypes._ +import org.apache.spark.sql.types._ import water.{ExternalFrameUtils, ExternalFrameWriterClient} class ExternalWriteConverterCtx(nodeDesc: NodeDesc, totalNumOfRows: Int) extends WriteConverterCtx { @@ -67,9 +68,31 @@ class ExternalWriteConverterCtx(nodeDesc: NodeDesc, totalNumOfRows: Int) extends object ExternalWriteConverterCtx extends ExternalBackendUtils { + import scala.language.postfixOps + import scala.reflect.runtime.universe._ + def scheduleUpload(numPartitions: Int): UploadPlan = { val nodes = cloudMembers val uploadPlan = (0 until numPartitions).zip(Stream.continually(nodes).flatten).toMap uploadPlan } + + // In external cluster we need to use just the basic types allowed for the conversion. Supported types has + // associated java classes but that's actually not the type as which the data are transferred. The following methods + // overrides this behaviour so the correct internal type for transfer is returned + def internalJavaClassOf[T](implicit ttag: TypeTag[T]) = ReflectionUtils.javaClassOf[T] + + def internalJavaClassOf(supportedType: SupportedType) = { + supportedType match { + case Date => Long.javaClass + case _ => supportedType.javaClass + } + } + def internalJavaClassOf(dt: DataType) : Class[_] = { + dt match { + case n if n.isInstanceOf[DecimalType] & n.getClass.getSuperclass != classOf[DecimalType] => Double.javaClass + case _ : DateType => Long.javaClass + case _ : DataType => ReflectionUtils.supportedTypeOf(dt).javaClass + } + } } diff --git a/core/src/main/scala/org/apache/spark/h2o/converters/H2OFrameFromRDDProductBuilder.scala b/core/src/main/scala/org/apache/spark/h2o/converters/H2OFrameFromRDDProductBuilder.scala index 965fb540ec..07314183d5 100644 --- a/core/src/main/scala/org/apache/spark/h2o/converters/H2OFrameFromRDDProductBuilder.scala +++ b/core/src/main/scala/org/apache/spark/h2o/converters/H2OFrameFromRDDProductBuilder.scala @@ -18,6 +18,7 @@ package org.apache.spark.h2o.converters import org.apache.spark.TaskContext +import org.apache.spark.h2o.backends.external.ExternalWriteConverterCtx import org.apache.spark.h2o.converters.WriteConverterCtxUtils.UploadPlan import org.apache.spark.h2o.utils.NodeDesc import org.apache.spark.h2o.utils.SupportedTypes.SupportedType @@ -71,7 +72,7 @@ case class H2OFrameFromRDDProductBuilder(hc: H2OContext, rdd: RDD[Product], fram val expectedTypes = if(hc.getConf.runsInInternalClusterMode){ meta.vecTypes }else{ - val javaClasses = meta.types.map(_.javaClass) + val javaClasses = meta.types.map(ExternalWriteConverterCtx.internalJavaClassOf(_)) ExternalFrameUtils.prepareExpectedTypes(javaClasses) } diff --git a/core/src/main/scala/org/apache/spark/h2o/converters/PrimitiveRDDConverter.scala b/core/src/main/scala/org/apache/spark/h2o/converters/PrimitiveRDDConverter.scala index 67f4120169..a1a0c9ae15 100644 --- a/core/src/main/scala/org/apache/spark/h2o/converters/PrimitiveRDDConverter.scala +++ b/core/src/main/scala/org/apache/spark/h2o/converters/PrimitiveRDDConverter.scala @@ -19,6 +19,7 @@ package org.apache.spark.h2o.converters import org.apache.spark.TaskContext import org.apache.spark.h2o._ +import org.apache.spark.h2o.backends.external.ExternalWriteConverterCtx import org.apache.spark.h2o.converters.WriteConverterCtxUtils.UploadPlan import org.apache.spark.h2o.utils.{NodeDesc, ReflectionUtils} import org.apache.spark.internal.Logging @@ -43,7 +44,7 @@ private[converters] object PrimitiveRDDConverter extends Logging{ val expectedTypes = if(hc.getConf.runsInInternalClusterMode){ Array[Byte](vecTypeOf[T]) }else{ - val clazz = ReflectionUtils.javaClassOf[T] + val clazz = ExternalWriteConverterCtx.internalJavaClassOf[T] ExternalFrameUtils.prepareExpectedTypes(Array[Class[_]](clazz)) } diff --git a/core/src/main/scala/org/apache/spark/h2o/converters/ReadConverterCtx.scala b/core/src/main/scala/org/apache/spark/h2o/converters/ReadConverterCtx.scala index 9e66d5e378..4d98a08dc8 100644 --- a/core/src/main/scala/org/apache/spark/h2o/converters/ReadConverterCtx.scala +++ b/core/src/main/scala/org/apache/spark/h2o/converters/ReadConverterCtx.scala @@ -46,6 +46,7 @@ trait ReadConverterCtx { var rowIdx: Int = 0 def numRows: Int + def increaseRowIdx() = rowIdx += 1 def hasNext = rowIdx < numRows diff --git a/core/src/main/scala/org/apache/spark/h2o/converters/SparkDataFrameConverter.scala b/core/src/main/scala/org/apache/spark/h2o/converters/SparkDataFrameConverter.scala index f7211ced06..2bb4d45bfe 100644 --- a/core/src/main/scala/org/apache/spark/h2o/converters/SparkDataFrameConverter.scala +++ b/core/src/main/scala/org/apache/spark/h2o/converters/SparkDataFrameConverter.scala @@ -19,16 +19,15 @@ package org.apache.spark.h2o.converters import org.apache.spark._ import org.apache.spark.h2o.H2OContext +import org.apache.spark.h2o.backends.external.ExternalWriteConverterCtx import org.apache.spark.h2o.converters.WriteConverterCtxUtils.UploadPlan -import org.apache.spark.h2o.utils.ReflectionUtils._ -import org.apache.spark.h2o.utils.{H2OSchemaUtils, NodeDesc} +import org.apache.spark.h2o.utils.{H2OSchemaUtils, ReflectionUtils} import org.apache.spark.internal.Logging import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, H2OFrameRelation, Row, SQLContext} -import water.{ExternalFrameUtils, Key} import water.fvec.{Frame, H2OFrame} +import water.{ExternalFrameUtils, Key} -import scala.collection.immutable private[h2o] object SparkDataFrameConverter extends Logging { @@ -65,10 +64,12 @@ private[h2o] object SparkDataFrameConverter extends Logging { // otherwise for external backend store expected types val expectedTypes = if(hc.getConf.runsInInternalClusterMode){ // Transform datatype into h2o types - flatRddSchema.map(f => vecTypeFor(f._2.dataType)).toArray + flatRddSchema.map(f => ReflectionUtils.vecTypeFor(f._2.dataType)).toArray }else{ - val javaClasses = flatRddSchema.map(f => supportedTypeOf(f._2.dataType).javaClass).toArray - ExternalFrameUtils.prepareExpectedTypes(javaClasses) + val internalJavaClasses = flatRddSchema.map{f => + ExternalWriteConverterCtx.internalJavaClassOf(f._2.dataType) + }.toArray + ExternalFrameUtils.prepareExpectedTypes(internalJavaClasses) } WriteConverterCtxUtils.convert[Row](hc, dfRdd, keyName, fnames, expectedTypes, perSQLPartition(flatRddSchema)) } diff --git a/core/src/main/scala/org/apache/spark/h2o/converters/WriteConverterCtx.scala b/core/src/main/scala/org/apache/spark/h2o/converters/WriteConverterCtx.scala index 8af7d55c95..665d6fce90 100644 --- a/core/src/main/scala/org/apache/spark/h2o/converters/WriteConverterCtx.scala +++ b/core/src/main/scala/org/apache/spark/h2o/converters/WriteConverterCtx.scala @@ -64,6 +64,7 @@ trait WriteConverterCtx { case n: Double => put(colIdx, n) case n: String => put(colIdx, n) case n: java.sql.Timestamp => put(colIdx, n) + case n: java.sql.Date => put(colIdx, n) case _ => putNA(colIdx) } } diff --git a/core/src/main/scala/org/apache/spark/h2o/utils/ReflectionUtils.scala b/core/src/main/scala/org/apache/spark/h2o/utils/ReflectionUtils.scala index 6a4c8e5537..d3b9d138a5 100644 --- a/core/src/main/scala/org/apache/spark/h2o/utils/ReflectionUtils.scala +++ b/core/src/main/scala/org/apache/spark/h2o/utils/ReflectionUtils.scala @@ -98,16 +98,16 @@ object ReflectionUtils { def supportedTypeOf(value : Any): SupportedType = { value match { - case n: Byte => Byte - case n: Short => Short - case n: Int => Integer - case n: Long => Long - case n: Float => Float - case n: Double => Double - case n: Boolean => Boolean - case n: String => String - case n: java.sql.Timestamp => Timestamp - case n: java.sql.Date => Date + case _: Byte => Byte + case _: Short => Short + case _: Int => Integer + case _: Long => Long + case _: Float => Float + case _: Double => Double + case _: Boolean => Boolean + case _: String => String + case _: java.sql.Timestamp => Timestamp + case _: java.sql.Date => Date case n: DataType => bySparkType(n) case q => throw new IllegalArgumentException(s"Do not understand type $q") } @@ -115,6 +115,14 @@ object ReflectionUtils { def javaClassOf[T](implicit ttag: TypeTag[T]) = supportedTypeFor(typeOf[T]).javaClass + + def javaClassOf(dt: DataType) : Class[_] = { + dt match { + case n if n.isInstanceOf[DecimalType] & n.getClass.getSuperclass != classOf[DecimalType] => Double.javaClass + case _ => bySparkType(dt).javaClass + } + } + def supportedTypeFor(tpe: Type): SupportedType = SupportedTypes.byType(tpe) def classFor(tpe: Type): Class[_] = supportedTypeFor(tpe).javaClass