diff --git a/src/main/python/tensorframes/core.py b/src/main/python/tensorframes/core.py index 251e74e..a9da016 100644 --- a/src/main/python/tensorframes/core.py +++ b/src/main/python/tensorframes/core.py @@ -9,7 +9,8 @@ from pyspark.sql.types import DoubleType, IntegerType, LongType, FloatType, ArrayType __all__ = ['reduce_rows', 'map_rows', 'reduce_blocks', 'map_blocks', - 'analyze', 'print_schema', 'aggregate', 'block', 'row'] + 'analyze', 'print_schema', 'aggregate', 'block', 'row', + 'append_shape'] _sc = None _sql = None @@ -377,6 +378,26 @@ def analyze(dframe): """ return DataFrame(_java_api().analyze(dframe._jdf), _sql) +def append_shape(dframe, col, shape): + """Append extra metadata for a dataframe that + describes the numerical shape of the content. + + This method is useful when a dataframe contains non-scalar tensors, for which the shape must be checked beforehand. + The user is responsible for providing the right shape, any mismatch will trigger eventually an exception in Spark + + Note: nullable fields are not accepted. + + The function [print_schema] lets users introspect the information added to the DataFrame. + + :param dframe: a Spark DataFrame + :param col: a Column expression + :param shape: a shape corresponding to the tensor, + detailed explanation https://www.tensorflow.org/programmers_guide/tensors#shape + :return: a Spark DataFrame with metadata information embedded. + """ + shape = [i or -1 for i in shape] + return DataFrame(_java_api().appendShape(dframe._jdf, col._jc, shape), _sql) + def aggregate(fetches, grouped_data, initial_variables=_initial_variables_default): """ Performs an algebraic aggregation on the grouped data. diff --git a/src/main/python/tensorframes/core_test.py b/src/main/python/tensorframes/core_test.py index a5ea321..21608a9 100644 --- a/src/main/python/tensorframes/core_test.py +++ b/src/main/python/tensorframes/core_test.py @@ -4,6 +4,7 @@ from pyspark import SparkContext from pyspark.sql import DataFrame, SQLContext from pyspark.sql import Row +from pyspark.sql.functions import col import tensorflow as tf import pandas as pd @@ -198,6 +199,20 @@ def test_reduce_rows_1(self): res = tfs.reduce_rows(x, df) assert res == sum([r.x for r in data]) + def test_append_shape(self): + data = [Row(x=float(x)) for x in range(5)] + df = self.sql.createDataFrame(data) + ddf = tfs.append_shape(df, col('x'), [-1]) + with tf.Graph().as_default(): + # The placeholder that corresponds to column 'x' + x_1 = tf.placeholder(tf.double, shape=[], name="x_1") + x_2 = tf.placeholder(tf.double, shape=[], name="x_2") + # The output that adds 3 to x + x = tf.add(x_1, x_2, name='x') + # The resulting number + res = tfs.reduce_rows(x, ddf) + assert res == sum([r.x for r in data]) + # This test fails def test_reduce_blocks_1(self): data = [Row(x=float(x)) for x in range(5)] diff --git a/src/main/scala/org/tensorframes/ExperimentalOperations.scala b/src/main/scala/org/tensorframes/ExperimentalOperations.scala index 57ae8d1..347c093 100644 --- a/src/main/scala/org/tensorframes/ExperimentalOperations.scala +++ b/src/main/scala/org/tensorframes/ExperimentalOperations.scala @@ -1,10 +1,14 @@ package org.tensorframes +import java.util +import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col -import org.apache.spark.sql.types.{ArrayType, DataType, NumericType} - +import org.apache.spark.sql.types.{ArrayType, DataType, MetadataBuilder, NumericType} import org.tensorframes.impl.{ScalarType, SupportedOperations} +import scala.collection.JavaConverters._ + + /** * Some useful methods for operating on dataframes that are not part of the official API (and thus may change anytime). @@ -45,6 +49,23 @@ trait ExperimentalOperations { } df.select(cols: _*) } + + def appendShape(df: DataFrame, col: Column, shape: Array[Int]): DataFrame = { + + val meta = new MetadataBuilder + val colDtypes = df.select(col).schema.fields.head.dataType + val basicDatatype = + ExtraOperations.extractBasicType(colDtypes).getOrElse(throw new Exception(s"'$colDtypes' was not supported")) + + meta.putString(MetadataConstants.tensorStructType, + SupportedOperations.opsFor(basicDatatype).sqlType.toString + ) + meta.putLongArray(MetadataConstants.shapeKey, shape.map(_.asInstanceOf[Long])) + df.withColumn(col.toString(), col.as("", meta.build())) + } + + def appendShape(df: DataFrame, col:Column, shape: util.ArrayList[Int]): DataFrame = + appendShape(df, col, shape.asScala.toArray[Int]) } private[tensorframes] object ExtraOperations extends ExperimentalOperations with Logging { @@ -110,7 +131,7 @@ private[tensorframes] object ExtraOperations extends ExperimentalOperations with DataFrameInfo(allInfo) } - private def extractBasicType(dt: DataType): Option[ScalarType] = dt match { + def extractBasicType(dt: DataType): Option[ScalarType] = dt match { case x: NumericType => Some(SupportedOperations.opsFor(x).scalarType) case x: ArrayType => extractBasicType(x.elementType) case _ => None diff --git a/src/test/scala/org/tensorframes/BasicOperationsSuite.scala b/src/test/scala/org/tensorframes/BasicOperationsSuite.scala index 5b1a218..e5e68eb 100644 --- a/src/test/scala/org/tensorframes/BasicOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/BasicOperationsSuite.scala @@ -3,6 +3,7 @@ package org.tensorframes import org.scalatest.FunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col import org.tensorframes.dsl.Implicits._ import org.tensorframes.dsl._ @@ -42,6 +43,15 @@ class BasicOperationsSuite compareRows(df2.collect(), Array(Row(Seq(1.0), Seq(1.0)), Row(Seq(2.0), Seq(2.0)))) } + testGraph("Identity - 1 dim, Manually") { + val df = make1(Seq(Seq(1.0), Seq(2.0)), "in") + val adf = ops.appendShape(df, col("in"), Array(-1, 1)) + val p1 = placeholder[Double](Unknown, 1) named "in" + val out = identity(p1) named "out" + val df2 = adf.mapBlocks(out).select("in", "out") + compareRows(df2.collect(), Array(Row(Seq(1.0), Seq(1.0)), Row(Seq(2.0), Seq(2.0)))) + } + testGraph("Simple add - 1 dim") { val a = placeholder[Double](Unknown, 1) named "a" val b = placeholder[Double](Unknown, 1) named "b" @@ -57,6 +67,25 @@ class BasicOperationsSuite Row(Seq(2.0), Seq(2.2), Seq(4.2)))) } + testGraph("Simple add - 1 dim, Manually") { + val a = placeholder[Double](Unknown, 1) named "a" + val b = placeholder[Double](Unknown, 1) named "b" + val out = a + b named "out" + + val df = sql.createDataFrame(Seq( + Seq(1.0)->Seq(1.1), + Seq(2.0)->Seq(2.2))).toDF("a", "b") + val adf = { + ops.appendShape( + ops.appendShape(df, col("a"), Array(-1, 1)), + col("b"), Array(-1, 1)) + } + val df2 = adf.mapBlocks(out).select("a", "b","out") + compareRows(df2.collect(), Array( + Row(Seq(1.0), Seq(1.1), Seq(2.1)), + Row(Seq(2.0), Seq(2.2), Seq(4.2)))) + } + testGraph("Reduce - sum double") { val df = make1(Seq(1.0, 2.0), "x") val x1 = placeholder[Double]() named "x_1"