Skip to content

Commit

Permalink
Remove storage service from supervisor (#3254)
Browse files Browse the repository at this point in the history
  • Loading branch information
vcfgv committed Sep 14, 2022
1 parent 6c6fc48 commit 71b1035
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 187 deletions.
48 changes: 22 additions & 26 deletions mars/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import sys
import cloudpickle
import functools
import itertools
import logging
Expand All @@ -32,7 +33,6 @@
from ..core import Entity, ExecutableTuple
from ..core.context import Context, get_context
from ..lib.mmh3 import hash as mmh_hash
from ..services.task.execution.ray.context import RayExecutionContext
from ..tensor.utils import dictify_chunk_size, normalize_chunk_sizes
from ..typing import ChunkType, TileableType
from ..utils import (
Expand All @@ -42,6 +42,7 @@
ModulePlaceholder,
is_full_slice,
parse_readable_size,
is_on_ray,
)

try:
Expand Down Expand Up @@ -1432,6 +1433,9 @@ def _concat_chunks(merge_chunks: List[ChunkType], output_index: int):
return new_op.new_tileable(df_or_series.op.inputs, kws=[params])


# TODO: clean_up_func, is_on_ray and restore_func functions may be
# removed or refactored in the future to calculate func size
# with more accuracy as well as address some serialization issues.
def clean_up_func(op):
closure_clean_up_bytes_threshold = int(
os.getenv("MARS_CLOSURE_CLEAN_UP_BYTES_THRESHOLD", 10**4)
Expand All @@ -1441,20 +1445,6 @@ def clean_up_func(op):
ctx = get_context()
if ctx is None:
return
# Before PR #3165 is merged, func cleanup is temporarily disabled under ray task mode.
# https://github.com/mars-project/mars/pull/3165
if isinstance(ctx, RayExecutionContext):
logger.warning("Func cleanup is currently disabled under ray task mode.")
return
# Note: Vineyard internally uses `pickle` which fails to pickle
# cell objects and corresponding functions.
if vineyard is not None:
storage_backend = ctx.get_storage_info()
if storage_backend.get("name", None) == "vineyard":
logger.warning(
"Func cleanup is currently disabled when vineyard is used as storage backend."
)
return

func = op.func
if hasattr(func, "__closure__") and func.__closure__ is not None:
Expand All @@ -1464,21 +1454,27 @@ def clean_up_func(op):
if counted_bytes >= closure_clean_up_bytes_threshold:
op.need_clean_up_func = True
break
# Note: op.func_key is set only when op.need_clean_up_func is True.
# Note: op.func_key is set only when func was put into storage.
if op.need_clean_up_func:
assert (
op.logic_key is not None
), "Logic key wasn't calculated before cleaning up func."
op.func_key = ctx.storage_put(op.func)
op.func = None
), f"Logic key of {op} wasn't calculated before cleaning up func."
logger.debug(f"{op} need cleaning up func.")
if is_on_ray(ctx):
import ray

op.func_key = ray.put(op.func)
op.func = None
else:
op.func = cloudpickle.dumps(op.func)


def restore_func(ctx: Context, op):
if op.need_clean_up_func and ctx is not None:
assert (
op.func_key is not None
), "Func key wasn't properly set while cleaning up func."
assert (
op.func is None
), "While restoring func, op.func should be None to ensure that cleanup was executed."
op.func = ctx.storage_get(op.func_key)
logger.debug(f"{op} need restoring func.")
if is_on_ray(ctx):
import ray

op.func = ray.get(op.func_key)
else:
op.func = cloudpickle.loads(op.func)
2 changes: 1 addition & 1 deletion mars/deploy/oscar/base_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ storage:
default_config:
transfer_block_size: 5 * 1024 ** 2
plasma:
store_memory: 12%
store_memory: 20%
"@overriding_fields": ["backends"]
meta:
store: dict
Expand Down
77 changes: 1 addition & 76 deletions mars/deploy/oscar/tests/test_checked_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
from typing import Any, Dict

import numpy as np
import pandas as pd
import pytest

from .... import dataframe as md
from .... import tensor as mt
from ....dataframe.base.apply import ApplyOperand
from ....config import option_context
from ....core import TileableGraph, TileableType, OperandType
from ....core import TileableType, OperandType
from ....services.task.supervisor.tests import CheckedTaskPreprocessor
from ....services.subtask.worker.tests import CheckedSubtaskProcessor
from ..local import _load_config
Expand All @@ -42,42 +39,6 @@ def _execute_operand(self, ctx: Dict[str, Any], op: OperandType):
return super()._execute_operand(ctx, op)


class FuncKeyCheckedTaskPreprocessor(CheckedTaskPreprocessor):
def tile(self, tileable_graph: TileableGraph):
ops = [t.op for t in tileable_graph if isinstance(t.op, ApplyOperand)]
assert all(hasattr(op, "func_key") for op in ops)
assert all(op.func_key is None for op in ops)
assert all(op.func is not None for op in ops)
assert all(op.need_clean_up_func is False for op in ops)
result = super().tile(tileable_graph)
for op in ops:
assert hasattr(op, "func_key")
if op.func_key is not None:
assert op.need_clean_up_func is True
assert op.func is None
else:
assert op.need_clean_up_func is False
assert op.func is not None
return result


class FuncKeyCheckedSubtaskProcessor(CheckedSubtaskProcessor):
def _execute_operand(self, ctx: Dict[str, Any], op: OperandType):
if isinstance(op, ApplyOperand):
assert hasattr(op, "func_key")
if op.func_key is not None:
assert op.need_clean_up_func is True
assert op.func is None
else:
assert op.need_clean_up_func is False
assert op.func is not None
result = super()._execute_operand(ctx, op)
assert op.func is not None
return result
else:
return super()._execute_operand(ctx, op)


@pytest.fixture(scope="module")
def setup():
with option_context({"show_progress": False}):
Expand Down Expand Up @@ -134,39 +95,3 @@ def test_check_subtask_processor(setup):
b.execute(extra_config={"check_all": False})

sess.stop_server()


def test_clean_up_and_restore_func(setup):
config = _load_config(CONFIG_FILE)
config["task"][
"task_preprocessor_cls"
] = "mars.deploy.oscar.tests.test_checked_session.FuncKeyCheckedTaskPreprocessor"
config["subtask"][
"subtask_processor_cls"
] = "mars.deploy.oscar.tests.test_checked_session.FuncKeyCheckedSubtaskProcessor"

sess = new_test_session(default=True, config=config)

cols = [chr(ord("A") + i) for i in range(10)]
df_raw = pd.DataFrame(dict((c, [i**2 for i in range(20)]) for c in cols))
df = md.DataFrame(df_raw, chunk_size=5)

x_small = pd.Series([i for i in range(10)])
y_small = pd.Series([i for i in range(10)])
x_large = pd.Series([i for i in range(10**4)])
y_large = pd.Series([i for i in range(10**4)])

def closure_small(z):
return pd.concat([x_small, y_small], ignore_index=True)

def closure_large(z):
return pd.concat([x_large, y_large], ignore_index=True)

# no need to clean up func, func_key won't be set
r_small = df.apply(closure_small, axis=1)
r_small.execute()
# need to clean up func, func_key will be set
r_large = df.apply(closure_large, axis=1)
r_large.execute()

sess.stop_server()
149 changes: 149 additions & 0 deletions mars/deploy/oscar/tests/test_clean_up_and_restore_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict

import pandas as pd
import pytest

from .... import dataframe as md
from ....dataframe.base.apply import ApplyOperand
from ....config import option_context
from ....core import TileableGraph, OperandType
from ....services.task.supervisor.tests import CheckedTaskPreprocessor
from ....services.subtask.worker.tests import CheckedSubtaskProcessor
from ....services.task.supervisor.preprocessor import TaskPreprocessor
from ....services.subtask.worker.processor import SubtaskProcessor

from ....utils import lazy_import
from ..local import _load_config as _load_mars_config
from ..tests.session import new_test_session, CONFIG_FILE

ray = lazy_import("ray")


class MarsBackendFuncCheckedTaskPreprocessor(CheckedTaskPreprocessor):
def tile(self, tileable_graph: TileableGraph):
ops = [t.op for t in tileable_graph if isinstance(t.op, ApplyOperand)]
for op in ops:
assert hasattr(op, "func_key")
assert op.func_key is None
assert op.func is not None
assert callable(op.func)
assert op.need_clean_up_func is False
result = super().tile(tileable_graph)
for op in ops:
assert hasattr(op, "func_key")
assert op.func_key is None
if op.need_clean_up_func:
assert isinstance(op.func, bytes)
else:
assert callable(op.func)
return result


class MarsBackendFuncCheckedSubtaskProcessor(CheckedSubtaskProcessor):
def _execute_operand(self, ctx: Dict[str, Any], op: OperandType):
if isinstance(op, ApplyOperand):
assert hasattr(op, "func_key")
assert op.func_key is None
if op.need_clean_up_func:
assert isinstance(op.func, bytes)
else:
assert callable(op.func)
result = super()._execute_operand(ctx, op)
assert op.func is not None
assert callable(op.func)
return result
else:
return super()._execute_operand(ctx, op)


class RayBackendFuncTaskPreprocessor(TaskPreprocessor):
def tile(self, tileable_graph: TileableGraph):
ops = [t.op for t in tileable_graph if isinstance(t.op, ApplyOperand)]
for op in ops:
assert hasattr(op, "func_key")
assert op.func_key is None
assert op.func is not None
assert callable(op.func)
assert op.need_clean_up_func is False
result = super().tile(tileable_graph)
for op in ops:
assert hasattr(op, "func_key")
if op.need_clean_up_func:
assert op.func is None
assert isinstance(op.func_key, ray.ObjectRef)
else:
assert callable(op.func)
assert op.func_key is None
return result


class RayBackendFuncSubtaskProcessor(SubtaskProcessor):
def _execute_operand(self, ctx: Dict[str, Any], op: OperandType):
if isinstance(op, ApplyOperand):
assert hasattr(op, "func_key")
if op.need_clean_up_func:
assert op.func is None
assert isinstance(op.func_key, ray.ObjectRef)
else:
assert callable(op.func)
assert op.func_key is None
result = super()._execute_operand(ctx, op)
assert op.func is not None
assert callable(op.func)
return result
else:
return super()._execute_operand(ctx, op)


@pytest.fixture(scope="module")
def setup():
with option_context({"show_progress": False}):
yield


def test_mars_backend_clean_up_and_restore_func(setup):
config = _load_mars_config(CONFIG_FILE)
config["task"][
"task_preprocessor_cls"
] = "mars.deploy.oscar.tests.test_clean_up_and_restore_func.MarsBackendFuncCheckedTaskPreprocessor"
config["subtask"][
"subtask_processor_cls"
] = "mars.deploy.oscar.tests.test_clean_up_and_restore_func.MarsBackendFuncCheckedSubtaskProcessor"

sess = new_test_session(default=True, config=config)

cols = [chr(ord("A") + i) for i in range(10)]
df_raw = pd.DataFrame(dict((c, [i**2 for i in range(20)]) for c in cols))
df = md.DataFrame(df_raw, chunk_size=5)

x_small = pd.Series([i for i in range(10)])
y_small = pd.Series([i for i in range(10)])
x_large = pd.Series([i for i in range(10**4)])
y_large = pd.Series([i for i in range(10**4)])

def closure_small(z):
return pd.concat([x_small, y_small], ignore_index=True)

def closure_large(z):
return pd.concat([x_large, y_large], ignore_index=True)

r_small = df.apply(closure_small, axis=1)
r_small.execute()
r_large = df.apply(closure_large, axis=1)
r_large.execute()

sess.stop_server()
2 changes: 0 additions & 2 deletions mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,6 @@ async def test_execute_describe(create_cluster):
)


