Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Mars DataFrame with Ray MLDataset #2294

Merged
merged 20 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mars/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ def fetch(self, session=None, **kw):
from .indexing.iloc import DataFrameIlocGetItem, SeriesIlocGetItem

batch_size = kw.pop('batch_size', 1000)
if isinstance(self.op, (DataFrameIlocGetItem, SeriesIlocGetItem)):
only_refs = kw.get('only_refs', False)
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self.op, (DataFrameIlocGetItem, SeriesIlocGetItem)) or only_refs:
# see GH#1871
# already iloc, do not trigger batch fetch
return self._fetch(session=session, **kw)
Expand Down
141 changes: 141 additions & 0 deletions mars/dataframe/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
import pandas as pd
from ..utils import ceildiv, lazy_import
from collections import defaultdict
from typing import Iterable

ray = lazy_import('ray')
parallel_it = lazy_import('ray.util.iter')
ml_dataset = lazy_import('ray.util.data')


class ObjRefBatch:
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self,
shard_id,
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
obj_refs):
"""Iterable batch holding a list of ray.ObjectRefs.

Args:
shard_id (int): id of the shard
prefix (str): prefix name of the batch
obj_refs (List[ray.ObjectRefs]): list of ray.ObjectRefs
"""
self._shard_id = shard_id
self._obj_refs = obj_refs

@property
def shard_id(self) -> int:
return self._shard_id

def __iter__(self) -> Iterable[pd.DataFrame]:
"""Returns the item_generator required from ParallelIteratorWorker."""
for obj_ref in self._obj_refs:
yield ray.get(obj_ref)


def _group_obj_refs(chunk_addr_refs,
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
num_shards):
"""Group fetched ray.ObjectRefs into a dict for later use.

Args:
chunk_addr_refs (List[Tuple(str, ray.ObjectRef)]): a list of add & ray.ObjectRef
of each chunk.
num_shards (int): the number of shards that will be created for the MLDataset.

Returns:
Dict[str, List[ray.ObjectRef]]: a dict that defines which group of ray.ObjectRefs will
be in an ObjRefBatch.
"""
group_to_obj_refs = defaultdict(list)
if not num_shards:
for addr, obj_ref in chunk_addr_refs:
group_to_obj_refs[addr].append(obj_ref)
else:
splits = np.array_split([ref for _, ref in chunk_addr_refs],
num_shards)
for idx, split in enumerate(splits):
group_to_obj_refs['group-' + str(idx)] = list(split)
return group_to_obj_refs


def _create_ml_dataset(name,
group_to_obj_refs):
record_batches = []
for rank, obj_refs in enumerate(group_to_obj_refs.values()):
record_batches.append(ObjRefBatch(shard_id=rank,
obj_refs=obj_refs))
worker_cls = ray.remote(num_cpus=0)(parallel_it.ParallelIteratorWorker)
actors = [worker_cls.remote(g, False) for g in record_batches]
it = parallel_it.from_actors(actors, name)
ds = ml_dataset.from_parallel_iter(
it, need_convert=False, batch_size=0, repeated=False)
return ds


def _rechunk_if_needed(df, num_shards: int = None):
chunk_size = df.extra_params.raw_chunk_size or max(df.shape)
num_rows = df.shape[0]
num_columns = df.shape[1]
# if chunk size not set, num_chunks_in_row = 1
# if chunk size is set more than max(df.shape), num_chunks_in_row = 1
# otherwise, num_chunks_in_row depends on ceildiv(num_rows, chunk_size)
num_chunks_in_row = ceildiv(num_rows, chunk_size)
naive_num_partitions = ceildiv(num_rows, num_columns)

need_re_execute = False
# ensure each part holds all columns
if chunk_size < num_columns:
df = df.rebalance(axis=1, num_partitions=1)
need_re_execute = True
if num_shards and num_chunks_in_row < num_shards:
df = df.rebalance(axis=0, num_partitions=num_shards)
need_re_execute = True
if not num_shards and num_chunks_in_row == 1:
df = df.rebalance(axis=0, num_partitions=naive_num_partitions)
need_re_execute = True
if need_re_execute:
df.execute()
return df


