Skip to content

Commit

Permalink
SPARK-7836 and SPARK-7822: Python API of window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed May 22, 2015
1 parent 3b68cb0 commit 778e2c0
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 30 deletions.
3 changes: 2 additions & 1 deletion python/pyspark/sql/__init__.py
Expand Up @@ -66,8 +66,9 @@ def deco(f):
from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
from pyspark.sql.group import GroupedData
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec

__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions'
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
]
25 changes: 21 additions & 4 deletions python/pyspark/sql/column.py
Expand Up @@ -304,15 +304,13 @@ def cast(self, dataType):

astype = cast

@ignore_unicode_prefix
@since(1.3)
def between(self, lowerBound, upperBound):
""" A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
"""
return (self >= lowerBound) & (self <= upperBound)

@ignore_unicode_prefix
@since(1.4)
def when(self, condition, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
Expand All @@ -322,7 +320,6 @@ def when(self, condition, value):
:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.
"""
sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
Expand All @@ -331,7 +328,6 @@ def when(self, condition, value):
jc = sc._jvm.functions.when(condition._jc, v)
return Column(jc)

@ignore_unicode_prefix
@since(1.4)
def otherwise(self, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
Expand All @@ -345,6 +341,27 @@ def otherwise(self, value):
jc = self._jc.otherwise(value)
return Column(jc)

@since(1.4)
def over(self, window):
"""
Define a windowing column.
:param window: a :class:`WindowSpec`
:return: a Column
>>> from pyspark.sql import Window
>>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1)
>>> from pyspark.sql.functions import rank, min
>>> # df.select(rank().over(window), min('age').over(window))
.. note:: Window functions is only supported with HiveContext in 1.4
"""
from pyspark.sql.window import WindowSpec
if not isinstance(window, WindowSpec):
raise TypeError("window should be WindowSpec")
jc = self._jc.over(window._jspec)
return Column(jc)

def __repr__(self):
return 'Column<%s>' % self._jc.toString().encode('utf8')

Expand Down
83 changes: 83 additions & 0 deletions python/pyspark/sql/functions.py
Expand Up @@ -67,6 +67,17 @@ def _(col1, col2):
return _


def _create_window_function(name, doc=''):
""" Create a window function by name """
def _():
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)()
return Column(jc)
_.__name__ = name
_.__doc__ = doc
return _


_functions = {
'lit': 'Creates a :class:`Column` of literal value.',
'col': 'Returns a :class:`Column` based on the given column name.',
Expand Down Expand Up @@ -130,12 +141,40 @@ def _(col1, col2):
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
}

_window_functions = {
'lag': 'Window function: returns the lag value of current row of the expression, ' +
'null when the current row extends before the beginning of the window.',
'rowNumber': """Window function: Assigns a unique number (sequentially, starting from 1,
as defined by ORDER BY) to each row within the partition.""",
'denseRank': 'The difference between RANK and DENSE_RANK is that DENSE_RANK' +
'leaves no gaps in ranking sequence when there are ties. That is, if you were' +
'ranking a competition using DENSE_RANK and had three people tie for second ' +
'place, you would say that all three were in second place and that the next ' +
'person came in third.',
'rank': 'The difference between RANK and DENSE_RANK is that DENSE_RANK' +
'leaves no gaps in ranking sequence when there are ties. That is, if you were' +
'ranking a competition using DENSE_RANK and had three people tie for second place,' +
'you would say that all three were in second place and that the next person came in ' +
'third.',
'cumeDist': 'CUME_DIST (defined as the inverse of percentile in some' +
'statistical books) computes the position of a specified value relative to' +
'a set of values. To compute the CUME_DIST of a value x in a set S of size N,' +
'you use the formula: CUME_DIST(x) = number of values in S coming before and' +
'including x in the specified order / N',
'percentRank': 'PERCENT_RANK is similar to CUME_DIST, but it uses rank values rather than' +
'row counts in its numerator. The formula: ' +
'(rank of row in its partition - 1) / (number of rows in the partition - 1)',

}

for _name, _doc in _functions.items():
globals()[_name] = since(1.3)(_create_function(_name, _doc))
for _name, _doc in _functions_1_4.items():
globals()[_name] = since(1.4)(_create_function(_name, _doc))
for _name, _doc in _binary_mathfunctions.items():
globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc))
for _name, _doc in _window_functions.items():
globals()[_name] = since(1.4)(_create_window_function(_name, _doc))
del _name, _doc
__all__ += _functions.keys()
__all__ += _binary_mathfunctions.keys()
Expand Down Expand Up @@ -349,6 +388,50 @@ def when(condition, value):
return Column(jc)


@since(1.4)
def lag(col, count=1, default=None):
"""
Window function: returns the lag values of current row of the expression,
given default value when the current row extends before the beginning
of the window.
:param col: name of column or expression
:param count: number of row to extend
:param default: default value
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.lag(_to_java_column(col), count, default))


@since(1.4)
def lead(col, count=1, default=None):
"""
Window function: returns the lead values of current row of the column,
given default value when the current row extends before the end of the window.
:param col: name of column or expression
:param count: number of row to extend
:param default: default value
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.lead(_to_java_column(col), count, default))


@since(1.4)
def ntile(n):
"""
NTILE for specified column or expression.
NTILE allows easy calculation of tertiles, quartiles, deciles and other
common summary statistics. This function divides an ordered partition into a specified
number of groups called buckets and assigns a bucket number to each row in the partition.
:param n: an integer
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.ntile(int(n)))


