Skip to content

Commit

Permalink
zero323#394: Use Union[List[Column], List[str]] for select()
Browse files Browse the repository at this point in the history
as @zero323 pointed out in zero323#395 (review),
we can use only List, not Sequence (as originally suggested by mypy)
  • Loading branch information
jhereth committed Apr 15, 2020
1 parent 306edd3 commit 270a9c3
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
51 changes: 51 additions & 0 deletions test-data/unit/sql-dataframe.test
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,57 @@ df.sample(withReplacement=False) # E: No overload variant of "sample" of "DataFr

[out]

[case selectColumns]
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.getOrCreate()

data = [('Alice', 1)]
df = spark.createDataFrame(data, schema="name str, age int")

df.select(["name", "age"])
df.select([col("name"), col("age")])

df.select(["name", col("age")]) # E: Argument 1 to "select" of "DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], List[str]]"

[out]

[case groupBy]
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.getOrCreate()

data = [('Alice', 1)]
df = spark.createDataFrame(data, schema="name str, age int")

df.groupBy(["name", "age"])
df.groupBy([col("name"), col("age")])


df.groupBy(["name", col("age")]) # E: Argument 1 to "groupBy" of "DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], List[str]]"

[out]

[case rollup]
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.getOrCreate()

data = [('Alice', 1)]
df = spark.createDataFrame(data, schema="name str, age int")

df.rollup(["name", "age"])
df.rollup([col("name"), col("age")])


df.rollup(["name", col("age")]) # E: Argument 1 to "rollup" of "DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], List[str]]"

[out]



[case dropColumns]
from pyspark.sql import SparkSession
Expand Down
22 changes: 16 additions & 6 deletions third_party/3/pyspark/sql/dataframe.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from typing import overload
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)

import pandas.core.frame # type: ignore
from py4j.java_gateway import JavaObject # type: ignore
Expand Down Expand Up @@ -135,7 +145,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
@overload
def select(self, *cols: ColumnOrName) -> DataFrame: ...
@overload
def select(self, __cols: List[ColumnOrName]) -> DataFrame: ...
def select(self, __cols: Union[List[Column], List[str]]) -> DataFrame: ...
@overload
def selectExpr(self, *expr: str) -> DataFrame: ...
@overload
Expand All @@ -144,15 +154,15 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
@overload
def groupBy(self, *cols: ColumnOrName) -> GroupedData: ...
@overload
def groupBy(self, __cols: List[ColumnOrName]) -> GroupedData: ...
def groupBy(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ...
@overload
def rollup(self, *cols: ColumnOrName) -> GroupedData: ...
@overload
def rollup(self, __cols: List[ColumnOrName]) -> GroupedData: ...
def rollup(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ...
@overload
def cube(self, *cols: ColumnOrName) -> GroupedData: ...
@overload
def cube(self, __cols: List[ColumnOrName]) -> GroupedData: ...
def cube(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ...
def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: ...
def union(self, other: DataFrame) -> DataFrame: ...
def unionAll(self, other: DataFrame) -> DataFrame: ...
Expand Down Expand Up @@ -220,7 +230,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
@overload
def groupby(self, *cols: ColumnOrName) -> GroupedData: ...
@overload
def groupby(self, __cols: List[ColumnOrName]) -> GroupedData: ...
def groupby(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ...
def drop_duplicates(self, subset: Optional[List[str]] = ...) -> DataFrame: ...
def where(self, condition: ColumnOrName) -> DataFrame: ...

Expand Down

0 comments on commit 270a9c3

Please sign in to comment.