Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Spark Dataset API spec #18

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -113,8 +113,8 @@ private[sql] object DataFrame {
// TODO: Improve documentation.
@Experimental
class DataFrame private[sql](
@transient val sqlContext: SQLContext,
@DeveloperApi @transient val queryExecution: QueryExecution) extends Serializable {
sqlContext: SQLContext,
queryExecution: QueryExecution) extends Dataset[Row](sqlContext, queryExecution)(new RowEncoder(queryExecution.analyzed.schema)) with Serializable {

// Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure
// you wrap it with `withNewExecutionId` if this actions doesn't call other action.
@@ -1358,63 +1358,21 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
def first(): Row = head()

/**
* Returns a new RDD by applying a function to all rows of this DataFrame.
* @group rdd
* @since 1.3.0
*/
def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f)

/**
* Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
* and then flattening the results.
* @group rdd
* @since 1.3.0
*/
def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)

/**
* Returns a new RDD by applying a function to each partition of this DataFrame.
* @group rdd
* @since 1.3.0
*/
def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
rdd.mapPartitions(f)
}

/**
* Applies a function `f` to all rows.
* @group rdd
* @since 1.3.0
*/
def foreach(f: Row => Unit): Unit = withNewExecutionId {
rdd.foreach(f)
}

/**
* Applies a function f to each partition of this [[DataFrame]].
* @group rdd
* @since 1.3.0
*/
def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId {
rdd.foreachPartition(f)
}
override def first(): Row = head()

/**
* Returns the first `n` rows in the [[DataFrame]].
* @group action
* @since 1.3.0
*/
def take(n: Int): Array[Row] = head(n)
override def take(n: Int): Array[Row] = head(n)

/**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
* @group action
* @since 1.3.0
*/
def collect(): Array[Row] = withNewExecutionId {
override def collect(): Array[Row] = withNewExecutionId {
queryExecution.executedPlan.executeCollect()
}

@@ -1461,7 +1419,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
def distinct(): DataFrame = dropDuplicates()
override def distinct(): DataFrame = dropDuplicates()

/**
* @group basic
@@ -0,0 +1,165 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.{Iterator => JIterator}

import org.apache.spark.sql.catalyst.expressions.{SortOrder, Ascending, JoinedRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.QueryExecution

import scala.reflect.ClassTag

import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType

import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}

import scala.reflect.runtime.universe._

/**
* A typed collection of data stored efficiently as a DataFrame.
*/
@Experimental
class Dataset[T] private[sql](
@transient val sqlContext: SQLContext,
@transient val queryExecution: QueryExecution)(
implicit val encoder: Encoder[T]) extends Serializable {

/**
* Returns a new `Dataset` where each record has been mapped on to the specified type.
*/
def as[U : Encoder] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])

/**
* Concise syntax for chaining custom transformations.
* {{{
* dataset
* .transform(featurize)
* .transform(...)
* }}}
*/
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = ???

This comment has been minimized.

Copy link
@johnynek

johnynek Nov 18, 2015

how is this implementation to be different from t(this)?

This comment has been minimized.

Copy link
@marmbrus

marmbrus Nov 18, 2015

Author Owner

Thats it. It's sugar so you can write:

dataset
  .where(...)
  .transform(featurize)
  .transform(tokenize)

instead of featurize(tokenize(ds.where(...))))


/**
* Returns a new `Dataset` that only contains elements where `func` returns `true`.
*/
def filter(func: T => Boolean): Dataset[T] = ???

// TODO: Create custom function to avoid boxing.
def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = ???

def filterNot(func: T => Boolean): Dataset[T] = ???

/**
* Returns a new Dataset that contains the result of applying `func` to each element.
*/
def map[U : Encoder](func: T => U): Dataset[U] = ???


def map[U](func: JFunction[T, U], uEncoder: Encoder[U]): Dataset[U] = ???

/**
* A version of map for Java that tries to infer the encoder using reflection. Note this needs
* a different name or it seems to break type inference for scala with the following error:
* "The argument types of an anonymous function must be fully known. (SLS 8.5)"
*/
def mapInfer[U](func: JFunction[T, U]): Dataset[U] = ???

def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = ???

def flatMap[U](func: JFunction[T, JIterator[U]], encoder: Encoder[U]): Dataset[U] = ???

/*****************
* Side effects *
****************/

def foreach(f: T => Unit): Unit = ???

def foreachPartition(f: Iterator[T] => Unit): Unit = ???

/*****************
* aggregation *
****************/

def reduce(f: (T, T) => T): T = ???

This comment has been minimized.

Copy link
@marmbrus

marmbrus Oct 6, 2015

Author Owner

reduce option?


def fold(zeroValue: T)(op: (T, T) => T): T = ???

def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = new GroupedDataset(this, func)

def groupBy(cols: Column*): GroupedDataset[Row, T] = ???

def groupBy[K : Encoder](col1: String, cols: String*): GroupedDataset[Row, T] = ???


/*****************
* joins *
****************/
def join[U : Encoder](other: Dataset[U]): Dataset[(T, U)] = ???

def join[U : Encoder, K : Encoder](other: Dataset[U], leftKey: T => K, rightKey: U => K) = ???

This comment has been minimized.

Copy link
@johnynek

johnynek Nov 18, 2015

what's the return type here?

This comment has been minimized.

Copy link
@marmbrus

marmbrus Nov 18, 2015

Author Owner

It probably would have been Dataset[(T, U)] but we dropped this.


def join[U : Encoder](other: Dataset[U], condition: Column): Dataset[(T, U)] = ???

/*****************
* Set operations *
****************/

