Skip to content

Commit

Permalink
Merge pull request #11935 from mariotaddeucci/feature/pyspark-functions
Browse files Browse the repository at this point in the history
[Python] Add pyspark hash and organize unit tests
  • Loading branch information
Mytherin committed May 15, 2024
2 parents bd40daf + c1f4c15 commit 18fe304
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 17 deletions.
109 changes: 96 additions & 13 deletions tools/pythonpkg/duckdb/experimental/spark/sql/functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from .column import Column, _get_expr
from typing import Any, Callable, overload, Union
from typing import Any, Callable, Union, overload

from duckdb import (
CaseExpression,
ColumnExpression,
ConstantExpression,
Expression,
FunctionExpression,
)

from duckdb import CaseExpression, ConstantExpression, ColumnExpression, FunctionExpression, Expression
from ._typing import ColumnOrName
from ..exception import ContributionsAcceptedError
from ._typing import ColumnOrName
from .column import Column, _get_expr
from ._typing import ColumnOrName
from .column import Column, _get_expr


def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column:
Expand Down Expand Up @@ -66,7 +75,9 @@ def _inner_expr_or_val(val):


def struct(*cols: Column) -> Column:
return Column(FunctionExpression('struct_pack', *[_inner_expr_or_val(x) for x in cols]))
return Column(
FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols])
)


def lit(col: Any) -> Column:
Expand All @@ -93,7 +104,11 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum
[Row(d='-----')]
"""
return _invoke_function(
"regexp_replace", _to_column(str), ConstantExpression(pattern), ConstantExpression(replacement), ConstantExpression('g')
"regexp_replace",
_to_column(str),
ConstantExpression(pattern),
ConstantExpression(replacement),
ConstantExpression("g"),
)


Expand Down Expand Up @@ -322,13 +337,11 @@ def count(col: "ColumnOrName") -> Column:


@overload
def transform(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column:
...
def transform(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: ...


@overload
def transform(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column:
...
def transform(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column: ...


def transform(
Expand Down Expand Up @@ -417,9 +430,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column":
[Row(s='abcd-123')]
"""
cols = [_to_column(expr) for expr in cols]
return _invoke_function(
"concat_ws", ConstantExpression(sep), *cols
)
return _invoke_function("concat_ws", ConstantExpression(sep), *cols)


def lower(col: "ColumnOrName") -> Column:
Expand Down Expand Up @@ -847,3 +858,75 @@ def length(col: "ColumnOrName") -> Column:
[Row(length=4)]
"""
return _invoke_function_over_columns("length", col)


def md5(col: "ColumnOrName") -> Column:
"""Calculates the MD5 digest and returns the value as a 32 character hex string.
.. versionadded:: 1.5.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
target column to compute on.
Returns
-------
:class:`~pyspark.sql.Column`
the column for computed results.
Examples
--------
>>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect()
[Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')]
"""
return _invoke_function_over_columns("md5", col)


def sha2(col: "ColumnOrName", numBits: int) -> Column:
"""Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
and SHA-512). The numBits indicates the desired bit length of the result, which must have a
value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
.. versionadded:: 1.5.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
target column to compute on.
numBits : int
the desired bit length of the result, which must have a
value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
Returns
-------
:class:`~pyspark.sql.Column`
the column for computed results.
Examples
--------
>>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"])
>>> df.withColumn("sha2", sha2(df.name, 256)).show(truncate=False)
+-----+----------------------------------------------------------------+
|name |sha2 |
+-----+----------------------------------------------------------------+
|Alice|3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043|
|Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961|
+-----+----------------------------------------------------------------+
"""

if numBits not in {224, 256, 384, 512, 0}:
raise ValueError("numBits should be one of {224, 256, 384, 512, 0}")

if numBits == 256:
return _invoke_function_over_columns("sha256", col)

raise ContributionsAcceptedError(
"SHA-224, SHA-384, and SHA-512 are not supported yet."
)
30 changes: 30 additions & 0 deletions tools/pythonpkg/tests/fast/spark/test_spark_functions_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

_ = pytest.importorskip("duckdb.experimental.spark")
from duckdb.experimental.spark.sql import functions as F


class TestSparkFunctionsHash(object):
def test_md5(self, spark):
data = [
("quack",),
]
res = (
spark.createDataFrame(data, ["firstColumn"])
.withColumn("hashed_value", F.md5(F.col("firstColumn")))
.select("hashed_value")
.collect()
)
assert res[0].hashed_value == "cfaf278e8f522c72644cee2a753d2845"

def test_sha256(self, spark):
data = [
("quack",),
]
res = (
spark.createDataFrame(data, ["firstColumn"])
.withColumn("hashed_value", F.sha2(F.col("firstColumn"), 256))
.select("hashed_value")
.collect()
)
assert res[0].hashed_value == "82d928273d067d774889d5df4249aaf73c0b04c64f04d6ed001441ce87a0853c"
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest

_ = pytest.importorskip("duckdb.experimental.spark")
from duckdb.experimental.spark.sql.types import Row
from duckdb.experimental.spark.sql import functions as F
from duckdb.experimental.spark.sql.types import Row


class TestNumericFunctions(object):
class TestSparkFunctionsNumeric(object):
def test_greatest(self, spark):
data = [
(1, 2),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest

_ = pytest.importorskip("duckdb.experimental.spark")
from duckdb.experimental.spark.sql.types import Row
from duckdb.experimental.spark.sql import functions as F
from duckdb.experimental.spark.sql.types import Row


class TestStringFunctions(object):
class TestSparkFunctionsString(object):
def test_length(self, spark):
data = [
("firstRowFirstColumn",),
Expand Down

0 comments on commit 18fe304

Please sign in to comment.