Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions src/main/scala/org/tensorframes/ColumnInformation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package org.tensorframes

import org.apache.spark.sql.types._

import org.tensorframes.impl.{ScalarType, SupportedOperations}


class ColumnInformation private (
val field: StructField,
Expand All @@ -15,7 +17,9 @@ class ColumnInformation private (
val b = new MetadataBuilder().withMetadata(field.metadata)
for (info <- stf) {
b.putLongArray(shapeKey, info.shape.dims.toArray)
b.putString(tensorStructType, info.dataType.toString)
// Keep the SQL name, so that we do not leak internal details.
val dt = SupportedOperations.opsFor(info.dataType).sqlType
b.putString(tensorStructType, dt.toString)
}
val meta = b.build()
field.copy(metadata = meta)
Expand Down Expand Up @@ -73,15 +77,15 @@ object ColumnInformation extends Logging {
* @param scalarType the data type
* @param blockShape the shape of the block
*/
def structField(name: String, scalarType: NumericType, blockShape: Shape): StructField = {
def structField(name: String, scalarType: ScalarType, blockShape: Shape): StructField = {
val i = SparkTFColInfo(blockShape, scalarType)
val f = StructField(name, sqlType(scalarType, blockShape.tail), nullable = false)
ColumnInformation(f, i).merged
}

private def sqlType(scalarType: NumericType, shape: Shape): DataType = {
private def sqlType(scalarType: ScalarType, shape: Shape): DataType = {
if (shape.dims.isEmpty) {
scalarType
SupportedOperations.opsFor(scalarType).sqlType
} else {
ArrayType(sqlType(scalarType, shape.tail), containsNull = false)
}
Expand All @@ -102,11 +106,14 @@ object ColumnInformation extends Logging {
for {
s <- shape
t <- tpe
} yield SparkTFColInfo(s, t)
ops <- SupportedOperations.getOps(t)
} yield SparkTFColInfo(s, ops.scalarType)
}

private def getType(s: String): Option[NumericType] = {
supportedTypes.find(_.toString == s)
private def getType(s: String): Option[DataType] = {
val res = supportedTypes.find(_.toString == s)
logInfo(s"getType: $s -> $res")
res
}

/**
Expand All @@ -115,19 +122,18 @@ object ColumnInformation extends Logging {
* @return
*/
private def extractFromRow(dt: DataType): Option[SparkTFColInfo] = dt match {
case x: NumericType if MetadataConstants.supportedTypes.contains(dt) =>
logTrace("numerictype: " + x)
// It is a basic type that we understand
Some(SparkTFColInfo(Shape(Unknown), x))
case x: ArrayType =>
logTrace("arraytype: " + x)
// Look into the array to figure out the type.
extractFromRow(x.elementType).map { info =>
SparkTFColInfo(info.shape.prepend(Unknown), info.dataType)
}
case _ =>
logTrace("not understood: " + dt)
// Not understood.
None
case _ => SupportedOperations.getOps(dt) match {
case Some(ops) =>
logTrace("numerictype: " + ops.scalarType)
// It is a basic type that we understand
Some(SparkTFColInfo(Shape(Unknown), ops.scalarType))
case None => None
}
}
}
7 changes: 4 additions & 3 deletions src/main/scala/org/tensorframes/ExperimentalOperations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package org.tensorframes
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, DataType, NumericType}
import org.tensorframes.impl.SupportedOperations

import org.tensorframes.impl.{ScalarType, SupportedOperations}

/**
* Some useful methods for operating on dataframes that are not part of the official API (and thus may change anytime).
Expand Down Expand Up @@ -109,8 +110,8 @@ private[tensorframes] object ExtraOperations extends ExperimentalOperations with
DataFrameInfo(allInfo)
}

private def extractBasicType(dt: DataType): Option[NumericType] = dt match {
case x: NumericType => Some(x)
private def extractBasicType(dt: DataType): Option[ScalarType] = dt match {
case x: NumericType => Some(SupportedOperations.opsFor(x).scalarType)
case x: ArrayType => extractBasicType(x.elementType)
case _ => None
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/org/tensorframes/MetadataConstants.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.tensorframes

import org.apache.spark.sql.types.NumericType
import org.tensorframes.impl.SupportedOperations
import org.apache.spark.sql.types.{DataType, NumericType}
import org.tensorframes.impl.{ScalarType, SupportedOperations}

/**
* Metadata annotations that get embedded in dataframes to express tensor information.
Expand Down Expand Up @@ -29,5 +29,5 @@ object MetadataConstants {
/**
* All the SQL types supported by SparkTF.
*/
val supportedTypes: Seq[NumericType] = SupportedOperations.sqlTypes
val supportedTypes: Seq[DataType] = SupportedOperations.sqlTypes
}
16 changes: 12 additions & 4 deletions src/main/scala/org/tensorframes/Shape.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package org.tensorframes

import org.apache.spark.sql.types.NumericType
import org.apache.spark.sql.types.{BinaryType, DataType, NumericType}
import org.tensorflow.framework.TensorShapeProto

import scala.collection.JavaConverters._
import org.tensorframes.Shape.DimType
import org.tensorframes.impl.ScalarType
import org.{tensorflow => tf}


Expand Down Expand Up @@ -36,6 +38,11 @@ class Shape private (private val ds: Array[DimType]) extends Serializable {

def prepend(x: Int): Shape = Shape(x.toLong +: ds)

/**
* Drops the most inner dimension of the shape.
*/
def dropInner: Shape = Shape(ds.dropRight(1))

/**
* A shape with the first dimension dropped.
*/
Expand Down Expand Up @@ -104,15 +111,16 @@ object Shape {

/**
* SparkTF information. This is the information generally required to work on a tensor.
* @param shape
* @param dataType
* @param shape the shape of the column (including the number of rows). May contain some unknowns.
* @param dataType the datatype of the scalar. Note that it is either NumericType or BinaryType.
*/
// TODO(tjh) the types supported by TF are much richer (uint8, etc.) but it is not clear
// if they all map to a Catalyst memory representation
// TODO(tjh) support later basic structures for sparse types?
case class SparkTFColInfo(
shape: Shape,
dataType: NumericType) extends Serializable
dataType: ScalarType) extends Serializable {
}

/**
* Exception thrown when the user requests tensors of high order.
Expand Down
11 changes: 7 additions & 4 deletions src/main/scala/org/tensorframes/dsl/DslImpl.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package org.tensorframes.dsl

import javax.annotation.Nullable

import org.tensorflow.framework.{AttrValue, DataType, GraphDef, TensorShapeProto}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.NumericType

import org.tensorframes.{Logging, ColumnInformation, Shape}
import org.tensorframes.impl.DenseTensor
import org.tensorframes.{ColumnInformation, Logging, Shape}
import org.tensorframes.impl.{DenseTensor, SupportedOperations}


/**
Expand Down Expand Up @@ -75,8 +76,9 @@ private[dsl] object DslImpl extends Logging with DefaultConversions {

def build_constant(dt: DenseTensor): Node = {
val a = AttrValue.newBuilder().setTensor(DenseTensor.toTensorProto(dt))
val dt2 = SupportedOperations.opsFor(dt.dtype).sqlType.asInstanceOf[NumericType]
build("Const", isOp = false,
shape = dt.shape, dtype = dt.dtype,
shape = dt.shape, dtype = dt2,
extraAttrs = Map("value" -> a.build()))
}

Expand All @@ -100,7 +102,8 @@ private[dsl] object DslImpl extends Logging with DefaultConversions {
s"tensorframes: $schema")
}
val shape = if (block) { stf.shape } else { stf.shape.tail }
DslImpl.placeholder(stf.dataType, shape).named(tfName)
val dt = SupportedOperations.opsFor(stf.dataType).sqlType.asInstanceOf[NumericType]
DslImpl.placeholder(dt, shape).named(tfName)
}

private def commonShape(shapes: Seq[Shape]): Shape = {
Expand Down
6 changes: 2 additions & 4 deletions src/main/scala/org/tensorframes/dsl/package.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package org.tensorframes

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.IntegerType

import org.apache.spark.sql.types.{IntegerType, NumericType}
import org.tensorframes.impl.SupportedOperations

/**
Expand Down Expand Up @@ -45,7 +43,7 @@ package object dsl {

def placeholder[T : Numeric : TypeTag](shape: Int*): Operation = {
val ops = SupportedOperations.getOps[T]()
DslImpl.placeholder(ops.sqlType, Shape(shape: _*))
DslImpl.placeholder(ops.sqlType.asInstanceOf[NumericType], Shape(shape: _*))
}

def constant[T : ConvertibleToDenseTensor](x: T): Operation = {
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/org/tensorframes/impl/DataOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import scala.reflect.ClassTag

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.{NumericType, StructType}
import org.apache.spark.sql.types.StructType

import org.tensorframes.{Logging, Shape}
import org.tensorframes.Shape.DimType
Expand Down Expand Up @@ -145,7 +145,7 @@ object DataOps extends Logging {

def getColumnFast0(
reshapeShape: Shape,
scalaType: NumericType,
scalaType: ScalarType,
allDataBuffer: mutable.WrappedArray[_]): Iterable[Any] = {
reshapeShape.dims match {
case Seq() =>
Expand Down
37 changes: 19 additions & 18 deletions src/main/scala/org/tensorframes/impl/DebugRowOps.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
package org.tensorframes.impl

import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}

import org.apache.commons.lang3.SerializationUtils
import org.tensorflow.framework.GraphDef
import org.tensorflow.{Session, Tensor}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row}
import org.tensorflow.framework.GraphDef
import org.tensorflow.{Session, Tensor}

import org.tensorframes._
import org.tensorframes.test.DslOperations

import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}

/**
* The different schemas required for the block reduction.
Expand Down Expand Up @@ -322,17 +325,17 @@ class DebugRowOps
throw new Exception(
s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe")
}
if (! stf.shape.checkMorePreciseThan(in.shape)) {
throw new Exception(
s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" +
s" ${in.shape} requested by the TF graph")
}
// We do not support autocasting for now.
if (stf.dataType != in.scalarType) {
throw new Exception(
s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " +
s"of the column (${in.scalarType})")
}
if (! stf.shape.checkMorePreciseThan(in.shape)) {
throw new Exception(
s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" +
s" ${in.shape} requested by the TF graph")
}
// The input has to be either a constant or a placeholder
if (! in.isPlaceholder) {
throw new Exception(
Expand Down Expand Up @@ -414,16 +417,16 @@ class DebugRowOps
val stf = get(ColumnInformation(f).stf,
s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe")

check(stf.dataType == in.scalarType,
s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " +
s"of the column (${in.scalarType})")

val cellShape = stf.shape.tail
// No check for unknowns: we allow unknowns in the first dimension of the cell shape.
check(cellShape.checkMorePreciseThan(in.shape),
s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" +
s" ${in.shape} requested by the TF graph")

check(stf.dataType == in.scalarType,
s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " +
s"of the column (${in.scalarType})")

check(in.isPlaceholder,
s"Invalid type for input node ${in.name}. It has to be a placeholder")
}
Expand Down Expand Up @@ -532,7 +535,8 @@ class DebugRowOps
val f = col.field
builder.append(s"$prefix-- ${f.name}: ${f.dataType.typeName} (nullable = ${f.nullable})")
val stf = col.stf.map { s =>
s" ${s.dataType.typeName}${s.shape}"
val dt = SupportedOperations.opsFor(s.dataType).sqlType
s" ${dt.typeName}${s.shape}"
} .getOrElse(" <no tensor info>")
builder.append(stf)
builder.append("\n")
Expand Down Expand Up @@ -725,9 +729,6 @@ object DebugRowOpsImpl extends Logging {
}
}

// Trying to get around some frequent crashes within TF.
private[this] val tfLock = new Object

private[impl] def reducePair(
schema: StructType,
gbc: Broadcast[SerializedGraph]): (Row, Row) => Row = {
Expand Down
20 changes: 12 additions & 8 deletions src/main/scala/org/tensorframes/impl/DenseTensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,31 @@ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, NumericTy
*/
private[tensorframes] class DenseTensor private(
val shape: Shape,
val dtype: NumericType,
val dtype: ScalarType,
private val data: Array[Byte]) {

override def toString(): String = s"DenseTensor($shape, $dtype, " +
s"${data.length / dtype.defaultSize} elements)"
s"${data.length} bytes)"
}

private[tensorframes] object DenseTensor {
def apply[T](x: T)(implicit ev2: TypeTag[T]): DenseTensor = {
val ops = SupportedOperations.getOps[T]()
new DenseTensor(Shape.empty, ops.sqlType, convert(x))
apply(Shape.empty, ops.sqlType.asInstanceOf[NumericType], convert(x))
}

def apply[T](xs: Seq[T])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = {
val ops = SupportedOperations.getOps[T]()
new DenseTensor(Shape(xs.size), ops.sqlType, convert1(xs))
apply(Shape(xs.size), ops.sqlType.asInstanceOf[NumericType], convert1(xs))
}

def apply(shape: Shape, dtype: NumericType, data: Array[Byte]): DenseTensor = {
new DenseTensor(shape, SupportedOperations.opsFor(dtype).scalarType, data)
}

def matrix[T](xs: Seq[Seq[T]])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = {
val ops = SupportedOperations.getOps[T]()
new DenseTensor(Shape(xs.size, xs.head.size), ops.sqlType, convert2(xs))
apply(Shape(xs.size, xs.head.size), ops.sqlType.asInstanceOf[NumericType], convert2(xs))
}

private def convert[T](x: T)(implicit ev2: TypeTag[T]): Array[Byte] = {
Expand Down Expand Up @@ -98,15 +102,15 @@ private[tensorframes] object DenseTensor {
val shape = Shape.from(proto.getTensorShape)
val data = ops.sqlType match {
case DoubleType =>
val coll = proto.getDoubleValList.asScala.toSeq.map(_.doubleValue())
val coll = proto.getDoubleValList.asScala.map(_.doubleValue())
convert(coll)
case IntegerType =>
val coll = proto.getIntValList.asScala.toSeq.map(_.intValue())
val coll = proto.getIntValList.asScala.map(_.intValue())
convert(coll)
case _ =>
throw new IllegalArgumentException(
s"Cannot convert type ${ops.sqlType}")
}
new DenseTensor(shape, ops.sqlType, data)
new DenseTensor(shape, ops.scalarType, data)
}
}
Loading