Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/main/python/tensorframes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from pyspark.sql.types import DoubleType, IntegerType, LongType, FloatType, ArrayType

__all__ = ['reduce_rows', 'map_rows', 'reduce_blocks', 'map_blocks',
'analyze', 'print_schema', 'aggregate', 'block', 'row']
'analyze', 'print_schema', 'aggregate', 'block', 'row',
'append_shape']

_sc = None
_sql = None
Expand Down Expand Up @@ -377,6 +378,26 @@ def analyze(dframe):
"""
return DataFrame(_java_api().analyze(dframe._jdf), _sql)

def append_shape(dframe, col, shape):
"""Append extra metadata for a dataframe that
describes the numerical shape of the content.

This method is useful when a dataframe contains non-scalar tensors, for which the shape must be checked beforehand.
The user is responsible for providing the right shape, any mismatch will trigger eventually an exception in Spark

Note: nullable fields are not accepted.

The function [print_schema] lets users introspect the information added to the DataFrame.

:param dframe: a Spark DataFrame
:param col: a Column expression
:param shape: a shape corresponding to the tensor,
detailed explanation https://www.tensorflow.org/programmers_guide/tensors#shape
:return: a Spark DataFrame with metadata information embedded.
"""
shape = [i or -1 for i in shape]
return DataFrame(_java_api().appendShape(dframe._jdf, col._jc, shape), _sql)

def aggregate(fetches, grouped_data, initial_variables=_initial_variables_default):
"""
Performs an algebraic aggregation on the grouped data.
Expand Down
15 changes: 15 additions & 0 deletions src/main/python/tensorframes/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pyspark import SparkContext
from pyspark.sql import DataFrame, SQLContext
from pyspark.sql import Row
from pyspark.sql.functions import col
import tensorflow as tf
import pandas as pd

Expand Down Expand Up @@ -198,6 +199,20 @@ def test_reduce_rows_1(self):
res = tfs.reduce_rows(x, df)
assert res == sum([r.x for r in data])

def test_append_shape(self):
data = [Row(x=float(x)) for x in range(5)]
df = self.sql.createDataFrame(data)
ddf = tfs.append_shape(df, col('x'), [-1])
with tf.Graph().as_default():
# The placeholder that corresponds to column 'x'
x_1 = tf.placeholder(tf.double, shape=[], name="x_1")
x_2 = tf.placeholder(tf.double, shape=[], name="x_2")
# The output that adds 3 to x
x = tf.add(x_1, x_2, name='x')
# The resulting number
res = tfs.reduce_rows(x, ddf)
assert res == sum([r.x for r in data])

# This test fails
def test_reduce_blocks_1(self):
data = [Row(x=float(x)) for x in range(5)]
Expand Down
27 changes: 24 additions & 3 deletions src/main/scala/org/tensorframes/ExperimentalOperations.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package org.tensorframes

import java.util
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, DataType, NumericType}

import org.apache.spark.sql.types.{ArrayType, DataType, MetadataBuilder, NumericType}
import org.tensorframes.impl.{ScalarType, SupportedOperations}
import scala.collection.JavaConverters._



/**
* Some useful methods for operating on dataframes that are not part of the official API (and thus may change anytime).
Expand Down Expand Up @@ -45,6 +49,23 @@ trait ExperimentalOperations {
}
df.select(cols: _*)
}

def appendShape(df: DataFrame, col: Column, shape: Array[Int]): DataFrame = {

val meta = new MetadataBuilder
val colDtypes = df.select(col).schema.fields.head.dataType
val basicDatatype =
ExtraOperations.extractBasicType(colDtypes).getOrElse(throw new Exception(s"'$colDtypes' was not supported"))

meta.putString(MetadataConstants.tensorStructType,
SupportedOperations.opsFor(basicDatatype).sqlType.toString
)
meta.putLongArray(MetadataConstants.shapeKey, shape.map(_.asInstanceOf[Long]))
df.withColumn(col.toString(), col.as("", meta.build()))
}

def appendShape(df: DataFrame, col:Column, shape: util.ArrayList[Int]): DataFrame =
appendShape(df, col, shape.asScala.toArray[Int])
}

private[tensorframes] object ExtraOperations extends ExperimentalOperations with Logging {
Expand Down Expand Up @@ -110,7 +131,7 @@ private[tensorframes] object ExtraOperations extends ExperimentalOperations with
DataFrameInfo(allInfo)
}

private def extractBasicType(dt: DataType): Option[ScalarType] = dt match {
def extractBasicType(dt: DataType): Option[ScalarType] = dt match {
case x: NumericType => Some(SupportedOperations.opsFor(x).scalarType)
case x: ArrayType => extractBasicType(x.elementType)
case _ => None
Expand Down
29 changes: 29 additions & 0 deletions src/test/scala/org/tensorframes/BasicOperationsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.tensorframes
import org.scalatest.FunSuite

import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col

import org.tensorframes.dsl.Implicits._
import org.tensorframes.dsl._
Expand Down Expand Up @@ -42,6 +43,15 @@ class BasicOperationsSuite
compareRows(df2.collect(), Array(Row(Seq(1.0), Seq(1.0)), Row(Seq(2.0), Seq(2.0))))
}

testGraph("Identity - 1 dim, Manually") {
val df = make1(Seq(Seq(1.0), Seq(2.0)), "in")
val adf = ops.appendShape(df, col("in"), Array(-1, 1))
val p1 = placeholder[Double](Unknown, 1) named "in"
val out = identity(p1) named "out"
val df2 = adf.mapBlocks(out).select("in", "out")
compareRows(df2.collect(), Array(Row(Seq(1.0), Seq(1.0)), Row(Seq(2.0), Seq(2.0))))
}

testGraph("Simple add - 1 dim") {
val a = placeholder[Double](Unknown, 1) named "a"
val b = placeholder[Double](Unknown, 1) named "b"
Expand All @@ -57,6 +67,25 @@ class BasicOperationsSuite
Row(Seq(2.0), Seq(2.2), Seq(4.2))))
}

testGraph("Simple add - 1 dim, Manually") {
val a = placeholder[Double](Unknown, 1) named "a"
val b = placeholder[Double](Unknown, 1) named "b"
val out = a + b named "out"

val df = sql.createDataFrame(Seq(
Seq(1.0)->Seq(1.1),
Seq(2.0)->Seq(2.2))).toDF("a", "b")
val adf = {
ops.appendShape(
ops.appendShape(df, col("a"), Array(-1, 1)),
col("b"), Array(-1, 1))
}
val df2 = adf.mapBlocks(out).select("a", "b","out")
compareRows(df2.collect(), Array(
Row(Seq(1.0), Seq(1.1), Seq(2.1)),
Row(Seq(2.0), Seq(2.2), Seq(4.2))))
}

testGraph("Reduce - sum double") {
val df = make1(Seq(1.0, 2.0), "x")
val x1 = placeholder[Double]() named "x_1"
Expand Down