Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Mars DataFrame with Ray MLDataset #2294

Merged
merged 20 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mars/core/entity/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ def fetch_log(self,
return fetch_log(self, session=session,
offsets=offsets, sizes=sizes)[0]

def _fetch_infos(self, fields=None, session=None, **kw):
from ...deploy.oscar.session import fetch_infos

session = _get_session(self, session)
self._check_session(session, 'fetch_infos')
return fetch_infos(self, fields=fields, session=session, **kw)

def _attach_session(self, session: SessionType):
if session not in self._executed_sessions:
_cleaner.register(self, session)
Expand Down Expand Up @@ -233,6 +240,13 @@ def _fetch(self, session: SessionType = None, **kw):
self._check_session(session, 'fetch')
return fetch(*self, session=session, **kw)

def _fetch_infos(self, fields=None, session=None, **kw):
from ...deploy.oscar.session import fetch_infos

session = _get_session(self, session)
self._check_session(session, 'fetch_infos')
return fetch_infos(*self, fields=fields, session=session, **kw)

def fetch(self, session: SessionType = None, **kw):
if len(self) == 0:
return tuple()
Expand Down
15 changes: 15 additions & 0 deletions mars/dataframe/contrib/raydataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .mldataset import to_ray_mldataset, ChunkRefBatch
142 changes: 142 additions & 0 deletions mars/dataframe/contrib/raydataset/mldataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd
from ....utils import ceildiv, lazy_import
from collections import defaultdict
from typing import Dict, Iterable, List, Tuple

ray = lazy_import('ray')
parallel_it = lazy_import('ray.util.iter')
ml_dataset = lazy_import('ray.util.data')


class ChunkRefBatch:
def __init__(self,
shard_id: int,
obj_refs: 'ray.ObjectRef'):
"""Iterable batch holding a list of ray.ObjectRefs.

Args:
shard_id (int): id of the shard
prefix (str): prefix name of the batch
obj_refs (List[ray.ObjectRefs]): list of ray.ObjectRefs
"""
self._shard_id = shard_id
self._obj_refs = obj_refs

@property
def shard_id(self) -> int:
return self._shard_id

def __iter__(self) -> Iterable[pd.DataFrame]:
"""Returns the item_generator required from ParallelIteratorWorker."""
for obj_ref in self._obj_refs:
yield ray.get(obj_ref)


def _group_chunk_refs(chunk_addr_refs: List[Tuple[Tuple, 'ray.ObjectRef']],
num_shards: int):
"""Group fetched ray.ObjectRefs into a dict for later use.

Args:
chunk_addr_refs (List[Tuple[Tuple, ray.ObjectRef]]): a list of tuples of
band & ray.ObjectRef of each chunk.
num_shards (int): the number of shards that will be created for the MLDataset.

Returns:
Dict[str, List[ray.ObjectRef]]: a dict that defines which group of ray.ObjectRefs will
be in an ChunkRefBatch.
"""
group_to_obj_refs = defaultdict(list)
if not num_shards:
for addr, obj_ref in chunk_addr_refs:
group_to_obj_refs[addr].append(obj_ref)
else:
splits = np.array_split([ref for _, ref in chunk_addr_refs],
num_shards)
for idx, split in enumerate(splits):
group_to_obj_refs['group-' + str(idx)] = list(split)
return group_to_obj_refs


def _create_ml_dataset(name: str,
group_to_obj_refs: Dict[str, List['ray.ObjectRef']]):
record_batches = []
for rank, obj_refs in enumerate(group_to_obj_refs.values()):
record_batches.append(ChunkRefBatch(shard_id=rank,
obj_refs=obj_refs))
worker_cls = ray.remote(num_cpus=0)(parallel_it.ParallelIteratorWorker)
actors = [worker_cls.remote(g, False) for g in record_batches]
it = parallel_it.from_actors(actors, name)
ds = ml_dataset.from_parallel_iter(
it, need_convert=False, batch_size=0, repeated=False)
return ds


def _rechunk_if_needed(df, num_shards: int = None):
chunk_size = df.extra_params.raw_chunk_size or max(df.shape)
num_rows = df.shape[0]
num_columns = df.shape[1]
# if chunk size not set, num_chunks_in_row = 1
# if chunk size is set more than max(df.shape), num_chunks_in_row = 1
# otherwise, num_chunks_in_row depends on ceildiv(num_rows, chunk_size)
num_chunks_in_row = ceildiv(num_rows, chunk_size)
naive_num_partitions = ceildiv(num_rows, num_columns)

need_re_execute = False
# ensure each part holds all columns
if chunk_size < num_columns:
df = df.rebalance(axis=1, num_partitions=1)
need_re_execute = True
if num_shards and num_chunks_in_row < num_shards:
df = df.rebalance(axis=0, num_partitions=num_shards)
need_re_execute = True
if not num_shards and num_chunks_in_row == 1:
df = df.rebalance(axis=0, num_partitions=naive_num_partitions)
need_re_execute = True
if need_re_execute:
df.execute()
return df


def to_ray_mldataset(df,
num_shards: int = None):
"""Create a MLDataset from Mars DataFrame

Args:
df (mars.dataframe.Dataframe): the Mars DataFrame
num_shards (int, optional): the number of shards that will be created
for the MLDataset. Defaults to None.
If num_shards is None, chunks will be grouped by nodes where they lie.
Otherwise, chunks will be grouped by their order in DataFrame.

Returns:
a MLDataset
"""
df = _rechunk_if_needed(df, num_shards)
# chunk_addr_refs is fetched directly rather than in batches
# during `fetch` procedure, it'll be checked that df has been executed
# items in chunk_addr_refs are ordered by positions in df
# while adjacent chunks may belong to different addrs, i.e.
# chunk1 for addr1,
# chunk2 & chunk3 for addr2,
# chunk4 for addr1
fetched_infos: Dict[str, List] = df.fetch_infos(fields=['band', 'object_id'])
chunk_addr_refs: List[Tuple[Tuple, 'ray.ObjectRef']] = [(band, object_id) for band, object_id in
zip(fetched_infos['band'],
fetched_infos['object_id'])]
group_to_obj_refs: Dict[str, List[ray.ObjectRef]] = _group_chunk_refs(chunk_addr_refs, num_shards)
return _create_ml_dataset("from_mars", group_to_obj_refs)
131 changes: 131 additions & 0 deletions mars/dataframe/contrib/raydataset/tests/test_mldataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np
import pandas as pd
import pytest

import mars.dataframe as md
import mars.dataframe.contrib.raydataset as mdd
from mars.deploy.oscar.ray import new_cluster, _load_config
from mars.deploy.oscar.session import new_session
from mars.tests.core import require_ray
from mars.utils import lazy_import


ray = lazy_import('ray')
ml_dataset = lazy_import('ray.util.data')
try:
import xgboost_ray
except ImportError: # pragma: no cover
xgboost_ray = None
try:
import sklearn
except ImportError: # pragma: no cover
sklearn = None


@pytest.fixture
async def create_cluster(request):
param = getattr(request, "param", {})
ray_config = _load_config()
ray_config.update(param.get('config', {}))
client = await new_cluster('test_cluster',
worker_num=4,
worker_cpu=2,
worker_mem=1 * 1024 ** 3,
config=ray_config)
async with client:
yield client


@require_ray
@pytest.mark.asyncio
async def test_dataset_related_classes(ray_large_cluster):
from mars.dataframe.contrib.raydataset import ChunkRefBatch
# in order to pass checks
value1 = np.random.rand(10, 10)
value2 = np.random.rand(10, 10)
df1 = pd.DataFrame(value1)
df2 = pd.DataFrame(value2)
if ray:
obj_ref1, obj_ref2 = ray.put(df1), ray.put(df2)
batch = ChunkRefBatch(shard_id=0,
obj_refs=[obj_ref1, obj_ref2])
assert batch.shard_id == 0
# the first data in batch
batch = iter(batch)
pd.testing.assert_frame_equal(next(batch), df1)
pd.testing.assert_frame_equal(next(batch), df2)


@require_ray
@pytest.mark.asyncio
@pytest.mark.parametrize('test_option', [[5, 5], [5, 4],
[None, None]])
async def test_convert_to_ray_mldataset(ray_large_cluster, create_cluster, test_option):
assert create_cluster.session
session = new_session(address=create_cluster.address, backend='oscar', default=True)
with session:
value = np.random.rand(10, 10)
chunk_size, num_shards = test_option
df: md.DataFrame = md.DataFrame(value, chunk_size=chunk_size)
df.execute()

ds = mdd.to_ray_mldataset(df, num_shards=num_shards)
assert isinstance(ds, ml_dataset.MLDataset)


@require_ray
@pytest.mark.asyncio
@pytest.mark.skipif(sklearn is None, reason='sklearn not installed')
@pytest.mark.skipif(xgboost_ray is None, reason='xgboost_ray not installed')
async def test_mars_with_xgboost(ray_large_cluster, create_cluster):
from xgboost_ray import RayDMatrix, RayParams, train
from sklearn.datasets import load_breast_cancer

assert create_cluster.session
session = new_session(address=create_cluster.address, backend='oscar', default=True)
with session:
train_x, train_y = load_breast_cancer(return_X_y=True, as_frame=True)
pd_df = pd.concat([train_x, train_y], axis=1)
df: md.DataFrame = md.DataFrame(pd_df)
df.execute()

num_shards = 4
ds = mdd.to_ray_mldataset(df)
assert isinstance(ds, ml_dataset.MLDataset)

# train
train_set = RayDMatrix(ds, "target")
evals_result = {}
bst = train(
{
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
},
train_set,
evals_result=evals_result,
evals=[(train_set, "train")],
verbose_eval=False,
ray_params=RayParams(
num_actors=num_shards, # Number of remote actors
cpus_per_actor=1)
)
bst.save_model("model.xgb")
assert os.path.exists("model.xgb")
os.remove("model.xgb")
print("Final training error: {:.4f}".format(
evals_result["train"]["error"][-1]))
3 changes: 3 additions & 0 deletions mars/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,9 @@ def fetch(self, session=None, **kw):
session=session, **kw))
return pd.concat(batches) if len(batches) > 1 else batches[0]

def fetch_infos(self, fields=None, session=None, **kw):
return self._fetch_infos(fields=fields, session=session, **kw)


class IndexData(HasShapeTileableData, _ToPandasMixin):
__slots__ = ()
Expand Down
Loading