class UserDefinedFunction(object):
"""
User defined function in Python
Expand Down
31 changes: 24 additions & 7 deletions python/pyspark/sql/tests.py
Expand Up @@ -44,6 +44,7 @@
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.window import Window


class ExamplePointUDT(UserDefinedType):
Expand Down Expand Up @@ -743,11 +744,9 @@ def setUpClass(cls):
try:
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
except py4j.protocol.Py4JError:
cls.sqlCtx = None
return
raise unittest.SkipTest("Hive is not available")
except TypeError:
cls.sqlCtx = None
return
raise unittest.SkipTest("Hive is not available")
os.unlink(cls.tempdir.name)
_scala_HiveContext =\
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
Expand All @@ -761,9 +760,6 @@ def tearDownClass(cls):
shutil.rmtree(cls.tempdir.name, ignore_errors=True)

def test_save_and_load_table(self):
if self.sqlCtx is None:
return # no hive available, skipped

df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
Expand Down Expand Up @@ -805,6 +801,27 @@ def test_save_and_load_table(self):

shutil.rmtree(tmpPath)

def test_window_functions(self):
df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
w = Window.partitionBy("value").orderBy("key")
from pyspark.sql import functions as F
sel = df.select(df.value, df.key,
F.max("key").over(w.rowsBetween(0, 1)),
F.min("key").over(w.rowsBetween(0, 1)),
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
F.rowNumber().over(w),
F.rank().over(w),
F.denseRank().over(w),
F.ntile(2).over(w))
rs = sorted(sel.collect())
expected = [
("1", 1, 1, 1, 1, 1, 1, 1, 1),
("2", 1, 1, 1, 3, 1, 1, 1, 1),
("2", 1, 2, 1, 3, 2, 1, 1, 1),
("2", 2, 2, 2, 3, 3, 3, 2, 2)
]
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[:len(r)])

if __name__ == "__main__":
unittest.main()
154 changes: 154 additions & 0 deletions python/pyspark/sql/window.py
@@ -0,0 +1,154 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys

from pyspark import SparkContext
from pyspark.sql import since
from pyspark.sql.column import _to_seq, _to_java_column

__all__ = ["Window", "WindowSpec"]


def _to_java_cols(cols):
sc = SparkContext._active_spark_context
return _to_seq(sc, cols, _to_java_column)


class Window(object):

"""
Utility functions for defining window in DataFrames.
For example:
PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
>>> window = Window.partitionBy("country").orderBy("date").rowsBetween(-sys.maxint, 0)
PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING
>>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3)
.. versionadded:: 1.4
"""
@staticmethod
@since(1.4)
def partitionBy(*cols):
"""
Creates a [[WindowSpec]] with the partitioning defined.
"""
sc = SparkContext._active_spark_context
jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols))
return WindowSpec(jspec)

@staticmethod
@since(1.4)
def orderBy(*cols):
"""
Creates a [[WindowSpec]] with the partitioning defined.
"""
sc = SparkContext._active_spark_context
jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols))
return WindowSpec(jspec)


class WindowSpec(object):
"""
A window specification that defines the partitioning, ordering,
and frame boundaries.
Use the static methods in :class:`Window` to create a :class:`WindowSpec`.
.. versionadded:: 1.4
"""

JAVA_MAX_LONG = (1 << 63) - 1
JAVA_MIN_LONG = - (1 << 63)

def __init__(self, jspec):
self._jspec = jspec

@since(1.4)
def partitionBy(self, *cols):
"""
Defines the partitioning columns in a [[WindowSpec]].
:param cols: names of columns or expressions
"""
return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols)))

@since(1.4)
def orderBy(self, *cols):
"""
Defines the ordering columns in a [[WindowSpec]].
:param cols: names of columns or expressions
"""
return WindowSpec(self._jspec.orderBy(_to_java_cols(cols)))

@since(1.4)
def rowsBetween(self, start=-sys.maxint, end=0):
"""
Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
Both `start` and `end` are relative positions from the current row.
For example, "0" means "current row", while "-1" means the row before
the current row, and "5" means the fifth row after the current row.
:param start: boundary start, inclusive.
The frame is unbounded if this is -sys.maxint(or lower).
:param end: boundary end, inclusive.
The frame is unbounded if this is sys.maxint(or higher).
"""
if start <= -sys.maxint:
start = self.JAVA_MIN_LONG
if end >= sys.maxint:
end = self.JAVA_MAX_LONG
return WindowSpec(self._jspec.rowsBetween(start, end))

@since(1.4)
def rangeBetween(self, start=-sys.maxint, end=0):
"""
Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
Both `start` and `end` are relative from the current row. For example,
"0" means "current row", while "-1" means one off before the current row,
and "5" means the five off after the current row.
:param start: boundary start, inclusive.
The frame is unbounded if this is -sys.maxint(or lower).
:param end: boundary end, inclusive.
The frame is unbounded if this is sys.maxint(or higher).
"""
if start <= -sys.maxint:
start = self.JAVA_MIN_LONG
if end >= sys.maxint:
end = self.JAVA_MAX_LONG
return WindowSpec(self._jspec.rangeBetween(start, end))


def _test():
import doctest
SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod()
if failure_count:
exit(-1)


if __name__ == "__main__":
_test()

0 comments on commit 778e2c0

Please sign in to comment.