Skip to content

Commit

Permalink
[Ray] Enable CI of mars/learn for Ray DAG (#3261)
Browse files Browse the repository at this point in the history
* Fix Ray DAG learn shuffle

* Fix special cases that some meta has been updated in previous stages

* Enable tests and remove init metrics

* Fix Ray executor can't run learn LabelEncoder

* Fix TensorPower on sparse matrix raises WRITEBACKIFCOPY base is read-only

* Fix check_binarized_results raises WRITEBACKIFCOPY base is read-only

* Fix shuffle colouring bug

* RemoteFunction support resolve_tileable_input

* Enable CI of mars/learn

* Remove @pytest.mark.ray_dag

* Pin pandas<1.5.0

* Fix

* Fix bincount shuffle

* Log not init error once

* Fix Ray executor result tileable ref

* Remove ensure_coverage() in ray executor

* Fix

* Fix ray.init twice?

* Try to fix Ray DAG coverage

* Improve coverage

* Coverage

* Fix

* Revert "Coverage"

This reverts commit 4174168.

* Revert "Try to fix Ray DAG coverage"

This reverts commit f7e8276.

* Coverage

Co-authored-by: 刘宝 <po.lb@antgroup.com>
  • Loading branch information
fyrestone and 刘宝 committed Oct 12, 2022
1 parent 6e2f7c9 commit 3a0a99c
Show file tree
Hide file tree
Showing 20 changed files with 225 additions and 128 deletions.
1 change: 1 addition & 0 deletions .github/workflows/platform-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ jobs:
export MARS_CI_BACKEND=ray
export RAY_idle_worker_killing_time_threshold_ms=60000
pytest $PYTEST_CONFIG --durations=0 --timeout=500 mars/dataframe -v -s -m "not skip_ray_dag"
pytest $PYTEST_CONFIG --durations=0 --timeout=500 mars/learn --ignore mars/learn/contrib --ignore mars/learn/utils/tests/test_collect_ports.py -v -s -m "not skip_ray_dag"
pytest $PYTEST_CONFIG --durations=0 --timeout=200 -v -s -m ray_dag
mv .coverage build/.coverage.ray_dag.file
pytest $PYTEST_CONFIG --durations=0 --timeout=200 -v -s mars/deploy/oscar/tests/test_ray_dag.py
Expand Down
4 changes: 2 additions & 2 deletions mars/core/operand/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_dependent_data_keys(self):
return deps
return super().get_dependent_data_keys()

def _iter_mapper_keys(self, input_id=0):
def iter_mapper_keys(self, input_id=0):
# key is mapper chunk key, index is mapper chunk index.
input_chunk = self.inputs[input_id]
if isinstance(input_chunk.op, ShuffleProxy):
Expand All @@ -103,7 +103,7 @@ def _iter_mapper_keys(self, input_id=0):
return keys

def iter_mapper_data(self, ctx, input_id=0, pop=False, skip_none=False):
for key in self._iter_mapper_keys(input_id):
for key in self.iter_mapper_keys(input_id):
try:
if pop:
yield ctx.pop((key, self.reducer_index))
Expand Down
21 changes: 2 additions & 19 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from ...utils import (
implements,
merge_chunks,
merged_chunk_as_tileable_type,
register_asyncio_task_timeout_detector,
classproperty,
copy_tileables,
Expand Down Expand Up @@ -1084,9 +1085,6 @@ async def _get_storage_api(self, band: BandType):
return storage_api

async def fetch(self, *tileables, **kwargs) -> list:
from ...tensor.core import TensorOrder
from ...tensor.array_utils import get_array_module

if kwargs: # pragma: no cover
unexpected_keys = ", ".join(list(kwargs.keys()))
raise TypeError(f"`fetch` got unexpected arguments: {unexpected_keys}")
Expand Down Expand Up @@ -1139,22 +1137,7 @@ async def fetch(self, *tileables, **kwargs) -> list:
for fetch_info in fetch_infos
]
merged = merge_chunks(index_to_data)
if hasattr(tileable, "order") and tileable.ndim > 0:
module = get_array_module(merged)
if tileable.order == TensorOrder.F_ORDER and hasattr(
module, "asfortranarray"
):
merged = module.asfortranarray(merged)
elif tileable.order == TensorOrder.C_ORDER and hasattr(
module, "ascontiguousarray"
):
merged = module.ascontiguousarray(merged)
if (
hasattr(tileable, "isscalar")
and tileable.isscalar()
and getattr(merged, "size", None) == 1
):
merged = merged.item()
merged = merged_chunk_as_tileable_type(merged, tileable)
result.append(self._process_result(tileable, merged))
return result

Expand Down
7 changes: 3 additions & 4 deletions mars/learn/decomposition/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from ... import tensor as mt
from ... import remote as mr
from ...core import ENTITY_TYPE
from ...tensor.array_utils import get_array_module
from ...tensor.core import TENSOR_TYPE
from ...tensor.utils import check_random_state
Expand Down Expand Up @@ -114,8 +113,6 @@ def _infer_dimension(spectrum, n_samples):
The returned value will be in [1, n_features - 1].
"""
if isinstance(spectrum, ENTITY_TYPE):
spectrum = spectrum.fetch()
xp = get_array_module(spectrum, nosparse=True)

ll = xp.empty_like(spectrum)
Expand Down Expand Up @@ -481,7 +478,9 @@ def _fit_full(self, X, n_components, session=None, run_kwargs=None):
# Postprocess the number of components required
if n_components == "mle":
n_components = mr.spawn(
_infer_dimension, args=(explained_variance_, n_samples)
_infer_dimension,
args=(explained_variance_, n_samples),
resolve_tileable_input=True,
)
ExecutableTuple([n_components, U, V]).execute(
session=session, **(run_kwargs or dict())
Expand Down
7 changes: 2 additions & 5 deletions mars/learn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def _execute_map(cls, ctx, op: "BaggingSample"):
),
) in result_store.items():
ctx[out_samples.key, (reducer_id, 0)] = (
ctx.get_current_chunk().key,
ctx.get_current_chunk().index,
tuple(samples + labels + weights + feature_idx_array),
)
Expand Down Expand Up @@ -569,10 +568,8 @@ def _execute_reduce(cls, ctx, op: "BaggingSample"):
else None
)

input_indexes = [
(source_key, idx) for source_key, idx, _ in op.iter_mapper_data(ctx)
]
for input_key, input_idx in input_indexes:
input_indexes = [idx for idx, _ in op.iter_mapper_data(ctx)]
for input_key, input_idx in zip(op.iter_mapper_keys(), input_indexes):
add_feature_index = input_idx[0] == 0
add_label_weight = input_idx[1] == op.chunk_shape[1] - 1
chunk_data = ctx[input_key, out_data.index][-1]
Expand Down
28 changes: 19 additions & 9 deletions mars/learn/ensemble/tests/test_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from sklearn.svm import SVC

from .... import tensor as mt, dataframe as md, execute
from ....conftest import MARS_CI_BACKEND
from ....core import enter_mode
from ....services.task.execution.api import Fetcher
from .._bagging import (
_extract_bagging_io,
BaggingSample,
Expand All @@ -41,16 +43,24 @@ async def _async_fetch():
meta_api = async_session._meta_api

t, indexes = async_session._get_to_fetch_tileable(tileable)
fetcher = Fetcher.create(
MARS_CI_BACKEND, get_storage_api=async_session._get_storage_api
)

get_metas = []
for chunk in t.chunks:
get_metas.append(
meta_api.get_chunk_meta.delay(
chunk.key, fields=fetcher.required_meta_keys
)
)
metas = await meta_api.get_chunk_meta.batch(*get_metas)

for chunk, meta in zip(t.chunks, metas):
await fetcher.append(chunk.key, meta)
all_data = await fetcher.get()

delays = [
meta_api.get_chunk_meta.delay(chunk.key, fields=["bands"])
for chunk in t.chunks
]
band_infos = await meta_api.get_chunk_meta.batch(*delays)
for chunk, band_info in zip(t.chunks, band_infos):
band = band_info["bands"][0]
storage_api = await async_session._get_storage_api(band)
data = await storage_api.get(chunk.key)
for chunk, data in zip(t.chunks, all_data):
tuples.append((t, chunk, data))
return tuples

Expand Down
9 changes: 8 additions & 1 deletion mars/learn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def check_binarized_results(y, classes, pos_label, neg_label, expected):

else:
inversed = _inverse_binarize_thresholding(
binarized,
binarized.copy(), # https://github.com/mars-project/mars/issues/3268
output_type=y_type,
classes=classes,
threshold=((neg_label + pos_label) / 2.0),
Expand Down Expand Up @@ -304,6 +304,13 @@ def test_label_encoder(setup, values, classes, unknown):
le.transform(unknown)


def test_label_encoder_missing_values_numeric(setup):
values = np.array([3, 1, np.nan, 5, 3, np.nan], dtype=float)
values_t = mt.tensor(values)
le = LabelEncoder()
assert_array_equal(le.fit_transform(values_t).fetch(), [1, 0, 3, 2, 1, 3])


def test_label_encoder_negative_ints(setup):
le = LabelEncoder()
le.fit(mt.tensor([1, 1, 4, 5, -1, 0]))
Expand Down
7 changes: 5 additions & 2 deletions mars/learn/utils/_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def _unique(values, *, return_inverse=False):
if return_inverse:

def inv_mapper(c, idx):
c[c > idx] = idx
if c.flags.writeable:
c[c > idx] = idx
else: # pragma: no cover
# If c is got from the shared memory, it is immutable.
c = np.select([c <= idx], [c], idx)
return c

inverse = inverse.map_chunk(
Expand All @@ -79,7 +83,6 @@ def inv_mapper(c, idx):
shape=((np.nan,),) * inverse.ndim,
)

if return_inverse:
return uniques, inverse
return uniques

Expand Down
6 changes: 5 additions & 1 deletion mars/learn/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def concat_chunks(chunks):


def copy_learned_attributes(from_estimator: BaseEstimator, to_estimator: BaseEstimator):
attrs = {k: v for k, v in vars(from_estimator).items() if k.endswith("_")}
attrs = {
k: v
for k, v in vars(from_estimator).items()
if k.endswith("_") or k.startswith("_")
}
for k, v in attrs.items():
setattr(to_estimator, k, v)

Expand Down
10 changes: 6 additions & 4 deletions mars/learn/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,12 @@

from .. import remote as mr
from .. import tensor as mt
from ..core import ENTITY_TYPE
from .base import BaseEstimator, RegressorMixin, ClassifierMixin
from .metrics import get_scorer
from .utils import copy_learned_attributes, check_array


def _wrap(estimator: SklearnBaseEstimator, method, X, y, **kwargs):
X = X.fetch() if isinstance(X, ENTITY_TYPE) else X
y = y.fetch() if isinstance(y, ENTITY_TYPE) else y
return getattr(estimator, method)(X, y, **kwargs)


Expand Down Expand Up @@ -145,7 +142,12 @@ def __init__(
def _make_fit(self, method):
def _fit(X, y=None, **kwargs):
result = (
mr.spawn(_wrap, args=(self.estimator, method, X, y), kwargs=kwargs)
mr.spawn(
_wrap,
args=(self.estimator, method, X, y),
kwargs=kwargs,
resolve_tileable_input=True,
)
.execute()
.fetch()
)
Expand Down
9 changes: 8 additions & 1 deletion mars/lib/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,14 @@ def __pow__(self, other, modulo=None):
except TypeError:
return NotImplemented
if get_array_module(naked_other).isscalar(naked_other):
x = self.spmatrix.power(naked_other)
try:
x = self.spmatrix.power(naked_other)
except ValueError as e: # pragma: no cover
# https://github.com/mars-project/mars/issues/3268
# https://github.com/scipy/scipy/issues/8678
assert "WRITEBACKIFCOPY" in e.args[0]
self.spmatrix = self.spmatrix.copy()
x = self.spmatrix.power(naked_other)
else:
if issparse(naked_other):
naked_other = other.toarray()
Expand Down
5 changes: 4 additions & 1 deletion mars/metrics/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def shutdown_metrics():

class _MetricWrapper(AbstractMetric):
_metric: AbstractMetric
_log_not_init_error: bool

def __init__(
self,
Expand All @@ -96,6 +97,7 @@ def __init__(
self._tag_keys = tag_keys or tuple()
self._type = metric_type
self._metric = None
self._log_not_init_error = False

@property
def type(self):
Expand All @@ -115,7 +117,8 @@ def set_metric(self, metric):
def record(self, value=1, tags: Optional[Dict[str, str]] = None):
if self._metric is not None:
self._metric.record(value, tags)
else:
elif not self._log_not_init_error:
self._log_not_init_error = True
logger.warning(
"Metric is not initialized, please call `init_metrics()` before using metrics."
)
Expand Down
42 changes: 33 additions & 9 deletions mars/remote/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import UserDict
from collections.abc import Iterable
from functools import partial

from .. import opcodes
from ..core import ENTITY_TYPE, ChunkData
from ..core import ENTITY_TYPE, ChunkData, Tileable
from ..core.custom_log import redirect_custom_log
from ..core.operand import ObjectOperand
from ..dataframe.core import DATAFRAME_TYPE, SERIES_TYPE, INDEX_TYPE
Expand All @@ -33,6 +33,8 @@
enter_current_session,
find_objects,
replace_objects,
merge_chunks,
merged_chunk_as_tileable_type,
)
from .operands import RemoteOperandMixin

Expand All @@ -45,6 +47,7 @@ class RemoteFunction(RemoteOperandMixin, ObjectOperand):
function_args = ListField("function_args")
function_kwargs = DictField("function_kwargs")
retry_when_fail = BoolField("retry_when_fail")
resolve_tileable_input = BoolField("resolve_tileable_input", default=False)
n_output = Int32Field("n_output", default=None)

@property
Expand Down Expand Up @@ -109,7 +112,7 @@ def tile(cls, op):
# if input is tensor, DataFrame etc,
# do not prepare data, because the data may be to huge,
# and users can choose to fetch slice of the data themselves
pure_depends.extend([True] * len(inp.chunks))
pure_depends.extend([not op.resolve_tileable_input] * len(inp.chunks))
else:
pure_depends.extend([False] * len(inp.chunks))
chunk_inputs.extend(inp.chunks)
Expand Down Expand Up @@ -141,11 +144,21 @@ def tile(cls, op):
@redirect_custom_log
@enter_current_session
def execute(cls, ctx, op: "RemoteFunction"):
mapping = {
inp: ctx[inp.key]
for inp, is_pure_dep in zip(op.inputs, op.pure_depends)
if not is_pure_dep
}
class MapperWrapper(UserDict):
def __getitem__(self, item):
if op.resolve_tileable_input and isinstance(item, Tileable):
index_chunks = [(c.index, ctx[c.key]) for c in item.chunks]
merged = merge_chunks(index_chunks)
return merged_chunk_as_tileable_type(merged, item)
return super().__getitem__(item)

mapping = MapperWrapper(
{
inp: ctx[inp.key]
for inp, is_pure_dep in zip(op.inputs, op.pure_depends)
if not is_pure_dep
}
)

function = op.function
function_args = replace_objects(op.function_args, mapping)
Expand All @@ -171,7 +184,15 @@ def execute(cls, ctx, op: "RemoteFunction"):
ctx[out.key] = r


def spawn(func, args=(), kwargs=None, retry_when_fail=False, n_output=None, **kw):
def spawn(
func,
args=(),
kwargs=None,
retry_when_fail=False,
resolve_tileable_input=False,
n_output=None,
**kw,
):
"""
Spawn a function and return a Mars Object which can be executed later.
Expand All @@ -185,6 +206,8 @@ def spawn(func, args=(), kwargs=None, retry_when_fail=False, n_output=None, **kw
Kwargs to pass to function
retry_when_fail: bool, default False
If True, retry when function failed.
resolve_tileable_input: bool default False
If True, resolve tileable inputs as values.
n_output: int
Count of outputs for the function
Expand Down Expand Up @@ -277,6 +300,7 @@ def spawn(func, args=(), kwargs=None, retry_when_fail=False, n_output=None, **kw
function_args=args,
function_kwargs=kwargs,
retry_when_fail=retry_when_fail,
resolve_tileable_input=resolve_tileable_input,
n_output=n_output,
**kw,
)
Expand Down

0 comments on commit 3a0a99c

Please sign in to comment.