# Note: Vineyard internally uses `pickle` which fails to pickle
# cell objects and corresponding functions.
@pytest.mark.asyncio
async def test_execute_apply_closure(create_cluster):
# DataFrame
Expand Down
26 changes: 26 additions & 0 deletions mars/deploy/oscar/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,32 @@ async def test_execute_describe(ray_start_regular_shared, create_cluster):
await test_local.test_execute_describe(create_cluster)


@require_ray
@pytest.mark.asyncio
async def test_execute_apply_closure(ray_start_regular_shared, create_cluster):
await test_local.test_execute_apply_closure(create_cluster)


@require_ray
@pytest.mark.parametrize(
"create_cluster",
[
{
"config": {
"task.task_preprocessor_cls": "mars.deploy.oscar.tests.test_clean_up_and_restore_func.RayBackendFuncTaskPreprocessor",
"subtask.subtask_processor_cls": "mars.deploy.oscar.tests.test_clean_up_and_restore_func.RayBackendFuncSubtaskProcessor",
}
}
],
indirect=True,
)
@pytest.mark.asyncio
async def test_ray_oscar_clean_up_and_restore_func(
ray_start_regular_shared, create_cluster
):
await test_local.test_execute_apply_closure(create_cluster)


@require_ray
@pytest.mark.asyncio
async def test_fetch_infos(ray_start_regular_shared, create_cluster):
Expand Down

0 comments on commit 71b1035

Please sign in to comment.