From 7649e2aa6bb7f6511664176ac98ab494458e4ef9 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 28 Aug 2025 14:01:25 +0200 Subject: [PATCH] Forward port of https://github.com/duckdb/duckdb/pull/15462 and https://github.com/duckdb/duckdb/pull/15036 --- duckdb/experimental/spark/sql/functions.py | 123 ++++++++++++++++++ .../fast/spark/test_spark_functions_array.py | 64 ++++----- tests/fast/spark/test_spark_functions_date.py | 25 +++- tests/fast/spark/test_spark_functions_null.py | 5 + .../spark/test_spark_functions_numeric.py | 82 ++++++------ 5 files changed, 229 insertions(+), 70 deletions(-) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index a6d67aeb..fecada95 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1851,6 +1851,30 @@ def isnotnull(col: "ColumnOrName") -> Column: return Column(_to_column_expr(col).isnotnull()) +def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: + """ + Returns same result as the EQUAL(=) operator for non-null operands, + but returns true if both are null, false if one of the them is null. + .. versionadded:: 3.5.0 + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + col2 : :class:`~pyspark.sql.Column` or str + Examples + -------- + >>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"]) + >>> df.select(equal_null(df.a, df.b).alias('r')).collect() + [Row(r=True), Row(r=False)] + """ + if isinstance(col1, str): + col1 = col(col1) + + if isinstance(col2, str): + col2 = col(col2) + + return nvl((col1 == col2) | ((col1.isNull() & col2.isNull())), lit(False)) + + def flatten(col: "ColumnOrName") -> Column: """ Collection function: creates a single array from an array of arrays. @@ -2157,6 +2181,33 @@ def e() -> Column: return lit(2.718281828459045) +def negative(col: "ColumnOrName") -> Column: + """ + Returns the negative value. + .. versionadded:: 3.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column to calculate negative value for. + Returns + ------- + :class:`~pyspark.sql.Column` + negative value. + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(3).select(sf.negative("id")).show() + +------------+ + |negative(id)| + +------------+ + | 0| + | -1| + | -2| + +------------+ + """ + return abs(col) * -1 + + def pi() -> Column: """Returns Pi. @@ -3774,6 +3825,53 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: return date_part(field, source) +def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: + """ + Returns the number of days from `start` to `end`. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + end : :class:`~pyspark.sql.Column` or column name + to date column to work on. + start : :class:`~pyspark.sql.Column` or column name + from date column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + difference in days between two dates. + + See Also + -------- + :meth:`pyspark.sql.functions.dateadd` + :meth:`pyspark.sql.functions.date_add` + :meth:`pyspark.sql.functions.date_sub` + :meth:`pyspark.sql.functions.datediff` + :meth:`pyspark.sql.functions.timestamp_diff` + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) + >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + +----------+----------+-----------------+ + | d1| d2|date_diff(d1, d2)| + +----------+----------+-----------------+ + |2015-04-08|2015-05-10| -32| + +----------+----------+-----------------+ + + >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + +----------+----------+-----------------+ + | d1| d2|date_diff(d2, d1)| + +----------+----------+-----------------+ + |2015-04-08|2015-05-10| 32| + +----------+----------+-----------------+ + """ + return _invoke_function_over_columns("date_diff", lit("day"), end, start) + + def year(col: "ColumnOrName") -> Column: """ Extract the year of a given date/timestamp as integer. @@ -5685,6 +5783,31 @@ def to_timestamp_ntz( return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format) +def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> Column: + """ + Parses the `col` with the `format` to a timestamp. The function always + returns null on an invalid input with/without ANSI SQL mode enabled. The result data type is + consistent with the value of configuration `spark.sql.timestampType`. + .. versionadded:: 3.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column values to convert. + format: str, optional + format to use to convert timestamp values. + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(try_to_timestamp(df.t).alias('dt')).collect() + [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] + >>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect() + [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] + """ + if format is None: + format = lit(['%Y-%m-%d', '%Y-%m-%d %H:%M:%S']) + + return _invoke_function_over_columns("try_strptime", col, format) + def substr( str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None ) -> Column: diff --git a/tests/fast/spark/test_spark_functions_array.py b/tests/fast/spark/test_spark_functions_array.py index 77c4c21a..f83e0ef2 100644 --- a/tests/fast/spark/test_spark_functions_array.py +++ b/tests/fast/spark/test_spark_functions_array.py @@ -2,7 +2,7 @@ import platform _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql import functions as F +from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row from spark_namespace import USE_ACTUAL_SPARK @@ -19,7 +19,7 @@ def test_array_distinct(self, spark): ([2, 4, 5], 3), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("distinct_values", F.array_distinct(F.col("firstColumn"))) + df = df.withColumn("distinct_values", sf.array_distinct(sf.col("firstColumn"))) res = df.select("distinct_values").collect() # Output order can vary across platforms which is why we sort it first assert len(res) == 2 @@ -31,7 +31,7 @@ def test_array_intersect(self, spark): (["b", "a", "c"], ["c", "d", "a", "f"]), ] df = spark.createDataFrame(data, ["c1", "c2"]) - df = df.withColumn("intersect_values", F.array_intersect(F.col("c1"), F.col("c2"))) + df = df.withColumn("intersect_values", sf.array_intersect(sf.col("c1"), sf.col("c2"))) res = df.select("intersect_values").collect() # Output order can vary across platforms which is why we sort it first assert len(res) == 1 @@ -42,7 +42,7 @@ def test_array_union(self, spark): (["b", "a", "c"], ["c", "d", "a", "f"]), ] df = spark.createDataFrame(data, ["c1", "c2"]) - df = df.withColumn("union_values", F.array_union(F.col("c1"), F.col("c2"))) + df = df.withColumn("union_values", sf.array_union(sf.col("c1"), sf.col("c2"))) res = df.select("union_values").collect() # Output order can vary across platforms which is why we sort it first assert len(res) == 1 @@ -54,7 +54,7 @@ def test_array_max(self, spark): ([4, 2, 5], 5), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("max_value", F.array_max(F.col("firstColumn"))) + df = df.withColumn("max_value", sf.array_max(sf.col("firstColumn"))) res = df.select("max_value").collect() assert res == [ Row(max_value=3), @@ -67,7 +67,7 @@ def test_array_min(self, spark): ([2, 4, 5], 5), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("min_value", F.array_min(F.col("firstColumn"))) + df = df.withColumn("min_value", sf.array_min(sf.col("firstColumn"))) res = df.select("min_value").collect() assert res == [ Row(max_value=1), @@ -77,58 +77,58 @@ def test_array_min(self, spark): def test_get(self, spark): df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) - res = df.select(F.get(df.data, 1).alias("r")).collect() + res = df.select(sf.get(df.data, 1).alias("r")).collect() assert res == [Row(r="b")] - res = df.select(F.get(df.data, -1).alias("r")).collect() + res = df.select(sf.get(df.data, -1).alias("r")).collect() assert res == [Row(r=None)] - res = df.select(F.get(df.data, 3).alias("r")).collect() + res = df.select(sf.get(df.data, 3).alias("r")).collect() assert res == [Row(r=None)] - res = df.select(F.get(df.data, "index").alias("r")).collect() + res = df.select(sf.get(df.data, "index").alias("r")).collect() assert res == [Row(r='b')] - res = df.select(F.get(df.data, F.col("index") - 1).alias("r")).collect() + res = df.select(sf.get(df.data, sf.col("index") - 1).alias("r")).collect() assert res == [Row(r='a')] def test_flatten(self, spark): df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) - res = df.select(F.flatten(df.data).alias("r")).collect() + res = df.select(sf.flatten(df.data).alias("r")).collect() assert res == [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] def test_array_compact(self, spark): df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) - res = df.select(F.array_compact(df.data).alias("v")).collect() + res = df.select(sf.array_compact(df.data).alias("v")).collect() assert [Row(v=[1, 2, 3]), Row(v=[4, 5, 4])] def test_array_remove(self, spark): df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) - res = df.select(F.array_remove(df.data, 1).alias("v")).collect() + res = df.select(sf.array_remove(df.data, 1).alias("v")).collect() assert res == [Row(v=[2, 3]), Row(v=[])] def test_array_agg(self, spark): df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"]) - res = df.groupBy("group").agg(F.array_agg("c").alias("r")).collect() + res = df.groupBy("group").agg(sf.array_agg("c").alias("r")).collect() assert res[0] == Row(group="A", r=[1, 1, 2]) def test_collect_list(self, spark): df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"]) - res = df.groupBy("group").agg(F.collect_list("c").alias("r")).collect() + res = df.groupBy("group").agg(sf.collect_list("c").alias("r")).collect() assert res[0] == Row(group="A", r=[1, 1, 2]) def test_array_append(self, spark): df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")], ["c1", "c2"]) - res = df.select(F.array_append(df.c1, df.c2).alias("r")).collect() + res = df.select(sf.array_append(df.c1, df.c2).alias("r")).collect() assert res == [Row(r=['b', 'a', 'c', 'c'])] - res = df.select(F.array_append(df.c1, 'x')).collect() + res = df.select(sf.array_append(df.c1, 'x')).collect() assert res == [Row(r=['b', 'a', 'c', 'x'])] def test_array_insert(self, spark): @@ -137,21 +137,21 @@ def test_array_insert(self, spark): ['data', 'pos', 'val'], ) - res = df.select(F.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + res = df.select(sf.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() assert res == [ Row(data=['a', 'd', 'b', 'c']), Row(data=['a', 'd', 'b', 'c', 'e']), Row(data=['c', 'b', 'd', 'a']), ] - res = df.select(F.array_insert(df.data, 5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, 5, 'hello').alias('data')).collect() assert res == [ Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['a', 'b', 'c', 'e', 'hello']), Row(data=['c', 'b', 'a', None, 'hello']), ] - res = df.select(F.array_insert(df.data, -5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, -5, 'hello').alias('data')).collect() assert res == [ Row(data=['hello', None, 'a', 'b', 'c']), Row(data=['hello', 'a', 'b', 'c', 'e']), @@ -160,53 +160,53 @@ def test_array_insert(self, spark): def test_slice(self, spark): df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) - res = df.select(F.slice(df.x, 2, 2).alias("sliced")).collect() + res = df.select(sf.slice(df.x, 2, 2).alias("sliced")).collect() assert res == [Row(sliced=[2, 3]), Row(sliced=[5])] def test_sort_array(self, spark): df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) - res = df.select(F.sort_array(df.data).alias('r')).collect() + res = df.select(sf.sort_array(df.data).alias('r')).collect() assert res == [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] - res = df.select(F.sort_array(df.data, asc=False).alias('r')).collect() + res = df.select(sf.sort_array(df.data, asc=False).alias('r')).collect() assert res == [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] @pytest.mark.parametrize(("null_replacement", "expected_joined_2"), [(None, "a"), ("replaced", "a,replaced")]) def test_array_join(self, spark, null_replacement, expected_joined_2): df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) - res = df.select(F.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect() + res = df.select(sf.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect() assert res == [Row(joined='a,b,c'), Row(joined=expected_joined_2)] def test_array_position(self, spark): df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) - res = df.select(F.array_position(df.data, "a").alias("pos")).collect() + res = df.select(sf.array_position(df.data, "a").alias("pos")).collect() assert res == [Row(pos=3), Row(pos=0)] def test_array_preprend(self, spark): df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) - res = df.select(F.array_prepend(df.data, 1).alias("pre")).collect() + res = df.select(sf.array_prepend(df.data, 1).alias("pre")).collect() assert res == [Row(pre=[1, 2, 3, 4]), Row(pre=[1])] def test_array_repeat(self, spark): df = spark.createDataFrame([('ab',)], ['data']) - res = df.select(F.array_repeat(df.data, 3).alias('r')).collect() + res = df.select(sf.array_repeat(df.data, 3).alias('r')).collect() assert res == [Row(r=['ab', 'ab', 'ab'])] def test_array_size(self, spark): df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) - res = df.select(F.array_size(df.data).alias('r')).collect() + res = df.select(sf.array_size(df.data).alias('r')).collect() assert res == [Row(r=3), Row(r=None)] def test_array_sort(self, spark): df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) - res = df.select(F.array_sort(df.data).alias('r')).collect() + res = df.select(sf.array_sort(df.data).alias('r')).collect() assert res == [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] def test_arrays_overlap(self, spark): @@ -214,13 +214,13 @@ def test_arrays_overlap(self, spark): [(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ['x', 'y'] ) - res = df.select(F.arrays_overlap(df.x, df.y).alias("overlap")).collect() + res = df.select(sf.arrays_overlap(df.x, df.y).alias("overlap")).collect() assert res == [Row(overlap=True), Row(overlap=False), Row(overlap=None), Row(overlap=None)] def test_arrays_zip(self, spark): df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3']) - res = df.select(F.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect() + res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect() # FIXME: The structure of the results should be the same if USE_ACTUAL_SPARK: assert res == [ diff --git a/tests/fast/spark/test_spark_functions_date.py b/tests/fast/spark/test_spark_functions_date.py index 8a03fd68..2a51d9b8 100644 --- a/tests/fast/spark/test_spark_functions_date.py +++ b/tests/fast/spark/test_spark_functions_date.py @@ -2,7 +2,7 @@ import pytest _ = pytest.importorskip("duckdb.experimental.spark") -from datetime import date, datetime, timezone +from datetime import date, datetime from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as F @@ -217,3 +217,26 @@ def test_add_months(self, spark): assert result[0].with_literal == date(2024, 6, 12) assert result[0].with_str == date(2024, 7, 12) assert result[0].with_col == date(2024, 7, 12) + + def test_date_diff(self, spark): + df = spark.createDataFrame([('2015-04-08', '2015-05-10')], ["d1", "d2"]) + + result_data = df.select(F.date_diff(col("d2").cast('DATE'), col("d1").cast('DATE')).alias("diff")).collect() + assert result_data[0]["diff"] == -32 + + result_data = df.select(F.date_diff(col("d1").cast('DATE'), col("d2").cast('DATE')).alias("diff")).collect() + assert result_data[0]["diff"] == 32 + + def test_try_to_timestamp(self, spark): + df = spark.createDataFrame([("1997-02-28 10:30:00",), ("2024-01-01",), ("invalid",)], ["t"]) + res = df.select(F.try_to_timestamp(df.t).alias("dt")).collect() + assert res[0].dt == datetime(1997, 2, 28, 10, 30) + assert res[1].dt == datetime(2024, 1, 1, 0, 0) + assert res[2].dt is None + + def test_try_to_timestamp_with_format(self, spark): + df = spark.createDataFrame([("1997-02-28 10:30:00",), ("2024-01-01",), ("invalid",)], ["t"]) + res = df.select(F.try_to_timestamp(df.t, format=F.lit("%Y-%m-%d %H:%M:%S")).alias("dt")).collect() + assert res[0].dt == datetime(1997, 2, 28, 10, 30) + assert res[1].dt is None + assert res[2].dt is None \ No newline at end of file diff --git a/tests/fast/spark/test_spark_functions_null.py b/tests/fast/spark/test_spark_functions_null.py index 39ca4ce2..3f5ee31b 100644 --- a/tests/fast/spark/test_spark_functions_null.py +++ b/tests/fast/spark/test_spark_functions_null.py @@ -112,3 +112,8 @@ def test_isnotnull(self, spark): Row(a=1, b=None, r1=True, r2=False), Row(a=None, b=2, r1=False, r2=True), ] + + def test_equal_null(self, spark): + df = spark.createDataFrame([(1, 1), (None, 2), (None, None)], ("a", "b")) + res = df.select(F.equal_null("a", F.col("b")).alias("r")).collect() + assert res == [Row(r=True), Row(r=False), Row(r=True)] \ No newline at end of file diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 3d7b5c3b..9c4bafb9 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -5,7 +5,7 @@ import math import numpy as np from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql import functions as F +from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row @@ -16,7 +16,7 @@ def test_greatest(self, spark): (4, 3), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("greatest_value", F.greatest(F.col("firstColumn"), F.col("secondColumn"))) + df = df.withColumn("greatest_value", sf.greatest(sf.col("firstColumn"), sf.col("secondColumn"))) res = df.select("greatest_value").collect() assert res == [ Row(greatest_value=2), @@ -29,7 +29,7 @@ def test_least(self, spark): (4, 3), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("least_value", F.least(F.col("firstColumn"), F.col("secondColumn"))) + df = df.withColumn("least_value", sf.least(sf.col("firstColumn"), sf.col("secondColumn"))) res = df.select("least_value").collect() assert res == [ Row(least_value=1), @@ -42,7 +42,7 @@ def test_ceil(self, spark): (2.9,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("ceil_value", F.ceil(F.col("firstColumn"))) + df = df.withColumn("ceil_value", sf.ceil(sf.col("firstColumn"))) res = df.select("ceil_value").collect() assert res == [ Row(ceil_value=2), @@ -55,7 +55,7 @@ def test_floor(self, spark): (2.9,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("floor_value", F.floor(F.col("firstColumn"))) + df = df.withColumn("floor_value", sf.floor(sf.col("firstColumn"))) res = df.select("floor_value").collect() assert res == [ Row(floor_value=1), @@ -68,7 +68,7 @@ def test_abs(self, spark): (-2.9,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("abs_value", F.abs(F.col("firstColumn"))) + df = df.withColumn("abs_value", sf.abs(sf.col("firstColumn"))) res = df.select("abs_value").collect() assert res == [ Row(abs_value=1.1), @@ -81,7 +81,7 @@ def test_sqrt(self, spark): (9,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("sqrt_value", F.sqrt(F.col("firstColumn"))) + df = df.withColumn("sqrt_value", sf.sqrt(sf.col("firstColumn"))) res = df.select("sqrt_value").collect() assert res == [ Row(sqrt_value=2.0), @@ -94,7 +94,7 @@ def test_cbrt(self, spark): (27,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("cbrt_value", F.cbrt(F.col("firstColumn"))) + df = df.withColumn("cbrt_value", sf.cbrt(sf.col("firstColumn"))) res = df.select("cbrt_value").collect() assert pytest.approx(res[0].cbrt_value) == 2.0 assert pytest.approx(res[1].cbrt_value) == 3.0 @@ -105,7 +105,7 @@ def test_cos(self, spark): (3.14159,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("cos_value", F.cos(F.col("firstColumn"))) + df = df.withColumn("cos_value", sf.cos(sf.col("firstColumn"))) res = df.select("cos_value").collect() assert len(res) == 2 assert res[0].cos_value == pytest.approx(1.0) @@ -117,7 +117,7 @@ def test_acos(self, spark): (-1,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("acos_value", F.acos(F.col("firstColumn"))) + df = df.withColumn("acos_value", sf.acos(sf.col("firstColumn"))) res = df.select("acos_value").collect() assert len(res) == 2 assert res[0].acos_value == pytest.approx(0.0) @@ -129,7 +129,7 @@ def test_exp(self, spark): (0.0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("exp_value", F.exp(F.col("firstColumn"))) + df = df.withColumn("exp_value", sf.exp(sf.col("firstColumn"))) res = df.select("exp_value").collect() round(res[0].exp_value, 2) == 2 res[1].exp_value == 1 @@ -140,7 +140,7 @@ def test_factorial(self, spark): (5,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("factorial_value", F.factorial(F.col("firstColumn"))) + df = df.withColumn("factorial_value", sf.factorial(sf.col("firstColumn"))) res = df.select("factorial_value").collect() assert res == [ Row(factorial_value=24), @@ -153,7 +153,7 @@ def test_log2(self, spark): (8,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("log2_value", F.log2(F.col("firstColumn"))) + df = df.withColumn("log2_value", sf.log2(sf.col("firstColumn"))) res = df.select("log2_value").collect() assert res == [ Row(log2_value=2.0), @@ -166,7 +166,7 @@ def test_ln(self, spark): (1.0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("ln_value", F.ln(F.col("firstColumn"))) + df = df.withColumn("ln_value", sf.ln(sf.col("firstColumn"))) res = df.select("ln_value").collect() round(res[0].ln_value, 2) == 1 res[1].ln_value == 0 @@ -177,7 +177,7 @@ def test_degrees(self, spark): (0.0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("degrees_value", F.degrees(F.col("firstColumn"))) + df = df.withColumn("degrees_value", sf.degrees(sf.col("firstColumn"))) res = df.select("degrees_value").collect() round(res[0].degrees_value, 2) == 180 res[1].degrees_value == 0 @@ -188,7 +188,7 @@ def test_radians(self, spark): (0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("radians_value", F.radians(F.col("firstColumn"))) + df = df.withColumn("radians_value", sf.radians(sf.col("firstColumn"))) res = df.select("radians_value").collect() round(res[0].radians_value, 2) == 3.14 res[1].radians_value == 0 @@ -199,7 +199,7 @@ def test_atan(self, spark): (0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("atan_value", F.atan(F.col("firstColumn"))) + df = df.withColumn("atan_value", sf.atan(sf.col("firstColumn"))) res = df.select("atan_value").collect() round(res[0].atan_value, 2) == 0.79 res[1].atan_value == 0 @@ -212,19 +212,19 @@ def test_atan2(self, spark): df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) # Both columns - df2 = df.withColumn("atan2_value", F.atan2(F.col("firstColumn"), "secondColumn")) + df2 = df.withColumn("atan2_value", sf.atan2(sf.col("firstColumn"), "secondColumn")) res = df2.select("atan2_value").collect() round(res[0].atan2_value, 2) == 0.79 res[1].atan2_value == 0 # Both literals - df2 = df.withColumn("atan2_value_lit", F.atan2(1, 1)) + df2 = df.withColumn("atan2_value_lit", sf.atan2(1, 1)) res = df2.select("atan2_value_lit").collect() round(res[0].atan2_value_lit, 2) == 0.79 round(res[1].atan2_value_lit, 2) == 0.79 # One literal, one column - df2 = df.withColumn("atan2_value_lit_col", F.atan2(1.0, F.col("secondColumn"))) + df2 = df.withColumn("atan2_value_lit_col", sf.atan2(1.0, sf.col("secondColumn"))) res = df2.select("atan2_value_lit_col").collect() round(res[0].atan2_value_lit_col, 2) == 0.79 res[1].atan2_value_lit_col == 0 @@ -235,7 +235,7 @@ def test_tan(self, spark): (1,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("tan_value", F.tan(F.col("firstColumn"))) + df = df.withColumn("tan_value", sf.tan(sf.col("firstColumn"))) res = df.select("tan_value").collect() res[0].tan_value == 0 round(res[1].tan_value, 2) == 1.56 @@ -249,9 +249,9 @@ def test_round(self, spark): ] df = spark.createDataFrame(data, ["firstColumn"]) df = ( - df.withColumn("round_value", F.round("firstColumn")) - .withColumn("round_value_1", F.round(F.col("firstColumn"), 1)) - .withColumn("round_value_minus_1", F.round("firstColumn", -1)) + df.withColumn("round_value", sf.round("firstColumn")) + .withColumn("round_value_1", sf.round(sf.col("firstColumn"), 1)) + .withColumn("round_value_minus_1", sf.round("firstColumn", -1)) ) res = df.select("round_value", "round_value_1", "round_value_minus_1").collect() assert res == [ @@ -269,9 +269,9 @@ def test_bround(self, spark): ] df = spark.createDataFrame(data, ["firstColumn"]) df = ( - df.withColumn("round_value", F.bround(F.col("firstColumn"))) - .withColumn("round_value_1", F.bround(F.col("firstColumn"), 1)) - .withColumn("round_value_minus_1", F.bround(F.col("firstColumn"), -1)) + df.withColumn("round_value", sf.bround(sf.col("firstColumn"))) + .withColumn("round_value_1", sf.bround(sf.col("firstColumn"), 1)) + .withColumn("round_value_minus_1", sf.bround(sf.col("firstColumn"), -1)) ) res = df.select("round_value", "round_value_1", "round_value_minus_1").collect() assert res == [ @@ -283,7 +283,7 @@ def test_bround(self, spark): def test_asin(self, spark): df = spark.createDataFrame([(0,), (2,)], ["value"]) - df = df.withColumn("asin_value", F.asin("value")) + df = df.withColumn("asin_value", sf.asin("value")) res = df.select("asin_value").collect() assert res[0].asin_value == 0 @@ -301,36 +301,36 @@ def test_corr(self, spark): # Have to use a groupby to test this as agg is not yet implemented without df = spark.createDataFrame(zip(a, b, ["group1"] * N), ["a", "b", "g"]) - res = df.groupBy("g").agg(F.corr("a", "b").alias('c')).collect() + res = df.groupBy("g").agg(sf.corr("a", "b").alias('c')).collect() assert pytest.approx(res[0].c) == 1 def test_cot(self, spark): df = spark.createDataFrame([(math.radians(45),)], ["value"]) - res = df.select(F.cot(df["value"]).alias("cot")).collect() + res = df.select(sf.cot(df["value"]).alias("cot")).collect() assert pytest.approx(res[0].cot) == 1 def test_e(self, spark): df = spark.createDataFrame([("value",)], ["value"]) - res = df.select(F.e().alias("e")).collect() + res = df.select(sf.e().alias("e")).collect() assert pytest.approx(res[0].e) == math.e def test_pi(self, spark): df = spark.createDataFrame([("value",)], ["value"]) - res = df.select(F.pi().alias("pi")).collect() + res = df.select(sf.pi().alias("pi")).collect() assert pytest.approx(res[0].pi) == math.pi def test_pow(self, spark): df = spark.createDataFrame([(2, 3)], ["a", "b"]) - res = df.select(F.pow(df["a"], df["b"]).alias("pow")).collect() + res = df.select(sf.pow(df["a"], df["b"]).alias("pow")).collect() assert res[0].pow == 8 def test_random(self, spark): df = spark.range(0, 2, 1) - res = df.withColumn('rand', F.rand()).collect() + res = df.withColumn('rand', sf.rand()).collect() assert isinstance(res[0].rand, float) assert res[0].rand >= 0 and res[0].rand < 1 @@ -338,13 +338,21 @@ def test_random(self, spark): assert isinstance(res[1].rand, float) assert res[1].rand >= 0 and res[1].rand < 1 - @pytest.mark.parametrize("sign_func", [F.sign, F.signum]) + @pytest.mark.parametrize("sign_func", [sf.sign, sf.signum]) def test_sign(self, spark, sign_func): - df = spark.range(1).select(sign_func(F.lit(-5).alias("v1")), sign_func(F.lit(6).alias("v2"))) + df = spark.range(1).select(sign_func(sf.lit(-5).alias("v1")), sign_func(sf.lit(6).alias("v2"))) res = df.collect() assert res == [Row(v1=-1.0, v2=1.0)] def test_sin(self, spark): df = spark.range(1) - res = df.select(F.sin(F.lit(math.radians(90))).alias("v")).collect() + res = df.select(sf.sin(sf.lit(math.radians(90))).alias("v")).collect() assert res == [Row(v=1.0)] + + def test_negative(self, spark): + df = spark.createDataFrame([(0,), (2,), (-3,)], ["value"]) + df = df.withColumn("value", sf.negative(sf.col("value"))) + res = df.collect() + assert res[0].value == 0 + assert res[1].value == -2 + assert res[2].value == -3 \ No newline at end of file