Skip to content

Commit

Permalink
Trying to fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
vcfgv committed Aug 5, 2021
1 parent 68ffe42 commit 611a2ca
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
23 changes: 8 additions & 15 deletions mars/dataframe/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
import pandas as pd
from ..utils import lazy_import
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional

ray = lazy_import('ray')
parallel_it = lazy_import('ray.util.iter')
Expand All @@ -24,12 +24,10 @@

class RayObjectPiece:
def __init__(self,
addr: str,
obj_ref: ray.ObjectRef,
addr,
obj_ref,
row_ids: Optional[List[int]]):
"""
RayObjectPiece is a single entity holding the object ref
"""
"""RayObjectPiece is a single entity holding the object ref"""
self.row_ids = row_ids
self.addr = addr
self.obj_ref = obj_ref
Expand All @@ -51,9 +49,7 @@ def __init__(self,
record_pieces: List[RayObjectPiece],
shuffle: bool,
shuffle_seed: int):
"""
RecordBatch holds a list of RayObjectPieces
"""
"""RecordBatch holds a list of RayObjectPieces"""
self._shard_id = shard_id
self._prefix = prefix
self.record_pieces = record_pieces
Expand All @@ -68,9 +64,7 @@ def shard_id(self) -> int:
return self._shard_id

def __iter__(self) -> Iterable[pd.DataFrame]:
"""
Returns the item_generator required from ParallelIteratorWorker
"""
"""Returns the item_generator required from ParallelIteratorWorker"""
if self.shuffle:
np.random.seed(self.shuffle_seed)
np.random.shuffle(self.record_pieces)
Expand Down Expand Up @@ -116,7 +110,6 @@ def from_mars(df,
df = df.rebalance(axis=0, num_partitions=num_shards)
df.execute()
# it's ensured that df has been executed
chunk_addr_refs: List[Tuple(str, ray.ObjectRef)] = df.fetch(only_refs=True)
chunk_addr_refs = df.fetch(only_refs=True)
record_pieces = [RayObjectPiece(addr, obj_ref, None) for addr, obj_ref in chunk_addr_refs]
return _create_ml_dataset("from_mars", record_pieces, num_shards,
shuffle, shuffle_seed)
return _create_ml_dataset("from_mars", record_pieces, num_shards, shuffle, shuffle_seed)
4 changes: 3 additions & 1 deletion mars/dataframe/tests/test_mldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from mars.dataframe.dataset import RayMLDataset
from mars.deploy.oscar.ray import new_cluster, _load_config
from mars.deploy.oscar.session import new_session
from mars.tests.conftest import * # noqa
from mars.tests.core import require_ray
from mars.tests.conftest import * # noqa
from mars.utils import lazy_import


Expand All @@ -33,13 +33,15 @@
sklearn_datasets = lazy_import('sklearn.datasets')


# TODO: register custom marks
def require_xgboost_ray(func):
if pytest:
func = pytest.mark.xgboost_ray(func)
func = pytest.mark.skipif(xgboost_ray is None, reason='xgboost_ray not installed')(func)
return func


# TODO: register custom marks
def require_sklearn(func):
if pytest:
func = pytest.mark.sklearn(func)
Expand Down

0 comments on commit 611a2ca

Please sign in to comment.