Skip to content

Commit

Permalink
depr: deprecate and warn on legacy udf usage (#8617)
Browse files Browse the repository at this point in the history
  • Loading branch information
ncclementi authored Mar 12, 2024
1 parent 710f8ac commit e561889
Show file tree
Hide file tree
Showing 12 changed files with 402 additions and 446 deletions.
179 changes: 89 additions & 90 deletions ibis/backends/dask/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,69 +85,60 @@ def t_timestamp(con):
# UDF Functions
# -------------

with pytest.warns(FutureWarning, match="v9.0"):

@elementwise(input_type=["string"], output_type="int64")
def my_string_length(series, **kwargs):
return series.str.len() * 2
@elementwise(input_type=["string"], output_type="int64")
def my_string_length(series, **kwargs):
return series.str.len() * 2

@elementwise(input_type=[dt.double, dt.double], output_type=dt.double)
def my_add(series1, series2, **kwargs):
return series1 + series2

@elementwise(input_type=[dt.double, dt.double], output_type=dt.double)
def my_add(series1, series2, **kwargs):
return series1 + series2


@reduction(["double"], "double")
def my_mean(series):
return series.mean()


@reduction(
input_type=[dt.Timestamp(timezone="UTC")],
output_type=dt.Timestamp(timezone="UTC"),
)
def my_tz_min(series):
return series.min()


@elementwise(
input_type=[dt.Timestamp(timezone="UTC")],
output_type=dt.Timestamp(timezone="UTC"),
)
def my_tz_add_one(series):
return series + pd.Timedelta(1, unit="D")


@reduction(input_type=[dt.string], output_type=dt.int64)
def my_string_length_sum(series, **kwargs):
return (series.str.len() * 2).sum()


@reduction(input_type=[dt.double, dt.double], output_type=dt.double)
def my_corr(lhs, rhs, **kwargs):
return lhs.corr(rhs)
@reduction(["double"], "double")
def my_mean(series):
return series.mean()

@reduction(
input_type=[dt.Timestamp(timezone="UTC")],
output_type=dt.Timestamp(timezone="UTC"),
)
def my_tz_min(series):
return series.min()

@elementwise([dt.double], dt.double)
def add_one(x):
return x + 1.0
@elementwise(
input_type=[dt.Timestamp(timezone="UTC")],
output_type=dt.Timestamp(timezone="UTC"),
)
def my_tz_add_one(series):
return series + pd.Timedelta(1, unit="D")

@reduction(input_type=[dt.string], output_type=dt.int64)
def my_string_length_sum(series, **kwargs):
return (series.str.len() * 2).sum()

@elementwise([dt.double], dt.double)
def times_two(x):
return x * 2.0
@reduction(input_type=[dt.double, dt.double], output_type=dt.double)
def my_corr(lhs, rhs, **kwargs):
return lhs.corr(rhs)

@elementwise([dt.double], dt.double)
def add_one(x):
return x + 1.0

@analytic(input_type=["double"], output_type="double")
def zscore(series):
return (series - series.mean()) / series.std()
@elementwise([dt.double], dt.double)
def times_two(x):
return x * 2.0

@analytic(input_type=["double"], output_type="double")
def zscore(series):
return (series - series.mean()) / series.std()

@reduction(
input_type=[dt.double],
output_type=dt.Array(dt.double),
)
def collect(series):
return list(series)
@reduction(
input_type=[dt.double],
output_type=dt.Array(dt.double),
)
def collect(series):
return list(series)


# -----
Expand Down Expand Up @@ -371,34 +362,38 @@ def test_array_return_type_reduction_group_by(t, df):


def test_elementwise_udf_with_many_args(t2):
@elementwise(input_type=[dt.double] * 16 + [dt.int32] * 8, output_type=dt.double)
def my_udf(
c1,
c2,
c3,
c4,
c5,
c6,
c7,
c8,
c9,
c10,
c11,
c12,
c13,
c14,
c15,
c16,
c17,
c18,
c19,
c20,
c21,
c22,
c23,
c24,
):
return c1
with pytest.warns(FutureWarning, match="v9.0"):

@elementwise(
input_type=[dt.double] * 16 + [dt.int32] * 8, output_type=dt.double
)
def my_udf(
c1,
c2,
c3,
c4,
c5,
c6,
c7,
c8,
c9,
c10,
c11,
c12,
c13,
c14,
c15,
c16,
c17,
c18,
c19,
c20,
c21,
c22,
c23,
c24,
):
return c1

expr = my_udf(*([t2.a] * 8 + [t2.b] * 8 + [t2.c] * 8))
result = expr.execute()
Expand All @@ -408,30 +403,34 @@ def my_udf(


# -----------------
# Test raied errors
# Test raised errors
# -----------------


def test_udaf_parameter_mismatch():
with pytest.raises(TypeError):
with pytest.warns(FutureWarning, match="v9.0"):

@reduction(input_type=[dt.double], output_type=dt.double)
def my_corr(lhs, rhs, **kwargs):
pass
@reduction(input_type=[dt.double], output_type=dt.double)
def my_corr(lhs, rhs, **kwargs):
pass


def test_udf_parameter_mismatch():
with pytest.raises(TypeError):
with pytest.warns(FutureWarning, match="v9.0"):

@reduction(input_type=[], output_type=dt.double)
def my_corr2(lhs, **kwargs):
pass
@reduction(input_type=[], output_type=dt.double)
def my_corr2(lhs, **kwargs):
pass


def test_udf_error(t):
@elementwise(input_type=[dt.double], output_type=dt.double)
def error_udf(s):
raise ValueError("xxx")
with pytest.warns(FutureWarning, match="v9.0"):

@elementwise(input_type=[dt.double], output_type=dt.double)
def error_udf(s):
raise ValueError("xxx")

with pytest.raises(ValueError):
error_udf(t.c).execute()
13 changes: 7 additions & 6 deletions ibis/backends/dask/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,14 +473,15 @@ def test_window_on_and_by_key_as_window_input(t, df):
)

# Test UDF
with pytest.warns(FutureWarning, match="v9.0"):

@reduction(input_type=[dt.int64], output_type=dt.int64)
def count(v):
return len(v)
@reduction(input_type=[dt.int64], output_type=dt.int64)
def count(v):
return len(v)

@reduction(input_type=[dt.int64, dt.int64], output_type=dt.int64)
def count_both(v1, v2):
return len(v1)
@reduction(input_type=[dt.int64, dt.int64], output_type=dt.int64)
def count_both(v1, v2):
return len(v1)

tm.assert_series_equal(
count(t[order_by]).over(row_window).execute(),
Expand Down
23 changes: 11 additions & 12 deletions ibis/backends/datafusion/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,20 @@
pytest.importorskip("datafusion")
pc = pytest.importorskip("pyarrow.compute")

with pytest.warns(FutureWarning, match="v9.0"):

@elementwise(input_type=["string"], output_type="int64")
def my_string_length(arr, **kwargs):
# arr is a pyarrow.StringArray
return pc.cast(pc.multiply(pc.utf8_length(arr), 2), target_type="int64")
@elementwise(input_type=["string"], output_type="int64")
def my_string_length(arr, **kwargs):
# arr is a pyarrow.StringArray
return pc.cast(pc.multiply(pc.utf8_length(arr), 2), target_type="int64")

@elementwise(input_type=[dt.int64, dt.int64], output_type=dt.int64)
def my_add(arr1, arr2, **kwargs):
return pc.add(arr1, arr2)

@elementwise(input_type=[dt.int64, dt.int64], output_type=dt.int64)
def my_add(arr1, arr2, **kwargs):
return pc.add(arr1, arr2)


@reduction(input_type=[dt.float64], output_type=dt.float64)
def my_mean(arr):
return pc.mean(arr)
@reduction(input_type=[dt.float64], output_type=dt.float64)
def my_mean(arr):
return pc.mean(arr)


def test_udf(alltypes):
Expand Down
16 changes: 10 additions & 6 deletions ibis/backends/pandas/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ def test_quantile_multi_array_access(client, t, df):
def test_execute_with_same_hash_value_in_scope(
left, right, expected_value, expected_type, left_dtype, right_dtype
):
@udf.elementwise([left_dtype, right_dtype], left_dtype)
def my_func(x, _):
return x
with pytest.warns(FutureWarning, match="v9.0"):

@udf.elementwise([left_dtype, right_dtype], left_dtype)
def my_func(x, _):
return x

df = pd.DataFrame({"left": [left], "right": [right]})
con = ibis.pandas.connect()
Expand Down Expand Up @@ -255,9 +257,11 @@ def test_ifelse_returning_bool():
],
)
def test_signature_does_not_match_input_type(dtype, value):
@udf.elementwise([dtype], dtype)
def func(x):
return x
with pytest.warns(FutureWarning, match="v9.0"):

@udf.elementwise([dtype], dtype)
def func(x):
return x

df = pd.DataFrame({"col": [value]})
table = ibis.pandas.connect().from_dataframe(df)
Expand Down
Loading

0 comments on commit e561889

Please sign in to comment.