def to_ray_mldataset(df,
num_shards: int = None):
"""Create a MLDataset from Mars DataFrame

Args:
df (mars.dataframe.Dataframe): the Mars DataFrame
num_shards (int, optional): the number of shards that will be created
for the MLDataset. Defaults to None.
If num_shards is None, chunks will be grouped by nodes where they lie.
If num_shards equals num_nodes mentions above, chunks will alse be
grouped by nodes where they lie.
Otherwise, chunks will be grouped by their order in DataFrame.

Returns:
a MLDataset
"""
df = _rechunk_if_needed(df, num_shards)
# chunk_addr_refs is fetched directly rather than in batches
# during `fetch` procedure, it'll be checked that df has been executed
# items in chunk_addr_refs are ordered by positions in df
# while adjacent chunks may belong to different addrs, i.e.
# chunk1 for addr1,
# chunk2 & chunk3 for addr2,
# chunk4 for addr1
chunk_addr_refs = df.fetch(only_refs=True)
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
group_to_obj_refs = _group_obj_refs(chunk_addr_refs, num_shards)
return _create_ml_dataset("from_mars", group_to_obj_refs)
133 changes: 133 additions & 0 deletions mars/dataframe/tests/test_mldataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np
import pandas as pd
import pytest

import mars.dataframe as md
import mars.dataframe.dataset as mds
from mars.deploy.oscar.ray import new_cluster, _load_config
from mars.deploy.oscar.session import new_session
from mars.tests.core import require_ray
from mars.utils import lazy_import


ray = lazy_import('ray')
ml_dataset = lazy_import('ray.util.data')
try:
import xgboost_ray
except ImportError: # pragma: no cover
xgboost_ray = None
try:
import sklearn
except ImportError: # pragma: no cover
sklearn = None


@pytest.fixture
async def create_cluster(request):
param = getattr(request, "param", {})
ray_config = _load_config()
ray_config.update(param.get('config', {}))
client = await new_cluster('test_cluster',
worker_num=4,
worker_cpu=2,
worker_mem=1 * 1024 ** 3,
config=ray_config)
async with client:
yield client


@require_ray
@pytest.mark.asyncio
async def test_dataset_related_classes(ray_large_cluster):
from mars.dataframe.dataset import ObjRefBatch
# in order to pass checks
value1 = np.random.rand(10, 10)
value2 = np.random.rand(10, 10)
df1 = pd.DataFrame(value1)
df2 = pd.DataFrame(value2)
if ray:
obj_ref1, obj_ref2 = ray.put(df1), ray.put(df2)
batch = ObjRefBatch(shard_id=0,
obj_refs=[obj_ref1, obj_ref2])
assert batch.shard_id == 0
# the first data in batch
data1, data2 = batch.__iter__()
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
pd.testing.assert_frame_equal(data1, df1)
pd.testing.assert_frame_equal(data2, df2)


@require_ray
@pytest.mark.asyncio
@pytest.mark.parametrize('test_option', [[5, 5], [5, 4],
[None, None]])
async def test_convert_to_ray_mldataset(ray_large_cluster, create_cluster, test_option):
assert create_cluster.session
session = new_session(address=create_cluster.address, backend='oscar', default=True)
with session:
value = np.random.rand(20, 10)
chunk_size, num_shards = test_option
df: md.DataFrame = md.DataFrame(value, chunk_size=chunk_size)
df.execute()

ds = mds.to_ray_mldataset(df, num_shards=num_shards)
if ml_dataset:
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(ds, ml_dataset.MLDataset)


@require_ray
@pytest.mark.asyncio
@pytest.mark.skipif(sklearn is None, reason='sklearn not installed')
@pytest.mark.skipif(xgboost_ray is None, reason='xgboost_ray not installed')
async def test_mars_with_xgboost(ray_large_cluster, create_cluster):
from xgboost_ray import RayDMatrix, RayParams, train
from sklearn.datasets import load_breast_cancer

assert create_cluster.session
session = new_session(address=create_cluster.address, backend='oscar', default=True)
with session:
train_x, train_y = load_breast_cancer(return_X_y=True, as_frame=True)
pd_df = pd.concat([train_x, train_y], axis=1)
df: md.DataFrame = md.DataFrame(pd_df)
df.execute()

num_shards = 4
ds = mds.to_ray_mldataset(df)
if ml_dataset:
assert isinstance(ds, ml_dataset.MLDataset)

