Skip to content

Commit

Permalink
[SPARK-5445][SQL] Made DataFrame dsl usable in Java
Browse files Browse the repository at this point in the history
Also removed the literal implicit transformation since it is pretty scary for API design. Instead, created a new lit method for creating literals. This doesn't break anything from a compatibility perspective because Literal was added two days ago.

Author: Reynold Xin <rxin@databricks.com>

Closes apache#4241 from rxin/df-docupdate and squashes the following commits:

c0f4810 [Reynold Xin] Fix Python merge conflict.
094c7d7 [Reynold Xin] Minor style fix. Reset Python tests.
3c89f4a [Reynold Xin] Package.
dfe6962 [Reynold Xin] Updated Python aggregate.
5dd4265 [Reynold Xin] Made dsl Java callable.
14b3c27 [Reynold Xin] Fix literal expression for symbols.
68b31cb [Reynold Xin] Literal.
4cfeb78 [Reynold Xin] [SPARK-5097][SQL] Address DataFrame code review feedback.
  • Loading branch information
rxin committed Jan 29, 2015
1 parent 4ee79c7 commit 5b9760d
Show file tree
Hide file tree
Showing 28 changed files with 325 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.examples.sql

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.dsl._
import org.apache.spark.sql.dsl.literals._
import org.apache.spark.sql.api.scala.dsl._

// One method for defining the schema of an RDD is to make a case class with the desired column
// names and types.
Expand Down
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql._
import org.apache.spark.sql.dsl._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.types._

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
import org.apache.spark.sql.dsl._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.dsl._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{StructField, StructType}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.dsl._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
Expand Down
38 changes: 22 additions & 16 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ def _parse_schema_abstract(s):

def _infer_schema_type(obj, dataType):
"""
Fill the dataType with types infered from obj
Fill the dataType with types inferred from obj
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
Expand Down Expand Up @@ -2140,7 +2140,7 @@ def __getattr__(self, name):
return Column(self._jdf.apply(name))
raise AttributeError

def As(self, name):
def alias(self, name):
""" Alias the current DataFrame """
return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)

Expand Down Expand Up @@ -2216,7 +2216,7 @@ def intersect(self, other):
"""
return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)

def Except(self, other):
def subtract(self, other):
""" Return a new [[DataFrame]] containing rows in this frame
but not in another frame.
Expand All @@ -2234,7 +2234,7 @@ def sample(self, withReplacement, fraction, seed=None):

def addColumn(self, colName, col):
""" Return a new [[DataFrame]] by adding a column. """
return self.select('*', col.As(colName))
return self.select('*', col.alias(colName))

def removeColumn(self, colName):
raise NotImplemented
Expand Down Expand Up @@ -2342,7 +2342,7 @@ def sum(self):

def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
return sc._jvm.Literal.apply(literal)
return sc._jvm.org.apache.spark.sql.api.java.dsl.lit(literal)


def _create_column_from_name(name):
Expand Down Expand Up @@ -2371,13 +2371,20 @@ def _(self):
return _


def _bin_op(name):
""" Create a method for given binary operator """
def _bin_op(name, pass_literal_through=False):
""" Create a method for given binary operator
Keyword arguments:
pass_literal_through -- whether to pass literal value directly through to the JVM.
"""
def _(self, other):
if isinstance(other, Column):
jc = other._jc
else:
jc = _create_column_from_literal(other)
if pass_literal_through:
jc = other
else:
jc = _create_column_from_literal(other)
return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
return _

Expand Down Expand Up @@ -2458,10 +2465,10 @@ def __init__(self, jc, jdf=None, sql_ctx=None):
# __getattr__ = _bin_op("getField")

# string methods
rlike = _bin_op("rlike")
like = _bin_op("like")
startswith = _bin_op("startsWith")
endswith = _bin_op("endsWith")
rlike = _bin_op("rlike", pass_literal_through=True)
like = _bin_op("like", pass_literal_through=True)
startswith = _bin_op("startsWith", pass_literal_through=True)
endswith = _bin_op("endsWith", pass_literal_through=True)
upper = _unary_op("upper")
lower = _unary_op("lower")

Expand All @@ -2487,7 +2494,7 @@ def substr(self, startPos, pos):
isNotNull = _unary_op("isNotNull")

# `as` is keyword
def As(self, alias):
def alias(self, alias):
return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)

def cast(self, dataType):
Expand All @@ -2501,15 +2508,14 @@ def cast(self, dataType):


def _aggregate_func(name):
""" Creat a function for aggregator by name"""
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
if isinstance(col, Column):
jcol = col._jc
else:
jcol = _create_column_from_name(col)
# FIXME: can not access dsl.min/max ...
jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol)
jc = getattr(sc._jvm.org.apache.spark.sql.api.java.dsl, name)(jcol)
return Column(jc)
return staticmethod(_)

Expand Down
Loading

0 comments on commit 5b9760d

Please sign in to comment.