Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions duckdb/experimental/spark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1851,6 +1851,30 @@ def isnotnull(col: "ColumnOrName") -> Column:
return Column(_to_column_expr(col).isnotnull())


def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
"""
Returns same result as the EQUAL(=) operator for non-null operands,
but returns true if both are null, false if one of the them is null.
.. versionadded:: 3.5.0
Parameters
----------
col1 : :class:`~pyspark.sql.Column` or str
col2 : :class:`~pyspark.sql.Column` or str
Examples
--------
>>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"])
>>> df.select(equal_null(df.a, df.b).alias('r')).collect()
[Row(r=True), Row(r=False)]
"""
if isinstance(col1, str):
col1 = col(col1)

if isinstance(col2, str):
col2 = col(col2)

return nvl((col1 == col2) | ((col1.isNull() & col2.isNull())), lit(False))


def flatten(col: "ColumnOrName") -> Column:
"""
Collection function: creates a single array from an array of arrays.
Expand Down Expand Up @@ -2157,6 +2181,33 @@ def e() -> Column:
return lit(2.718281828459045)


def negative(col: "ColumnOrName") -> Column:
"""
Returns the negative value.
.. versionadded:: 3.5.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
column to calculate negative value for.
Returns
-------
:class:`~pyspark.sql.Column`
negative value.
Examples
--------
>>> import pyspark.sql.functions as sf
>>> spark.range(3).select(sf.negative("id")).show()
+------------+
|negative(id)|
+------------+
| 0|
| -1|
| -2|
+------------+
"""
return abs(col) * -1


def pi() -> Column:
"""Returns Pi.

Expand Down Expand Up @@ -3774,6 +3825,53 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column:
return date_part(field, source)


def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column:
"""
Returns the number of days from `start` to `end`.

.. versionadded:: 3.5.0

Parameters
----------
end : :class:`~pyspark.sql.Column` or column name
to date column to work on.
start : :class:`~pyspark.sql.Column` or column name
from date column to work on.

Returns
-------
:class:`~pyspark.sql.Column`
difference in days between two dates.

See Also
--------
:meth:`pyspark.sql.functions.dateadd`
:meth:`pyspark.sql.functions.date_add`
:meth:`pyspark.sql.functions.date_sub`
:meth:`pyspark.sql.functions.datediff`
:meth:`pyspark.sql.functions.timestamp_diff`

Examples
--------
>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])
>>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show()
+----------+----------+-----------------+
| d1| d2|date_diff(d1, d2)|
+----------+----------+-----------------+
|2015-04-08|2015-05-10| -32|
+----------+----------+-----------------+

>>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show()
+----------+----------+-----------------+
| d1| d2|date_diff(d2, d1)|
+----------+----------+-----------------+
|2015-04-08|2015-05-10| 32|
+----------+----------+-----------------+
"""
return _invoke_function_over_columns("date_diff", lit("day"), end, start)