# train
train_set = RayDMatrix(ds, "target")
evals_result = {}
bst = train(
{
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
},
train_set,
evals_result=evals_result,
evals=[(train_set, "train")],
verbose_eval=False,
ray_params=RayParams(
num_actors=num_shards, # Number of remote actors
cpus_per_actor=1)
)
bst.save_model("model.xgb")
assert os.path.exists("model.xgb")
os.remove("model.xgb")
print("Final training error: {:.4f}".format(
evals_result["train"]["error"][-1]))
48 changes: 28 additions & 20 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,8 @@ async def fetch(self, *tileables, **kwargs) -> list:
from ...tensor.core import TensorOrder
from ...tensor.array_utils import get_array_module

if kwargs: # pragma: no cover
only_refs = kwargs.get('only_refs', False)
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
if kwargs and 'only_refs' not in kwargs: # pragma: no cover
unexpected_keys = ', '.join(list(kwargs.keys()))
raise TypeError(f'`fetch` got unexpected '
f'arguments: {unexpected_keys}')
Expand Down Expand Up @@ -837,33 +838,40 @@ async def fetch(self, *tileables, **kwargs) -> list:
chunk = fetch_info.chunk
band = chunk_to_band[chunk]
storage_api = await self._get_storage_api(band)
storage_api_to_gets[storage_api].append(
storage_api.get.delay(chunk.key, conditions=conditions))
get = storage_api.get.delay(chunk.key, conditions=conditions) \
if not only_refs else storage_api.get_infos.delay(chunk.key)
storage_api_to_gets[storage_api].append(get)
storage_api_to_fetch_infos[storage_api].append(fetch_info)
for storage_api in storage_api_to_gets:
fetched_data = await storage_api.get.batch(
*storage_api_to_gets[storage_api])
*storage_api_to_gets[storage_api]) if not only_refs else \
await storage_api.get_infos.batch(
*storage_api_to_gets[storage_api])
infos = storage_api_to_fetch_infos[storage_api]
for info, data in zip(infos, fetched_data):
info.data = data

result = []
for tileable, fetch_infos in zip(tileables, fetch_infos_list):
index_to_data = [(fetch_info.chunk.index, fetch_info.data)
for fetch_info in fetch_infos]
merged = merge_chunks(index_to_data)
if hasattr(tileable, 'order') and tileable.ndim > 0:
module = get_array_module(merged)
if tileable.order == TensorOrder.F_ORDER and \
hasattr(module, 'asfortranarray'):
merged = module.asfortranarray(merged)
elif tileable.order == TensorOrder.C_ORDER and \
hasattr(module, 'ascontiguousarray'):
merged = module.ascontiguousarray(merged)
if hasattr(tileable, 'isscalar') and tileable.isscalar() and \
getattr(merged, 'size', None) == 1:
merged = merged.item()
result.append(self._process_result(tileable, merged))
if only_refs:
result += [(chunk_to_band[fetch_info.chunk], fetch_info.data[0].object_id)
vcfgv marked this conversation as resolved.
Show resolved Hide resolved
for fetch_info in fetch_infos]
else:
index_to_data = [(fetch_info.chunk.index, fetch_info.data)
for fetch_info in fetch_infos]
merged = merge_chunks(index_to_data)
if hasattr(tileable, 'order') and tileable.ndim > 0:
module = get_array_module(merged)
if tileable.order == TensorOrder.F_ORDER and \
hasattr(module, 'asfortranarray'):
merged = module.asfortranarray(merged)
elif tileable.order == TensorOrder.C_ORDER and \
hasattr(module, 'ascontiguousarray'):
merged = module.ascontiguousarray(merged)
if hasattr(tileable, 'isscalar') and tileable.isscalar() and \
getattr(merged, 'size', None) == 1:
merged = merged.item()
result.append(self._process_result(tileable, merged))
return result

async def decref(self, *tileable_keys):
Expand Down Expand Up @@ -1378,7 +1386,7 @@ async def _fetch(tileable: TileableType,
tileable, tileables = tileable[0], tileable[1:]
session = _get_isolated_session(session)
data = await session.fetch(tileable, *tileables, **kwargs)
return data[0] if len(tileables) == 0 else data
return data[0] if len(tileables) == 0 and len(data) == 1 else data


def fetch(tileable: TileableType,
Expand Down