Skip to content

Commit

Permalink
Code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
vcfgv committed Aug 6, 2021
1 parent 22a630e commit b463133
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
11 changes: 4 additions & 7 deletions mars/dataframe/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
import pandas as pd
from ..utils import lazy_import
from typing import Iterable, List, Optional
from typing import Iterable, List

ray = lazy_import('ray')
parallel_it = lazy_import('ray.util.iter')
Expand All @@ -25,17 +25,13 @@
class RayObjectPiece:
def __init__(self,
addr,
obj_ref,
row_ids: Optional[List[int]]):
obj_ref):
"""Represents a single entity holding the object ref."""
self.row_ids = row_ids
self.addr = addr
self.obj_ref = obj_ref

def read(self, shuffle: bool) -> pd.DataFrame:
df: pd.DataFrame = ray.get(self.obj_ref)
if self.row_ids:
df = df.loc[self.row_ids]

if shuffle:
df = df.sample(frac=1.0)
Expand All @@ -56,6 +52,7 @@ def __init__(self,
self.shuffle = shuffle
self.shuffle_seed = shuffle_seed

@property
def prefix(self) -> str:
return self._prefix

Expand Down Expand Up @@ -111,5 +108,5 @@ def from_mars(df,
df.execute()
# it's ensured that df has been executed
chunk_addr_refs = df.fetch(only_refs=True)
record_pieces = [RayObjectPiece(addr, obj_ref, None) for addr, obj_ref in chunk_addr_refs]
record_pieces = [RayObjectPiece(addr, obj_ref) for addr, obj_ref in chunk_addr_refs]
return _create_ml_dataset("from_mars", record_pieces, num_shards, shuffle, shuffle_seed)
42 changes: 41 additions & 1 deletion mars/dataframe/tests/test_mldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,46 @@ async def create_cluster(request):
yield client


@require_ray
@pytest.mark.asyncio
async def test_mldataset_related_classes(ray_large_cluster):
from mars.dataframe.dataset import RecordBatch, RayObjectPiece
# in order to pass checks
value = np.random.rand(10, 10)
df = pd.DataFrame(value)
if ray:
obj_ref = ray.put(df)
piece = RayObjectPiece(addr='address0', obj_ref=obj_ref)
data = piece.read(shuffle=False)
shuffle_data = piece.read(shuffle=True)
pd.testing.assert_frame_equal(data, df)
assert not shuffle_data.equals(df)

batch = RecordBatch(shard_id=0,
prefix='test_batch',
record_pieces=[piece],
shuffle=False,
shuffle_seed=None
)
assert batch.shard_id == 0
assert batch.prefix == 'test_batch'
# only one data in batch
data = list(batch.__iter__())[0]
pd.testing.assert_frame_equal(data, df)

shuffle_batch = RecordBatch(shard_id=1,
prefix='shuffle_batch',
record_pieces=[piece],
shuffle=True,
shuffle_seed=0
)
assert shuffle_batch.shard_id == 1
assert shuffle_batch.prefix == 'shuffle_batch'
# only one data in batch
shuffle_data = list(shuffle_batch.__iter__())[0]
assert not shuffle_data.equals(df)


@require_ray
@pytest.mark.asyncio
async def test_convert_to_mldataset(ray_large_cluster, create_cluster):
Expand All @@ -72,7 +112,7 @@ async def test_convert_to_mldataset(ray_large_cluster, create_cluster):
df: md.DataFrame = md.DataFrame(value, chunk_size=5)
df.execute()

ds = RayMLDataset.from_mars(df, num_shards=4)
ds = RayMLDataset.from_mars(df, num_shards=4, shuffle=True, shuffle_seed=0)
if ml_dataset:
assert isinstance(ds, ml_dataset.MLDataset)

Expand Down

0 comments on commit b463133

Please sign in to comment.