def distinct: Dataset[T] = ???

def intersect(other: Dataset[T]): Dataset[T] = ???

def union(other: Dataset[T]): Dataset[T] = ???

def subtract(other: Dataset[T]): Dataset[T] = ???

/*****************
* actions *
****************/

def first(): T = ???
def collect(): Array[T] = ???
def take(num: Int): Array[T] = ???
}

trait Aggregator[T]

This comment has been minimized.

Copy link
@johnynek

johnynek Nov 18, 2015

what is the purpose of this empty trait?


class GroupedDataset[K : Encoder, T](dataset: Dataset[T], keyFunction: T => K) extends Serializable {

/** Specify a new encoder for key part of this [[GroupedDataset]]. */
def asKey[L : Encoder]: GroupedDataset[L, T] = ???
/** Specify a new encoder for value part of this [[GroupedDataset]]. */
def asValue[U : Encoder]: GroupedDataset[K, T] = ???

def keys: Dataset[K] = ???

def agg[U1](agg: Aggregator[U1]): Dataset[(K, U1)] = ???
def agg[U1, U2](agg: Aggregator[U1], agg2: Aggregator[U2]): Dataset[(K, U1, U2)] = ???
// ... more agg functions

def join[U](other: GroupedDataset[K, U]): Dataset[Pair[T, U]] = ???

def cogroup[U](other: GroupedDataset[K, U])(f: (K, Iterator[T], Iterator[U])) = ???

def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = ???

def countByKey: Dataset[(K, Long)] = ???
}
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.lang.reflect.{ParameterizedType, Constructor}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, UnsafeRow}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions.UnsafeRowWriters.UTF8StringWriter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import scala.language.implicitConversions
import scala.reflect._
import scala.reflect.runtime.universe._

object Encoder {
def fromType(genericTypes: Seq[java.lang.reflect.Type]): Encoder[_] = ???
}

/**
* Captures how to encode JVM objects as Spark SQL rows.
* TODO: Make unsafe row?

This comment has been minimized.

Copy link
@johnynek

johnynek Nov 18, 2015

seems like you might want macros to generate these for case classes and primitives.

This comment has been minimized.

Copy link
@marmbrus

marmbrus Nov 18, 2015

Author Owner

In the implementation we are doing this using janino at runtime. Compared with macro's this makes it easier for us to keep binary compatibility across different versions of spark.

*/
@Experimental
trait Encoder[T] extends Serializable {
def schema: StructType

def fromRow(row: InternalRow): T

def toRow(value: T): InternalRow

def classTag: ClassTag[T]

This comment has been minimized.

Copy link
@johnynek

johnynek Nov 18, 2015

why ClassTag[T] here rather than Class[T]

This comment has been minimized.

Copy link
@marmbrus

marmbrus Nov 18, 2015

Author Owner

No strong reason (and we plan to keep this internal). Its slightly easier to create arrays this way cause scala does the magic for you and RDDs also understand classTags?


// TODO: Use attribute references
def bind(ordinals: Seq[Int]): Encoder[T]
}

object ProductEncoder {
def apply[T <: Product : TypeTag] = ???

def tuple[T1, T2](t1: Class[T1], t2: Class[T2]): ProductEncoder[(T1, T2)] = ???
def tuple[T1, T2, T3](t1: Class[T1], t2: Class[T2], t3: Class[T3]): ProductEncoder[(T1, T2, T3)] = ???

}

class ProductEncoder[T <: Product]

/**
* Represents a pair of objects that are encoded as a flat row. Pairs are created to facilitate
* operations that calculate a grouping key, such as joins or aggregations.
*/
class Pair[L, R](val left: L, val right: R)

This comment has been minimized.

Copy link
@johnynek

johnynek Nov 18, 2015

why this over Tuple2?

This comment has been minimized.

Copy link
@marmbrus

marmbrus Nov 18, 2015

Author Owner

We were trying to have something that when used as a dataframe would have a "flat" relation representation of the join, unlike tuples which create structs when they have a class in a given field. We decided against this however and are just using tuples. See https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L497


class PairEncoder[T, U](left: Encoder[T], right: Encoder[U]) extends Encoder[Pair[T, U]] {
override def schema: StructType = ???

override def fromRow(row: InternalRow): Pair[T, U] = ???

override def classTag: ClassTag[Pair[T, U]] = ???

override def bind(ordinals: Seq[Int]): Encoder[Pair[T, U]] = ???

override def toRow(value: Pair[T, U]): InternalRow = ???
}
@@ -68,7 +68,7 @@ private[sql] object GroupedData {
class GroupedData protected[sql](
df: DataFrame,
groupingExprs: Seq[Expression],
private val groupType: GroupedData.GroupType) {
private val groupType: GroupedData.GroupType) extends GroupedDataset(df, identity[Row])(new RowEncoder(df.schema)) {

private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
@@ -33,6 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String
private[sql] abstract class SQLImplicits {
protected def _sqlContext: SQLContext


/**
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
* @since 1.3.0
@@ -56,6 +57,16 @@ private[sql] abstract class SQLImplicits {
DataFrameHolder(_sqlContext.createDataFrame(data))
}

implicit class DataSeq[A : Encoder](data: Seq[A]) {
def ds = {
val enc = implicitly[Encoder[A]]
val encoded = data.map { d => enc.toRow(d).copy() }
new Dataset[A](
_sqlContext.internalCreateDataFrame(
_sqlContext.sparkContext.parallelize(encoded), enc.schema))
}
}

// Do NOT add more implicit conversions. They are likely to break source compatibility by
// making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
// because of [[DoubleRDDFunctions]].
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.