def year(col: "ColumnOrName") -> Column:
"""
Extract the year of a given date/timestamp as integer.
Expand Down Expand Up @@ -5685,6 +5783,31 @@ def to_timestamp_ntz(
return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format)


def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> Column:
"""
Parses the `col` with the `format` to a timestamp. The function always
returns null on an invalid input with/without ANSI SQL mode enabled. The result data type is
consistent with the value of configuration `spark.sql.timestampType`.
.. versionadded:: 3.5.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
column values to convert.
format: str, optional
format to use to convert timestamp values.
Examples
--------
>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
>>> df.select(try_to_timestamp(df.t).alias('dt')).collect()
[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
>>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect()
[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
"""
if format is None:
format = lit(['%Y-%m-%d', '%Y-%m-%d %H:%M:%S'])

return _invoke_function_over_columns("try_strptime", col, format)

def substr(
str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None
) -> Column:
Expand Down
64 changes: 32 additions & 32 deletions tests/fast/spark/test_spark_functions_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import platform

_ = pytest.importorskip("duckdb.experimental.spark")
from spark_namespace.sql import functions as F
from spark_namespace.sql import functions as sf
from spark_namespace.sql.types import Row
from spark_namespace import USE_ACTUAL_SPARK

Expand All @@ -19,7 +19,7 @@ def test_array_distinct(self, spark):
([2, 4, 5], 3),
]
df = spark.createDataFrame(data, ["firstColumn", "secondColumn"])
df = df.withColumn("distinct_values", F.array_distinct(F.col("firstColumn")))
df = df.withColumn("distinct_values", sf.array_distinct(sf.col("firstColumn")))
res = df.select("distinct_values").collect()
# Output order can vary across platforms which is why we sort it first
assert len(res) == 2
Expand All @@ -31,7 +31,7 @@ def test_array_intersect(self, spark):
(["b", "a", "c"], ["c", "d", "a", "f"]),
]
df = spark.createDataFrame(data, ["c1", "c2"])
df = df.withColumn("intersect_values", F.array_intersect(F.col("c1"), F.col("c2")))
df = df.withColumn("intersect_values", sf.array_intersect(sf.col("c1"), sf.col("c2")))
res = df.select("intersect_values").collect()
# Output order can vary across platforms which is why we sort it first
assert len(res) == 1
Expand All @@ -42,7 +42,7 @@ def test_array_union(self, spark):
(["b", "a", "c"], ["c", "d", "a", "f"]),
]
df = spark.createDataFrame(data, ["c1", "c2"])
df = df.withColumn("union_values", F.array_union(F.col("c1"), F.col("c2")))
df = df.withColumn("union_values", sf.array_union(sf.col("c1"), sf.col("c2")))
res = df.select("union_values").collect()
# Output order can vary across platforms which is why we sort it first
assert len(res) == 1
Expand All @@ -54,7 +54,7 @@ def test_array_max(self, spark):
([4, 2, 5], 5),
]
df = spark.createDataFrame(data, ["firstColumn", "secondColumn"])
df = df.withColumn("max_value", F.array_max(F.col("firstColumn")))
df = df.withColumn("max_value", sf.array_max(sf.col("firstColumn")))
res = df.select("max_value").collect()
assert res == [
Row(max_value=3),
Expand All @@ -67,7 +67,7 @@ def test_array_min(self, spark):
([2, 4, 5], 5),
]
df = spark.createDataFrame(data, ["firstColumn", "secondColumn"])
df = df.withColumn("min_value", F.array_min(F.col("firstColumn")))
df = df.withColumn("min_value", sf.array_min(sf.col("firstColumn")))
res = df.select("min_value").collect()
assert res == [
Row(max_value=1),
Expand All @@ -77,58 +77,58 @@ def test_array_min(self, spark):
def test_get(self, spark):
df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index'])

res = df.select(F.get(df.data, 1).alias("r")).collect()
res = df.select(sf.get(df.data, 1).alias("r")).collect()
assert res == [Row(r="b")]

res = df.select(F.get(df.data, -1).alias("r")).collect()
res = df.select(sf.get(df.data, -1).alias("r")).collect()
assert res == [Row(r=None)]

res = df.select(F.get(df.data, 3).alias("r")).collect()
res = df.select(sf.get(df.data, 3).alias("r")).collect()
assert res == [Row(r=None)]

res = df.select(F.get(df.data, "index").alias("r")).collect()
res = df.select(sf.get(df.data, "index").alias("r")).collect()
assert res == [Row(r='b')]

res = df.select(F.get(df.data, F.col("index") - 1).alias("r")).collect()
res = df.select(sf.get(df.data, sf.col("index") - 1).alias("r")).collect()
assert res == [Row(r='a')]

def test_flatten(self, spark):
df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data'])

res = df.select(F.flatten(df.data).alias("r")).collect()
res = df.select(sf.flatten(df.data).alias("r")).collect()
assert res == [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]

def test_array_compact(self, spark):
df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data'])

res = df.select(F.array_compact(df.data).alias("v")).collect()
res = df.select(sf.array_compact(df.data).alias("v")).collect()
assert [Row(v=[1, 2, 3]), Row(v=[4, 5, 4])]

def test_array_remove(self, spark):
df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data'])

res = df.select(F.array_remove(df.data, 1).alias("v")).collect()
res = df.select(sf.array_remove(df.data, 1).alias("v")).collect()
assert res == [Row(v=[2, 3]), Row(v=[])]

def test_array_agg(self, spark):
df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"])

res = df.groupBy("group").agg(F.array_agg("c").alias("r")).collect()
res = df.groupBy("group").agg(sf.array_agg("c").alias("r")).collect()
assert res[0] == Row(group="A", r=[1, 1, 2])

def test_collect_list(self, spark):
df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"])

res = df.groupBy("group").agg(F.collect_list("c").alias("r")).collect()
res = df.groupBy("group").agg(sf.collect_list("c").alias("r")).collect()
assert res[0] == Row(group="A", r=[1, 1, 2])

def test_array_append(self, spark):
df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")], ["c1", "c2"])

res = df.select(F.array_append(df.c1, df.c2).alias("r")).collect()
res = df.select(sf.array_append(df.c1, df.c2).alias("r")).collect()
assert res == [Row(r=['b', 'a', 'c', 'c'])]

res = df.select(F.array_append(df.c1, 'x')).collect()
res = df.select(sf.array_append(df.c1, 'x')).collect()
assert res == [Row(r=['b', 'a', 'c', 'x'])]

def test_array_insert(self, spark):
Expand All @@ -137,21 +137,21 @@ def test_array_insert(self, spark):
['data', 'pos', 'val'],
)

res = df.select(F.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect()
res = df.select(sf.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect()
assert res == [
Row(data=['a', 'd', 'b', 'c']),
Row(data=['a', 'd', 'b', 'c', 'e']),
Row(data=['c', 'b', 'd', 'a']),
]

res = df.select(F.array_insert(df.data, 5, 'hello').alias('data')).collect()
res = df.select(sf.array_insert(df.data, 5, 'hello').alias('data')).collect()
assert res == [
Row(data=['a', 'b', 'c', None, 'hello']),
Row(data=['a', 'b', 'c', 'e', 'hello']),
Row(data=['c', 'b', 'a', None, 'hello']),
]

res = df.select(F.array_insert(df.data, -5, 'hello').alias('data')).collect()
res = df.select(sf.array_insert(df.data, -5, 'hello').alias('data')).collect()
assert res == [
Row(data=['hello', None, 'a', 'b', 'c']),
Row(data=['hello', 'a', 'b', 'c', 'e']),
Expand All @@ -160,67 +160,67 @@ def test_array_insert(self, spark):

def test_slice(self, spark):
df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
res = df.select(F.slice(df.x, 2, 2).alias("sliced")).collect()
res = df.select(sf.slice(df.x, 2, 2).alias("sliced")).collect()
assert res == [Row(sliced=[2, 3]), Row(sliced=[5])]

def test_sort_array(self, spark):
df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data'])

res = df.select(F.sort_array(df.data).alias('r')).collect()
res = df.select(sf.sort_array(df.data).alias('r')).collect()
assert res == [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]

res = df.select(F.sort_array(df.data, asc=False).alias('r')).collect()
res = df.select(sf.sort_array(df.data, asc=False).alias('r')).collect()
assert res == [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]

@pytest.mark.parametrize(("null_replacement", "expected_joined_2"), [(None, "a"), ("replaced", "a,replaced")])
def test_array_join(self, spark, null_replacement, expected_joined_2):
df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])

res = df.select(F.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect()
res = df.select(sf.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect()
assert res == [Row(joined='a,b,c'), Row(joined=expected_joined_2)]

def test_array_position(self, spark):
df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])

res = df.select(F.array_position(df.data, "a").alias("pos")).collect()
res = df.select(sf.array_position(df.data, "a").alias("pos")).collect()
assert res == [Row(pos=3), Row(pos=0)]

def test_array_preprend(self, spark):
df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data'])

res = df.select(F.array_prepend(df.data, 1).alias("pre")).collect()
res = df.select(sf.array_prepend(df.data, 1).alias("pre")).collect()
assert res == [Row(pre=[1, 2, 3, 4]), Row(pre=[1])]

def test_array_repeat(self, spark):
df = spark.createDataFrame([('ab',)], ['data'])

res = df.select(F.array_repeat(df.data, 3).alias('r')).collect()
res = df.select(sf.array_repeat(df.data, 3).alias('r')).collect()
assert res == [Row(r=['ab', 'ab', 'ab'])]

def test_array_size(self, spark):
df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data'])

res = df.select(F.array_size(df.data).alias('r')).collect()
res = df.select(sf.array_size(df.data).alias('r')).collect()
assert res == [Row(r=3), Row(r=None)]

def test_array_sort(self, spark):
df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data'])

res = df.select(F.array_sort(df.data).alias('r')).collect()
res = df.select(sf.array_sort(df.data).alias('r')).collect()
assert res == [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]

def test_arrays_overlap(self, spark):
df = spark.createDataFrame(
[(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ['x', 'y']
)

res = df.select(F.arrays_overlap(df.x, df.y).alias("overlap")).collect()
res = df.select(sf.arrays_overlap(df.x, df.y).alias("overlap")).collect()
assert res == [Row(overlap=True), Row(overlap=False), Row(overlap=None), Row(overlap=None)]

def test_arrays_zip(self, spark):
df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3'])

res = df.select(F.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect()
res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect()
# FIXME: The structure of the results should be the same
if USE_ACTUAL_SPARK:
assert res == [
Expand Down
Loading