diff --git a/src/main/scala/org/tensorframes/ColumnInformation.scala b/src/main/scala/org/tensorframes/ColumnInformation.scala index d06a76c..e4c5187 100644 --- a/src/main/scala/org/tensorframes/ColumnInformation.scala +++ b/src/main/scala/org/tensorframes/ColumnInformation.scala @@ -2,6 +2,8 @@ package org.tensorframes import org.apache.spark.sql.types._ +import org.tensorframes.impl.{ScalarType, SupportedOperations} + class ColumnInformation private ( val field: StructField, @@ -15,7 +17,9 @@ class ColumnInformation private ( val b = new MetadataBuilder().withMetadata(field.metadata) for (info <- stf) { b.putLongArray(shapeKey, info.shape.dims.toArray) - b.putString(tensorStructType, info.dataType.toString) + // Keep the SQL name, so that we do not leak internal details. + val dt = SupportedOperations.opsFor(info.dataType).sqlType + b.putString(tensorStructType, dt.toString) } val meta = b.build() field.copy(metadata = meta) @@ -73,15 +77,15 @@ object ColumnInformation extends Logging { * @param scalarType the data type * @param blockShape the shape of the block */ - def structField(name: String, scalarType: NumericType, blockShape: Shape): StructField = { + def structField(name: String, scalarType: ScalarType, blockShape: Shape): StructField = { val i = SparkTFColInfo(blockShape, scalarType) val f = StructField(name, sqlType(scalarType, blockShape.tail), nullable = false) ColumnInformation(f, i).merged } - private def sqlType(scalarType: NumericType, shape: Shape): DataType = { + private def sqlType(scalarType: ScalarType, shape: Shape): DataType = { if (shape.dims.isEmpty) { - scalarType + SupportedOperations.opsFor(scalarType).sqlType } else { ArrayType(sqlType(scalarType, shape.tail), containsNull = false) } @@ -102,11 +106,14 @@ object ColumnInformation extends Logging { for { s <- shape t <- tpe - } yield SparkTFColInfo(s, t) + ops <- SupportedOperations.getOps(t) + } yield SparkTFColInfo(s, ops.scalarType) } - private def getType(s: String): Option[NumericType] = { - supportedTypes.find(_.toString == s) + private def getType(s: String): Option[DataType] = { + val res = supportedTypes.find(_.toString == s) + logInfo(s"getType: $s -> $res") + res } /** @@ -115,19 +122,18 @@ object ColumnInformation extends Logging { * @return */ private def extractFromRow(dt: DataType): Option[SparkTFColInfo] = dt match { - case x: NumericType if MetadataConstants.supportedTypes.contains(dt) => - logTrace("numerictype: " + x) - // It is a basic type that we understand - Some(SparkTFColInfo(Shape(Unknown), x)) case x: ArrayType => logTrace("arraytype: " + x) // Look into the array to figure out the type. extractFromRow(x.elementType).map { info => SparkTFColInfo(info.shape.prepend(Unknown), info.dataType) } - case _ => - logTrace("not understood: " + dt) - // Not understood. - None + case _ => SupportedOperations.getOps(dt) match { + case Some(ops) => + logTrace("numerictype: " + ops.scalarType) + // It is a basic type that we understand + Some(SparkTFColInfo(Shape(Unknown), ops.scalarType)) + case None => None + } } } diff --git a/src/main/scala/org/tensorframes/ExperimentalOperations.scala b/src/main/scala/org/tensorframes/ExperimentalOperations.scala index a622104..57ae8d1 100644 --- a/src/main/scala/org/tensorframes/ExperimentalOperations.scala +++ b/src/main/scala/org/tensorframes/ExperimentalOperations.scala @@ -3,7 +3,8 @@ package org.tensorframes import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{ArrayType, DataType, NumericType} -import org.tensorframes.impl.SupportedOperations + +import org.tensorframes.impl.{ScalarType, SupportedOperations} /** * Some useful methods for operating on dataframes that are not part of the official API (and thus may change anytime). @@ -109,8 +110,8 @@ private[tensorframes] object ExtraOperations extends ExperimentalOperations with DataFrameInfo(allInfo) } - private def extractBasicType(dt: DataType): Option[NumericType] = dt match { - case x: NumericType => Some(x) + private def extractBasicType(dt: DataType): Option[ScalarType] = dt match { + case x: NumericType => Some(SupportedOperations.opsFor(x).scalarType) case x: ArrayType => extractBasicType(x.elementType) case _ => None } diff --git a/src/main/scala/org/tensorframes/MetadataConstants.scala b/src/main/scala/org/tensorframes/MetadataConstants.scala index affec0d..d0aea61 100644 --- a/src/main/scala/org/tensorframes/MetadataConstants.scala +++ b/src/main/scala/org/tensorframes/MetadataConstants.scala @@ -1,7 +1,7 @@ package org.tensorframes -import org.apache.spark.sql.types.NumericType -import org.tensorframes.impl.SupportedOperations +import org.apache.spark.sql.types.{DataType, NumericType} +import org.tensorframes.impl.{ScalarType, SupportedOperations} /** * Metadata annotations that get embedded in dataframes to express tensor information. @@ -29,5 +29,5 @@ object MetadataConstants { /** * All the SQL types supported by SparkTF. */ - val supportedTypes: Seq[NumericType] = SupportedOperations.sqlTypes + val supportedTypes: Seq[DataType] = SupportedOperations.sqlTypes } \ No newline at end of file diff --git a/src/main/scala/org/tensorframes/Shape.scala b/src/main/scala/org/tensorframes/Shape.scala index b7d9859..091e51b 100644 --- a/src/main/scala/org/tensorframes/Shape.scala +++ b/src/main/scala/org/tensorframes/Shape.scala @@ -1,9 +1,11 @@ package org.tensorframes -import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.types.{BinaryType, DataType, NumericType} import org.tensorflow.framework.TensorShapeProto + import scala.collection.JavaConverters._ import org.tensorframes.Shape.DimType +import org.tensorframes.impl.ScalarType import org.{tensorflow => tf} @@ -36,6 +38,11 @@ class Shape private (private val ds: Array[DimType]) extends Serializable { def prepend(x: Int): Shape = Shape(x.toLong +: ds) + /** + * Drops the most inner dimension of the shape. + */ + def dropInner: Shape = Shape(ds.dropRight(1)) + /** * A shape with the first dimension dropped. */ @@ -104,15 +111,16 @@ object Shape { /** * SparkTF information. This is the information generally required to work on a tensor. - * @param shape - * @param dataType + * @param shape the shape of the column (including the number of rows). May contain some unknowns. + * @param dataType the datatype of the scalar. Note that it is either NumericType or BinaryType. */ // TODO(tjh) the types supported by TF are much richer (uint8, etc.) but it is not clear // if they all map to a Catalyst memory representation // TODO(tjh) support later basic structures for sparse types? case class SparkTFColInfo( shape: Shape, - dataType: NumericType) extends Serializable + dataType: ScalarType) extends Serializable { +} /** * Exception thrown when the user requests tensors of high order. diff --git a/src/main/scala/org/tensorframes/dsl/DslImpl.scala b/src/main/scala/org/tensorframes/dsl/DslImpl.scala index 3795c9f..c2c7d41 100644 --- a/src/main/scala/org/tensorframes/dsl/DslImpl.scala +++ b/src/main/scala/org/tensorframes/dsl/DslImpl.scala @@ -1,13 +1,14 @@ package org.tensorframes.dsl import javax.annotation.Nullable + import org.tensorflow.framework.{AttrValue, DataType, GraphDef, TensorShapeProto} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.NumericType -import org.tensorframes.{Logging, ColumnInformation, Shape} -import org.tensorframes.impl.DenseTensor +import org.tensorframes.{ColumnInformation, Logging, Shape} +import org.tensorframes.impl.{DenseTensor, SupportedOperations} /** @@ -75,8 +76,9 @@ private[dsl] object DslImpl extends Logging with DefaultConversions { def build_constant(dt: DenseTensor): Node = { val a = AttrValue.newBuilder().setTensor(DenseTensor.toTensorProto(dt)) + val dt2 = SupportedOperations.opsFor(dt.dtype).sqlType.asInstanceOf[NumericType] build("Const", isOp = false, - shape = dt.shape, dtype = dt.dtype, + shape = dt.shape, dtype = dt2, extraAttrs = Map("value" -> a.build())) } @@ -100,7 +102,8 @@ private[dsl] object DslImpl extends Logging with DefaultConversions { s"tensorframes: $schema") } val shape = if (block) { stf.shape } else { stf.shape.tail } - DslImpl.placeholder(stf.dataType, shape).named(tfName) + val dt = SupportedOperations.opsFor(stf.dataType).sqlType.asInstanceOf[NumericType] + DslImpl.placeholder(dt, shape).named(tfName) } private def commonShape(shapes: Seq[Shape]): Shape = { diff --git a/src/main/scala/org/tensorframes/dsl/package.scala b/src/main/scala/org/tensorframes/dsl/package.scala index adfa39a..2a787d6 100644 --- a/src/main/scala/org/tensorframes/dsl/package.scala +++ b/src/main/scala/org/tensorframes/dsl/package.scala @@ -1,10 +1,8 @@ package org.tensorframes import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.IntegerType - +import org.apache.spark.sql.types.{IntegerType, NumericType} import org.tensorframes.impl.SupportedOperations /** @@ -45,7 +43,7 @@ package object dsl { def placeholder[T : Numeric : TypeTag](shape: Int*): Operation = { val ops = SupportedOperations.getOps[T]() - DslImpl.placeholder(ops.sqlType, Shape(shape: _*)) + DslImpl.placeholder(ops.sqlType.asInstanceOf[NumericType], Shape(shape: _*)) } def constant[T : ConvertibleToDenseTensor](x: T): Operation = { diff --git a/src/main/scala/org/tensorframes/impl/DataOps.scala b/src/main/scala/org/tensorframes/impl/DataOps.scala index 6e60f1a..d349c64 100644 --- a/src/main/scala/org/tensorframes/impl/DataOps.scala +++ b/src/main/scala/org/tensorframes/impl/DataOps.scala @@ -5,7 +5,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.StructType import org.tensorframes.{Logging, Shape} import org.tensorframes.Shape.DimType @@ -145,7 +145,7 @@ object DataOps extends Logging { def getColumnFast0( reshapeShape: Shape, - scalaType: NumericType, + scalaType: ScalarType, allDataBuffer: mutable.WrappedArray[_]): Iterable[Any] = { reshapeShape.dims match { case Seq() => diff --git a/src/main/scala/org/tensorframes/impl/DebugRowOps.scala b/src/main/scala/org/tensorframes/impl/DebugRowOps.scala index e9dbd32..af4c2dd 100644 --- a/src/main/scala/org/tensorframes/impl/DebugRowOps.scala +++ b/src/main/scala/org/tensorframes/impl/DebugRowOps.scala @@ -1,20 +1,23 @@ package org.tensorframes.impl +import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + import org.apache.commons.lang3.SerializationUtils +import org.tensorflow.framework.GraphDef +import org.tensorflow.{Session, Tensor} + import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import org.apache.spark.sql.functions.col import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row} -import org.tensorflow.framework.GraphDef -import org.tensorflow.{Session, Tensor} + import org.tensorframes._ import org.tensorframes.test.DslOperations -import scala.collection.mutable -import scala.collection.JavaConverters._ -import scala.util.{Failure, Success, Try} /** * The different schemas required for the block reduction. @@ -322,17 +325,17 @@ class DebugRowOps throw new Exception( s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe") } - if (! stf.shape.checkMorePreciseThan(in.shape)) { - throw new Exception( - s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + - s" ${in.shape} requested by the TF graph") - } // We do not support autocasting for now. if (stf.dataType != in.scalarType) { throw new Exception( s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + s"of the column (${in.scalarType})") } + if (! stf.shape.checkMorePreciseThan(in.shape)) { + throw new Exception( + s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + + s" ${in.shape} requested by the TF graph") + } // The input has to be either a constant or a placeholder if (! in.isPlaceholder) { throw new Exception( @@ -414,16 +417,16 @@ class DebugRowOps val stf = get(ColumnInformation(f).stf, s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe") + check(stf.dataType == in.scalarType, + s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + + s"of the column (${in.scalarType})") + val cellShape = stf.shape.tail // No check for unknowns: we allow unknowns in the first dimension of the cell shape. check(cellShape.checkMorePreciseThan(in.shape), s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + s" ${in.shape} requested by the TF graph") - check(stf.dataType == in.scalarType, - s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + - s"of the column (${in.scalarType})") - check(in.isPlaceholder, s"Invalid type for input node ${in.name}. It has to be a placeholder") } @@ -532,7 +535,8 @@ class DebugRowOps val f = col.field builder.append(s"$prefix-- ${f.name}: ${f.dataType.typeName} (nullable = ${f.nullable})") val stf = col.stf.map { s => - s" ${s.dataType.typeName}${s.shape}" + val dt = SupportedOperations.opsFor(s.dataType).sqlType + s" ${dt.typeName}${s.shape}" } .getOrElse(" ") builder.append(stf) builder.append("\n") @@ -725,9 +729,6 @@ object DebugRowOpsImpl extends Logging { } } - // Trying to get around some frequent crashes within TF. - private[this] val tfLock = new Object - private[impl] def reducePair( schema: StructType, gbc: Broadcast[SerializedGraph]): (Row, Row) => Row = { diff --git a/src/main/scala/org/tensorframes/impl/DenseTensor.scala b/src/main/scala/org/tensorframes/impl/DenseTensor.scala index 7414f73..d9e30e2 100644 --- a/src/main/scala/org/tensorframes/impl/DenseTensor.scala +++ b/src/main/scala/org/tensorframes/impl/DenseTensor.scala @@ -17,27 +17,31 @@ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, NumericTy */ private[tensorframes] class DenseTensor private( val shape: Shape, - val dtype: NumericType, + val dtype: ScalarType, private val data: Array[Byte]) { override def toString(): String = s"DenseTensor($shape, $dtype, " + - s"${data.length / dtype.defaultSize} elements)" + s"${data.length} bytes)" } private[tensorframes] object DenseTensor { def apply[T](x: T)(implicit ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - new DenseTensor(Shape.empty, ops.sqlType, convert(x)) + apply(Shape.empty, ops.sqlType.asInstanceOf[NumericType], convert(x)) } def apply[T](xs: Seq[T])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - new DenseTensor(Shape(xs.size), ops.sqlType, convert1(xs)) + apply(Shape(xs.size), ops.sqlType.asInstanceOf[NumericType], convert1(xs)) + } + + def apply(shape: Shape, dtype: NumericType, data: Array[Byte]): DenseTensor = { + new DenseTensor(shape, SupportedOperations.opsFor(dtype).scalarType, data) } def matrix[T](xs: Seq[Seq[T]])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - new DenseTensor(Shape(xs.size, xs.head.size), ops.sqlType, convert2(xs)) + apply(Shape(xs.size, xs.head.size), ops.sqlType.asInstanceOf[NumericType], convert2(xs)) } private def convert[T](x: T)(implicit ev2: TypeTag[T]): Array[Byte] = { @@ -98,15 +102,15 @@ private[tensorframes] object DenseTensor { val shape = Shape.from(proto.getTensorShape) val data = ops.sqlType match { case DoubleType => - val coll = proto.getDoubleValList.asScala.toSeq.map(_.doubleValue()) + val coll = proto.getDoubleValList.asScala.map(_.doubleValue()) convert(coll) case IntegerType => - val coll = proto.getIntValList.asScala.toSeq.map(_.intValue()) + val coll = proto.getIntValList.asScala.map(_.intValue()) convert(coll) case _ => throw new IllegalArgumentException( s"Cannot convert type ${ops.sqlType}") } - new DenseTensor(shape, ops.sqlType, data) + new DenseTensor(shape, ops.scalarType, data) } } diff --git a/src/main/scala/org/tensorframes/impl/TFDataOps.scala b/src/main/scala/org/tensorframes/impl/TFDataOps.scala index 5ed609e..f3c4c6b 100644 --- a/src/main/scala/org/tensorframes/impl/TFDataOps.scala +++ b/src/main/scala/org/tensorframes/impl/TFDataOps.scala @@ -184,7 +184,7 @@ object TFDataOps extends Logging { */ private def getColumn( t: tf.Tensor, - scalaType: NumericType, + scalaType: ScalarType, cellShape: Shape, expectedNumRows: Option[Int], fastPath: Boolean = true): (Int, Iterable[Any]) = { diff --git a/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala b/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala index 449db0e..f0aca5c 100644 --- a/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala +++ b/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala @@ -140,10 +140,10 @@ object TensorFlowOps extends Logging { } } - private def getSummaryDefault(op: tf.Operation): Seq[(NumericType, Shape)] = { + private def getSummaryDefault(op: tf.Operation): Seq[(ScalarType, Shape)] = { (0 until op.numOutputs()).map { idx => val n = op.output(idx) - val dt = SupportedOperations.opsFor(n.dataType()).sqlType + val dt = SupportedOperations.opsFor(n.dataType()).scalarType val shape = Shape.from(n.shape()) dt -> shape } @@ -164,6 +164,6 @@ case class GraphNodeSummary( isPlaceholder: Boolean, isInput: Boolean, isOutput: Boolean, - scalarType: NumericType, + scalarType: ScalarType, shape: Shape, name: String) extends Serializable diff --git a/src/main/scala/org/tensorframes/impl/datatypes.scala b/src/main/scala/org/tensorframes/impl/datatypes.scala index 9be2bfa..a6966f8 100644 --- a/src/main/scala/org/tensorframes/impl/datatypes.scala +++ b/src/main/scala/org/tensorframes/impl/datatypes.scala @@ -5,7 +5,7 @@ import java.nio._ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.{tensorflow => tf} -import org.tensorflow.framework.DataType +import org.tensorflow.framework.{DataType => ProtoDataType} import org.tensorframes.{Logging, Shape} import scala.collection.mutable.{WrappedArray => MWrappedArray} @@ -18,6 +18,39 @@ import scala.reflect.runtime.universe.TypeTag // - jvm: ??? // - protobuf: ??? +/** + * All the types of scalars supported by TensorFrames. + * + * It can be argued that the Binary type is not really a scalar, + * but it is considered as such by both Spark and TensorFlow. + */ +trait ScalarType + +/** + * Int32 + */ +case object ScalarIntType extends ScalarType + +/** + * INT64 + */ +case object ScalarLongType extends ScalarType + +/** + * FLOAT64 + */ +case object ScalarDoubleType extends ScalarType + +/** + * FLOAT32 + */ +case object ScalarFloatType extends ScalarType + +/** + * STRING / BINARY + */ +case object ScalarBinaryType extends ScalarType + /** * @param shape the shape of the element in the row (not the overall shape of the block) * @param numCells the number of cells that are going to be allocated with the given shape. @@ -27,7 +60,7 @@ import scala.reflect.runtime.universe.TypeTag private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, Float, Int, Long) T] ( val shape: Shape, val numCells: Int) - (implicit ev1: TypeTag[T], ev2: ClassTag[T]) extends Logging { + (implicit ev2: ClassTag[T]) extends Logging { final val empty = Array.empty[T] /** * Creates memory space for a given number of units of the given shape. @@ -79,6 +112,7 @@ private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, // The return element is just here so that the method gets specialized (otherwise it would not). final def append(row: Row, position: Int): Array[T] = { + logger.debug(s"append: position=$position row=$row") val d = shape.numDims if (d == 0) { appendRaw(row.getAs[T](position)) @@ -125,17 +159,16 @@ private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, * It does not support TF's rich type collection (uint16, float128, etc.). These have to be handled * internally through casting. */ -private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int, Long, Double, Float) T] - (implicit ev1: TypeTag[T], ev2: ClassTag[T]) { +private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int, Long, Double, Float) T] { /** * The SQL type associated with the given type. */ - val sqlType: NumericType + val sqlType: DataType /** * The TF type */ - val tfType: DataType + val tfType: ProtoDataType /** * The TF type (new style). @@ -143,6 +176,11 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int */ val tfType2: tf.DataType + /** + * The type of the scalar value. + */ + val scalarType: ScalarType + /** * A zero element for this type */ @@ -217,25 +255,42 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int res.map { arr => conv(arr.map(conv)) } } - def tag: TypeTag[_] = implicitly[TypeTag[T]] + implicit def classTag: ClassTag[T] = ev + + def tag: Option[TypeTag[_]] + + def ev: ClassTag[T] = null } private[tensorframes] object SupportedOperations { private val ops: Seq[ScalarTypeOperation[_]] = - Seq(DoubleOperations, FloatOperations, IntOperations, LongOperations) + Seq(DoubleOperations, FloatOperations, IntOperations, LongOperations, StringOperations) val sqlTypes = ops.map(_.sqlType) + val scalarTypes = ops.map(_.scalarType) + private val tfTypes = ops.map(_.tfType) - def opsFor(t: NumericType): ScalarTypeOperation[_] = { + def getOps(t: DataType): Option[ScalarTypeOperation[_]] = { + ops.find(_.sqlType == t) + } + + def opsFor(t: DataType): ScalarTypeOperation[_] = { ops.find(_.sqlType == t).getOrElse { throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + s"supported: ${sqlTypes.mkString(", ")}") } } - def opsFor(t: DataType): ScalarTypeOperation[_] = { + def opsFor(t: ScalarType): ScalarTypeOperation[_] = { + ops.find(_.scalarType == t).getOrElse { + throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + + s"supported: ${sqlTypes.mkString(", ")}") + } + } + + def opsFor(t: ProtoDataType): ScalarTypeOperation[_] = { ops.find(_.tfType == t).getOrElse { throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + s"supported: ${tfTypes.mkString(", ")}") @@ -252,7 +307,7 @@ private[tensorframes] object SupportedOperations { def getOps[T : TypeTag](): ScalarTypeOperation[T] = { val ev: TypeTag[_] = implicitly[TypeTag[T]] - ops.find(_.tag.tpe =:= ev.tpe).getOrElse { + ops.find(_.tag.map(_.tpe =:= ev.tpe) == Some(true)).getOrElse { val tags = ops.map(_.tag.toString()).mkString(", ") throw new IllegalArgumentException(s"Type ${ev} is not supported. Only the following types " + s"are supported: ${tags}") @@ -299,8 +354,9 @@ private[impl] class DoubleTensorConverter(s: Shape, numCells: Int) private[impl] object DoubleOperations extends ScalarTypeOperation[Double] with Logging { override val sqlType = DoubleType - override val tfType = DataType.DT_DOUBLE + override val tfType = ProtoDataType.DT_DOUBLE override val tfType2 = tf.DataType.DOUBLE + override val scalarType = ScalarDoubleType final override val zero = 0.0 override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Double] = new DoubleTensorConverter(cellShape, numCells) @@ -325,6 +381,9 @@ private[impl] object DoubleOperations extends ScalarTypeOperation[Double] with L res } + override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Double]]) + + override def ev = ClassTag.Double } // ********** FLOAT ************ @@ -358,8 +417,9 @@ private[impl] class FloatTensorConverter(s: Shape, numCells: Int) private[impl] object FloatOperations extends ScalarTypeOperation[Float] with Logging { override val sqlType = FloatType - override val tfType = DataType.DT_FLOAT + override val tfType = ProtoDataType.DT_FLOAT override val tfType2 = tf.DataType.FLOAT + override val scalarType = ScalarFloatType final override val zero = 0.0f override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Float] = new FloatTensorConverter(cellShape, numCells) @@ -381,6 +441,10 @@ private[impl] object FloatOperations extends ScalarTypeOperation[Float] with Log t.writeTo(b) res } + + override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Float]]) + + override def ev = ClassTag.Float } // ********** INT32 ************ @@ -414,8 +478,9 @@ private[impl] class IntTensorConverter(s: Shape, numCells: Int) private[impl] object IntOperations extends ScalarTypeOperation[Int] with Logging { override val sqlType = IntegerType - override val tfType = DataType.DT_INT32 + override val tfType = ProtoDataType.DT_INT32 override val tfType2 = tf.DataType.INT32 + override val scalarType = ScalarIntType final override val zero = 0 override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Int] = new IntTensorConverter(cellShape, numCells) @@ -434,6 +499,10 @@ private[impl] object IntOperations extends ScalarTypeOperation[Int] with Logging dbuff.get(res) res } + + override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Int]]) + + override def ev = ClassTag.Int } // ****** INT64 (LONG) ****** @@ -467,8 +536,9 @@ private[impl] class LongTensorConverter(s: Shape, numCells: Int) private[impl] object LongOperations extends ScalarTypeOperation[Long] with Logging { override val sqlType = LongType - override val tfType = DataType.DT_INT64 + override val tfType = ProtoDataType.DT_INT64 override val tfType2 = tf.DataType.INT64 + override val scalarType = ScalarLongType final override val zero = 0L override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Long] = new LongTensorConverter(cellShape, numCells) @@ -488,4 +558,67 @@ private[impl] object LongOperations extends ScalarTypeOperation[Long] with Loggi logTrace(s"Extracted from buffer: ${res.toSeq}") res } -} \ No newline at end of file + + override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Long]]) + + override def ev = ClassTag.Long +} + +// ********** STRING ********* +// This is actually byte arrays, which corresponds to the 'binary' type in Spark. + +// The string converter can only deal with one row at a time (the most common case). +private[impl] class StringTensorConverter(s: Shape, numCells: Int) + extends TensorConverter[Array[Byte]](s, numCells) with Logging { + private var buffer: Array[Byte] = null + + override val elementSize: Int = 1 + + { + logger.debug(s"Creating string buffer for shape $s and $numCells cells") + assert(s == Shape() && numCells == 1, s"The string buffer does not accept more than one" + + s" scalar of type binary. shape=$s numCells=$numCells") + } + + + override def reserve(): Unit = {} + + override def appendRaw(d: Array[Byte]): Unit = { + assert(buffer == null, s"The buffer has only been set with ${buffer.length} values," + + s" but ${d.length} are trying to get inserted") + buffer = d.clone() + } + + override def tensor2(): tf.Tensor = { + tf.Tensor.create(buffer) + } + + override def fillBuffer(buff: ByteBuffer): Unit = { + buff.put(buffer) + } +} + +private[impl] object StringOperations extends ScalarTypeOperation[Array[Byte]] with Logging { + override val sqlType = BinaryType + override val tfType = ProtoDataType.DT_STRING + override val tfType2 = tf.DataType.STRING + override val scalarType = ScalarBinaryType + final override val zero = Array.empty[Byte] + + override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Array[Byte]] = + new StringTensorConverter(cellShape, numCells) + + override def convertTensor(t: tf.Tensor): MWrappedArray[Array[Byte]] = { + throw new Exception(s"convertTensor is not implemented for strings") + } + + override def convertBuffer(buff: ByteBuffer, numElements: Int): Iterable[Any] = { + throw new Exception(s"convertBuffer is not implemented for strings") + } + + override def tag: Option[TypeTag[_]] = None + + override def ev = throw new Exception(s"ev is not implemented for strings") +} + + diff --git a/src/main/scala/org/tensorframes/test/dsl.scala b/src/main/scala/org/tensorframes/test/dsl.scala index 1d14503..fbe4f7a 100644 --- a/src/main/scala/org/tensorframes/test/dsl.scala +++ b/src/main/scala/org/tensorframes/test/dsl.scala @@ -2,10 +2,10 @@ package org.tensorframes.test import java.nio.file.{Files, Paths} -import org.apache.spark.sql.types.{DoubleType, NumericType} -import org.tensorflow.framework._ +import org.apache.spark.sql.types.{DataType, NumericType} +import org.tensorflow.framework.{AttrValue, GraphDef, NodeDef, TensorShapeProto, DataType => ProtoDataType} import org.tensorframes.{Logging, Shape} -import org.tensorframes.impl.{DenseTensor, SupportedOperations} +import org.tensorframes.impl.{DenseTensor, ScalarType, SupportedOperations} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe._ @@ -25,7 +25,7 @@ object dsl extends Logging { def toAttr: AttrValue = buildType(s) } - private implicit class DataTypeToAttr(dt: DataType) { + private implicit class DataTypeToAttr(dt: ProtoDataType) { def toAttr: AttrValue = dataTypeToAttrValue(dt) } @@ -66,8 +66,8 @@ object dsl extends Logging { def +(other: Node): Node = op_add(this, other) } - private[tensorframes] def placeholder(dtype: NumericType, shape: Shape): Node = { - build("Placeholder", shape=shape, dtype=dtype, isOp = false, + private[tensorframes] def placeholder(dtype: DataType, shape: Shape): Node = { + build("Placeholder", shape=shape, dtype=dtype.asInstanceOf[NumericType], isOp = false, extraAttrs = Map("shape" -> shape.toAttr)) } @@ -165,8 +165,9 @@ object dsl extends Logging { private def build_constant(dt: DenseTensor): Node = { val a = AttrValue.newBuilder().setTensor(DenseTensor.toTensorProto(dt)) + val dt2 = SupportedOperations.opsFor(dt.dtype).sqlType.asInstanceOf[NumericType] build("Const", isOp = false, - shape = dt.shape, dtype = dt.dtype, + shape = dt.shape, dtype = dt2, extraAttrs = Map("value" -> a.build())) } @@ -196,7 +197,7 @@ object dsl extends Logging { dtype = parent.scalarType, shape = reduce_shape(parent.shape, Option(reduction_indices).getOrElse(Nil)), extraAttrs = Map( - "Tidx" -> AttrValue.newBuilder().setType(DataType.DT_INT32).build(), + "Tidx" -> AttrValue.newBuilder().setType(ProtoDataType.DT_INT32).build(), "keep_dims" -> AttrValue.newBuilder().setB(false).build())) } @@ -218,13 +219,16 @@ object dsl extends Logging { * Utilities to convert data back and forth between the proto descriptions and the dataframe descriptions. */ object ProtoConversions { - def getDType(nodeDef: NodeDef): DataType = { + def getDType(nodeDef: NodeDef): ProtoDataType = { val opt = Option(nodeDef.getAttr.get("T")).orElse(Option(nodeDef.getAttr.get("dtype"))) val v = opt.getOrElse(throw new Exception(s"Neither 'T' no 'dtype' was found in $nodeDef")) v.getType } - def getDType(sqlType: NumericType): DataType = { + def getDType(sqlType: NumericType): ProtoDataType = { + SupportedOperations.opsFor(sqlType).tfType + } + def getDType(sqlType: ScalarType): ProtoDataType = { SupportedOperations.opsFor(sqlType).tfType } @@ -232,7 +236,7 @@ object ProtoConversions { AttrValue.newBuilder().setType(getDType(sqlType)).build() } - def dataTypeToAttrValue(dataType: DataType): AttrValue = { + def dataTypeToAttrValue(dataType: ProtoDataType): AttrValue = { AttrValue.newBuilder().setType(dataType).build() } diff --git a/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala b/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala index d92e58e..cd7d50c 100644 --- a/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala +++ b/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala @@ -3,7 +3,7 @@ package org.tensorframes import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DoubleType, StructType} import org.scalatest.FunSuite -import org.tensorframes.impl.DebugRowOpsImpl +import org.tensorframes.impl.{DebugRowOpsImpl, ScalarDoubleType} import org.tensorframes.dsl._ class DebugRowOpsSuite @@ -14,10 +14,10 @@ class DebugRowOpsSuite testGraph("Simple identity") { val rows = Array(Row(1.0)) - val input = StructType(Array(structField("x", DoubleType, Shape(Unknown)))) + val input = StructType(Array(structField("x", ScalarDoubleType, Shape(Unknown)))) val p2 = placeholder[Double](1) named "x" val out = identity(p2) named "y" - val outputSchema = StructType(Array(structField("y", DoubleType, Shape(Unknown)))) + val outputSchema = StructType(Array(structField("y", ScalarDoubleType, Shape(Unknown)))) val (g, _) = TestUtilities.analyzeGraph(out) logDebug(g.toString) val res = DebugRowOpsImpl.performMap(rows, input, Array(0), g, outputSchema) @@ -26,10 +26,10 @@ class DebugRowOpsSuite testGraph("Simple add") { val rows = Array(Row(1.0)) - val input = StructType(Array(structField("x", DoubleType, Shape(Unknown)))) + val input = StructType(Array(structField("x", ScalarDoubleType, Shape(Unknown)))) val p2 = placeholder[Double](1) named "x" val out = p2 + p2 named "y" - val outputSchema = StructType(Array(structField("y", DoubleType, Shape(Unknown)))) + val outputSchema = StructType(Array(structField("y", ScalarDoubleType, Shape(Unknown)))) val (g, _) = TestUtilities.analyzeGraph(out) logDebug(g.toString) val res = DebugRowOpsImpl.performMap(rows, input, Array(0), g, outputSchema) diff --git a/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala b/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala index 2a0a1a0..b197df7 100644 --- a/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala @@ -2,6 +2,7 @@ package org.tensorframes import org.apache.spark.sql.types.{DoubleType, IntegerType} import org.scalatest.FunSuite +import org.tensorframes.impl.{ScalarDoubleType, ScalarIntType} class ExtraOperationsSuite @@ -16,7 +17,7 @@ class ExtraOperationsSuite val di = ExtraOperations.explainDetailed(df) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === DoubleType) + assert(s.dataType === ScalarDoubleType) assert(s.shape === Shape(Unknown)) logDebug(df.toString() + "->" + di.toString) } @@ -26,7 +27,7 @@ class ExtraOperationsSuite val di = explainDetailed(df) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === IntegerType) + assert(s.dataType === ScalarIntType) assert(s.shape === Shape(Unknown)) logDebug(df.toString() + "->" + di.toString) } @@ -37,13 +38,13 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2, c3) = di.cols val Some(s1) = c1.stf - assert(s1.dataType === DoubleType) + assert(s1.dataType === ScalarDoubleType) assert(s1.shape === Shape(Unknown)) val Some(s2) = c2.stf - assert(s2.dataType === DoubleType) + assert(s2.dataType === ScalarDoubleType) assert(s2.shape === Shape(Unknown, Unknown)) val Some(s3) = c3.stf - assert(s3.dataType === DoubleType) + assert(s3.dataType === ScalarDoubleType) assert(s3.shape === Shape(Unknown, Unknown, Unknown)) } @@ -54,7 +55,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === DoubleType) + assert(s.dataType === ScalarDoubleType) assert(s.shape === Shape(1)) // There is only one partition } @@ -65,7 +66,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === DoubleType) + assert(s.dataType === ScalarDoubleType) assert(s.shape === Shape(Unknown)) // There is only one partition } @@ -78,7 +79,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2) = di.cols val Some(s2) = c2.stf - assert(s2.dataType === DoubleType) + assert(s2.dataType === ScalarDoubleType) assert(s2.shape === Shape(2, Unknown)) // There is only one partition } @@ -92,7 +93,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2) = di.cols val Some(s2) = c2.stf - assert(s2.dataType === DoubleType) + assert(s2.dataType === ScalarDoubleType) assert(s2.shape === Shape(3, 2)) // There is only one partition } } diff --git a/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala index a1680f9..3624e72 100644 --- a/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala @@ -2,7 +2,7 @@ package org.tensorframes.perf import org.scalatest.FunSuite import org.tensorframes.{ColumnInformation, Shape, TensorFramesTestSparkContext} -import org.tensorframes.impl.{SupportedOperations, TFDataOps} +import org.tensorframes.impl.{ScalarIntType, SupportedOperations, TFDataOps} import org.tensorframes.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -48,9 +48,9 @@ class ConvertBackPerformanceSuite // Creating the rows this way, because we need to respect the collection used by Spark when // unpacking the rows. val rows = sqlContext.createDataFrame(Seq.fill(numCells)(Tuple1(Seq.fill(numVals)(1)))).collect() - val schema = StructType(Seq(ColumnInformation.structField("f1", IntegerType, + val schema = StructType(Seq(ColumnInformation.structField("f1", ScalarIntType, Shape(numCells, numVals)))) - val tfSchema = StructType(Seq(ColumnInformation.structField("f2", IntegerType, + val tfSchema = StructType(Seq(ColumnInformation.structField("f2", ScalarIntType, Shape(numCells, numVals)))) val tensor = getTFTensor(IntegerType, Row(Seq.fill(numVals)(1)), Shape(numVals), numCells) println("generated data") diff --git a/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala index 556c8f1..b3112e8 100644 --- a/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala @@ -5,7 +5,7 @@ import org.tensorframes.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.tensorframes.{ColumnInformation, Shape, TensorFramesTestSparkContext} -import org.tensorframes.impl.{DataOps, SupportedOperations, TFDataOps} +import org.tensorframes.impl.{DataOps, ScalarIntType, SupportedOperations, TFDataOps} class ConvertPerformanceSuite extends FunSuite with TensorFramesTestSparkContext with Logging { @@ -44,7 +44,7 @@ class ConvertPerformanceSuite // Creating the rows this way, because we need to respect the collection used by Spark when // unpacking the rows. val rows = sqlContext.createDataFrame(Seq.fill(numCells)(Tuple1(Seq.fill(numVals)(1)))).collect() - val schema = StructType(Seq(ColumnInformation.structField("f1", IntegerType, + val schema = StructType(Seq(ColumnInformation.structField("f1", ScalarIntType, Shape(numCells, numVals)))) println("generated data") logInfo("generated data")