Skip to content

Commit

Permalink
Fix read_ray_dataset is not compatible with the latest Ray (#3318)
Browse files Browse the repository at this point in the history
* Fix read_ray_dataset

* Fix

Co-authored-by: 刘宝 <po.lb@antgroup.com>
  • Loading branch information
fyrestone and 刘宝 committed Jan 12, 2023
1 parent dcc090d commit 7181876
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
22 changes: 13 additions & 9 deletions mars/dataframe/datasource/read_raydataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,20 @@ def read_ray_dataset(ds, columns=None, incremental_index=False, **kwargs):
import pyarrow as pa

try:
from ray.data.impl.pandas_block import PandasBlockSchema

if isinstance(schema, PandasBlockSchema):
dtypes = pd.Series(schema.types, index=schema.names)
elif isinstance(schema, pa.Schema):
dtypes = schema.empty_table().to_pandas().dtypes
else:
raise NotImplementedError(f"Unsupported format of schema {schema}")
except ImportError: # pragma: no cover
from ray.data._internal.pandas_block import PandasBlockSchema
except ImportError:
try:
from ray.data.impl.pandas_block import PandasBlockSchema
except ImportError: # pragma: no cover
PandasBlockSchema = type(None)

if isinstance(schema, PandasBlockSchema):
dtypes = pd.Series(schema.types, index=schema.names)
elif isinstance(schema, pa.Schema):
dtypes = schema.empty_table().to_pandas().dtypes
else:
raise NotImplementedError(f"Unsupported format of schema {schema}")

index_value = parse_index(pd.RangeIndex(-1))
columns_value = parse_index(dtypes.index, store_data=True)
op = DataFrameReadRayDataset(
Expand Down
6 changes: 3 additions & 3 deletions mars/dataframe/datasource/tests/test_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ..read_csv import read_csv, DataFrameReadCSV
from ..read_sql import read_sql_table, read_sql_query, DataFrameReadSQL
from ..read_raydataset import (
read_raydataset,
read_ray_dataset,
DataFrameReadRayDataset,
read_ray_mldataset,
DataFrameReadMLDataset,
Expand Down Expand Up @@ -518,7 +518,7 @@ def test_read_sql():


@require_ray
def test_read_raydataset(ray_start_regular):
def test_read_ray_dataset(ray_start_regular):
test_df1 = pd.DataFrame(
{
"a": np.arange(10).astype(np.int64, copy=False),
Expand All @@ -533,7 +533,7 @@ def test_read_raydataset(ray_start_regular):
)
df = pd.concat([test_df1, test_df2])
ds = ray.data.from_pandas_refs([ray.put(test_df1), ray.put(test_df2)])
mdf = read_raydataset(ds)
mdf = read_ray_dataset(ds)

assert mdf.shape[1] == 2
pd.testing.assert_index_equal(df.columns, mdf.columns_value.to_pandas())
Expand Down

0 comments on commit 7181876

Please sign in to comment.