Skip to content

Commit

Permalink
Fix dtype errors in StringArrays (rapidsai#16111)
Browse files Browse the repository at this point in the history
This PR adds proxy classes for `ArrowStringArray` and `ArrowStringArrayNumpySemantics` that will increase the pandas test pass rate by 1%.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: rapidsai#16111
  • Loading branch information
galipremsagar committed Jun 27, 2024
1 parent 5d49fe6 commit a71c249
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
16 changes: 16 additions & 0 deletions python/cudf/cudf/pandas/_wrappers/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,22 @@ def Index__new__(cls, *args, **kwargs):
},
)

ArrowStringArrayNumpySemantics = make_final_proxy_type(
"ArrowStringArrayNumpySemantics",
_Unusable,
pd.core.arrays.string_arrow.ArrowStringArrayNumpySemantics,
fast_to_slow=_Unusable(),
slow_to_fast=_Unusable(),
)

ArrowStringArray = make_final_proxy_type(
"ArrowStringArray",
_Unusable,
pd.core.arrays.string_arrow.ArrowStringArray,
fast_to_slow=_Unusable(),
slow_to_fast=_Unusable(),
)

StringDtype = make_final_proxy_type(
"StringDtype",
_Unusable,
Expand Down
3 changes: 2 additions & 1 deletion python/cudf/cudf/pandas/scripts/run-pandas-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ and not test_s3_roundtrip"
TEST_THAT_CRASH_PYTEST_WORKERS="not test_bitmasks_pyarrow \
and not test_large_string_pyarrow \
and not test_interchange_from_corrected_buffer_dtypes \
and not test_eof_states"
and not test_eof_states \
and not test_array_tz"

# TODO: Remove "not db" once a postgres & mysql container is set up on the CI
PANDAS_CI="1" timeout 30m python -m pytest -p cudf.pandas \
Expand Down
23 changes: 23 additions & 0 deletions python/cudf/cudf_pandas_tests/test_cudf_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,3 +1533,26 @@ def test_is_proxy_object():
assert is_proxy_object(np_arr_proxy)
assert is_proxy_object(s1)
assert not is_proxy_object(s2)


def test_arrow_string_arrays():
cu_s = xpd.Series(["a", "b", "c"])
pd_s = pd.Series(["a", "b", "c"])

cu_arr = xpd.arrays.ArrowStringArray._from_sequence(
cu_s, dtype=xpd.StringDtype("pyarrow")
)
pd_arr = pd.arrays.ArrowStringArray._from_sequence(
pd_s, dtype=pd.StringDtype("pyarrow")
)

tm.assert_equal(cu_arr, pd_arr)

cu_arr = xpd.core.arrays.string_arrow.ArrowStringArray._from_sequence(
cu_s, dtype=xpd.StringDtype("pyarrow_numpy")
)
pd_arr = pd.core.arrays.string_arrow.ArrowStringArray._from_sequence(
pd_s, dtype=pd.StringDtype("pyarrow_numpy")
)

tm.assert_equal(cu_arr, pd_arr)

0 comments on commit a71c249

Please sign in to comment.