Skip to content

Commit

Permalink
Trying to fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
vcfgv committed Aug 5, 2021
1 parent 68ffe42 commit f7f6d4e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
21 changes: 8 additions & 13 deletions mars/dataframe/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
import pandas as pd
from ..utils import lazy_import
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional

ray = lazy_import('ray')
parallel_it = lazy_import('ray.util.iter')
Expand All @@ -24,12 +24,11 @@

class RayObjectPiece:
def __init__(self,
addr: str,
obj_ref: ray.ObjectRef,
addr,
obj_ref,
row_ids: Optional[List[int]]):
"""
RayObjectPiece is a single entity holding the object ref
"""
RayObjectPiece is a single entity holding the object ref."""
self.row_ids = row_ids
self.addr = addr
self.obj_ref = obj_ref
Expand All @@ -51,9 +50,7 @@ def __init__(self,
record_pieces: List[RayObjectPiece],
shuffle: bool,
shuffle_seed: int):
"""
RecordBatch holds a list of RayObjectPieces
"""
"""RecordBatch holds a list of RayObjectPieces."""
self._shard_id = shard_id
self._prefix = prefix
self.record_pieces = record_pieces
Expand All @@ -69,8 +66,7 @@ def shard_id(self) -> int:

def __iter__(self) -> Iterable[pd.DataFrame]:
"""
Returns the item_generator required from ParallelIteratorWorker
"""
Returns the item_generator required from ParallelIteratorWorker."""
if self.shuffle:
np.random.seed(self.shuffle_seed)
np.random.shuffle(self.record_pieces)
Expand Down Expand Up @@ -116,7 +112,6 @@ def from_mars(df,
df = df.rebalance(axis=0, num_partitions=num_shards)
df.execute()
# it's ensured that df has been executed
chunk_addr_refs: List[Tuple(str, ray.ObjectRef)] = df.fetch(only_refs=True)
chunk_addr_refs = df.fetch(only_refs=True)
record_pieces = [RayObjectPiece(addr, obj_ref, None) for addr, obj_ref in chunk_addr_refs]
return _create_ml_dataset("from_mars", record_pieces, num_shards,
shuffle, shuffle_seed)
return _create_ml_dataset("from_mars", record_pieces, num_shards, shuffle, shuffle_seed)
10 changes: 7 additions & 3 deletions mars/dataframe/tests/test_mldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from mars.dataframe.dataset import RayMLDataset
from mars.deploy.oscar.ray import new_cluster, _load_config
from mars.deploy.oscar.session import new_session
from mars.tests.conftest import * # noqa
from mars.tests.core import require_ray
from ...tests.conftest import * # noqa
from mars.utils import lazy_import


Expand All @@ -33,13 +33,15 @@
sklearn_datasets = lazy_import('sklearn.datasets')


# TODO: register custom marks
def require_xgboost_ray(func):
if pytest:
func = pytest.mark.xgboost_ray(func)
func = pytest.mark.skipif(xgboost_ray is None, reason='xgboost_ray not installed')(func)
return func


# TODO: register custom marks
def require_sklearn(func):
if pytest:
func = pytest.mark.sklearn(func)
Expand Down Expand Up @@ -72,7 +74,8 @@ async def test_convert_to_mldataset(ray_large_cluster, create_cluster):
df.execute()

ds = RayMLDataset.from_mars(df, num_shards=4)
assert isinstance(ds, ml_dataset.MLDataset)
if ml_dataset:
assert isinstance(ds, ml_dataset.MLDataset)


@require_ray
Expand All @@ -91,7 +94,8 @@ async def test_mars_with_xgboost(ray_large_cluster, create_cluster):

num_shards = 4
ds = RayMLDataset.from_mars(df, num_shards=num_shards)
assert isinstance(ds, ml_dataset.MLDataset)
if ml_dataset:
assert isinstance(ds, ml_dataset.MLDataset)

# train
train_set = RayDMatrix(ds, "target")
Expand Down

0 comments on commit f7f6d4e

Please sign in to comment.