Skip to content

Commit

Permalink
Updated Python tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jul 2, 2015
1 parent 2727789 commit d8518cf
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
10 changes: 5 additions & 5 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,11 +802,11 @@ def groupBy(self, *cols):
Each element should be a column name (string) or an expression (:class:`Column`).
>>> df.groupBy().avg().collect()
[Row(AVG(age)=3.5)]
[Row(avg(age)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
>>> df.groupBy(['name', df.age]).count().collect()
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
"""
Expand Down Expand Up @@ -864,10 +864,10 @@ def agg(self, *exprs):
(shorthand for ``df.groupBy.agg()``).
>>> df.agg({"age": "max"}).collect()
[Row(MAX(age)=5)]
[Row(max(age)=5)]
>>> from pyspark.sql import functions as F
>>> df.agg(F.min(df.age)).collect()
[Row(MIN(age)=2)]
[Row(min(age)=2)]
"""
return self.groupBy().agg(*exprs)

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def coalesce(*cols):
>>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
+-------------+
|Coalesce(a,b)|
|coalesce(a,b)|
+-------------+
| null|
| 1|
Expand All @@ -275,7 +275,7 @@ def coalesce(*cols):
>>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
+----+----+---------------+
| a| b|Coalesce(a,0.0)|
| a| b|coalesce(a,0.0)|
+----+----+---------------+
|null|null| 0.0|
| 1|null| 1.0|
Expand Down
24 changes: 12 additions & 12 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def agg(self, *exprs):
>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
[Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
[Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
Expand Down Expand Up @@ -110,9 +110,9 @@ def mean(self, *cols):
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().mean('age').collect()
[Row(AVG(age)=3.5)]
[Row(avg(age)=3.5)]
>>> df3.groupBy().mean('age', 'height').collect()
[Row(AVG(age)=3.5, AVG(height)=82.5)]
[Row(avg(age)=3.5, avg(height)=82.5)]
"""

@df_varargs_api
Expand All @@ -125,9 +125,9 @@ def avg(self, *cols):
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().avg('age').collect()
[Row(AVG(age)=3.5)]
[Row(avg(age)=3.5)]
>>> df3.groupBy().avg('age', 'height').collect()
[Row(AVG(age)=3.5, AVG(height)=82.5)]
[Row(avg(age)=3.5, avg(height)=82.5)]
"""

@df_varargs_api
Expand All @@ -136,9 +136,9 @@ def max(self, *cols):
"""Computes the max value for each numeric columns for each group.
>>> df.groupBy().max('age').collect()
[Row(MAX(age)=5)]
[Row(max(age)=5)]
>>> df3.groupBy().max('age', 'height').collect()
[Row(MAX(age)=5, MAX(height)=85)]
[Row(max(age)=5, max(height)=85)]
"""

@df_varargs_api
Expand All @@ -149,9 +149,9 @@ def min(self, *cols):
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().min('age').collect()
[Row(MIN(age)=2)]
[Row(min(age)=2)]
>>> df3.groupBy().min('age', 'height').collect()
[Row(MIN(age)=2, MIN(height)=80)]
[Row(min(age)=2, min(height)=80)]
"""

@df_varargs_api
Expand All @@ -162,9 +162,9 @@ def sum(self, *cols):
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().sum('age').collect()
[Row(SUM(age)=7)]
[Row(sum(age)=7)]
>>> df3.groupBy().sum('age', 'height').collect()
[Row(SUM(age)=7, SUM(height)=165)]
[Row(sum(age)=7, sum(height)=165)]
"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[

override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"MIN($child)"

override def asPartial: SplitEvaluation = {
val partialMin = Alias(Min(child), "PartialMin")()
Expand Down Expand Up @@ -388,6 +387,8 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)

case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {

override def prettyName: String = "avg"

override def nullable: Boolean = true

override def dataType: DataType = child.dataType match {
Expand Down

0 comments on commit d8518cf

Please sign in to comment.