Skip to content

Commit

Permalink
Ensure key is a valid Python identifier (#3190)
Browse files Browse the repository at this point in the history
  • Loading branch information
fyrestone committed Jul 28, 2022
1 parent dbbbcaa commit 5c3093a
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 21 deletions.
17 changes: 15 additions & 2 deletions mars/_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ import numpy as np
import pandas as pd
import cloudpickle
cimport cython
from libc.stdint cimport uint_fast64_t
from cpython cimport PyBytes_FromStringAndSize
from libc.stdint cimport uint_fast64_t, uint32_t, uint8_t
from libc.stdlib cimport malloc, free
from .lib.cython.libcpp cimport mt19937_64
try:
Expand All @@ -46,6 +47,18 @@ cdef bint _has_cudf = bool(pkgutil.find_loader('cudf'))
cdef bint _has_sqlalchemy = bool(pkgutil.find_loader('sqlalchemy'))


cdef extern from "MurmurHash3.h":
void MurmurHash3_x64_128(const void * key, Py_ssize_t len, uint32_t seed, void * out)


cdef bytes _get_mars_key(const uint8_t[:] bufferview):
cdef const uint8_t *data = &bufferview[0]
cdef uint8_t out[16]
MurmurHash3_x64_128(data, len(bufferview), 0, out)
out[0] |= 0xC0
return PyBytes_FromStringAndSize(<char*>out, 16)


cpdef str to_str(s, encoding='utf-8'):
if type(s) is str:
return <str>s
Expand Down Expand Up @@ -161,7 +174,7 @@ cdef inline build_canonical_bytes(tuple args, kwargs):


def tokenize(*args, **kwargs):
return mmh_hash_bytes(build_canonical_bytes(args, kwargs)).hex()
return _get_mars_key(build_canonical_bytes(args, kwargs)).hex()


def tokenize_int(*args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions mars/core/operand/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pandas as pd
import pytest

from ....dataframe import core # noqa: F401 # pylint: disable=unused-variable
from ... import OutputType
from .. import Operand, TileableOperandMixin, execute, estimate_size, ShuffleProxy

Expand Down
5 changes: 3 additions & 2 deletions mars/deploy/oscar/tests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import uuid

from ....conftest import MARS_CI_BACKEND
from ....core import OBJECT_TYPE
from ....deploy.oscar.local import LocalCluster, LocalClient
from ....tests.core import _check_args, ObjectCheckMixin
Expand Down Expand Up @@ -72,7 +73,7 @@ async def fetch(self, *tileables, **kwargs):
async def _new_test_session(
address: str,
session_id: str = None,
backend: str = "mars",
backend: str = MARS_CI_BACKEND,
default: bool = False,
new: bool = True,
timeout: float = None,
Expand Down Expand Up @@ -129,7 +130,7 @@ async def _new_test_cluster_in_isolation(**new_cluster_kwargs):
def new_test_session(
address: str = None,
session_id: str = None,
backend: str = "mars",
backend: str = MARS_CI_BACKEND,
default: bool = False,
new: bool = True,
**kwargs,
Expand Down
2 changes: 2 additions & 0 deletions mars/services/subtask/worker/tests/subtask_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def _execute_operand(self, ctx: Dict[str, Any], op: OperandType):
continue
if self._check_keys and out.key not in self._check_keys:
continue
# The first char of key is a letter.
assert out.key[0] in {"c", "d", "e", "f"}, out.key
if out.key not in ctx and any(
k[0] == out.key for k in ctx if isinstance(k, tuple)
):
Expand Down
15 changes: 10 additions & 5 deletions mars/services/task/execution/ray/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,16 @@ def _on_execute_aiotask_done(_):
output_object_refs = set()
for chunk in chunk_graph.result_chunks:
chunk_key = chunk.key
object_ref = task_context[chunk_key]
output_object_refs.add(object_ref)
chunk_params = key_to_meta.get(chunk_key)
if chunk_params is not None:
chunk_to_meta[chunk] = ExecutionChunkResult(chunk_params, object_ref)
# The result chunk may be in previous stage result,
# then the chunk does not have to be processed.
if chunk_key in task_context:
object_ref = task_context[chunk_key]
output_object_refs.add(object_ref)
chunk_params = key_to_meta.get(chunk_key)
if chunk_params is not None:
chunk_to_meta[chunk] = ExecutionChunkResult(
chunk_params, object_ref
)

logger.info("Waiting for stage %s complete.", stage_id)
# Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
Expand Down
2 changes: 1 addition & 1 deletion mars/services/task/execution/ray/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ async def get(self):
refs[index] = fetch_info.object_ref
else:
refs[index] = self._remote_query_object_with_condition().remote(
fetch_info.object_ref, fetch_info.conditions
fetch_info.object_ref, tuple(fetch_info.conditions)
)
return await asyncio.gather(*refs)
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ async def test_ray_fetcher(ray_start_regular_shared2):
fetcher = RayFetcher()
await fetcher.append("pd_key", {"object_refs": [pd_object_ref]})
await fetcher.append("np_key", {"object_refs": [np_object_ref]})
await fetcher.append("pd_key", {"object_refs": [pd_object_ref]}, [1, 3])
await fetcher.append("np_key", {"object_refs": [np_object_ref]}, [1, 3])
await fetcher.append("pd_key", {"object_refs": [pd_object_ref]}, [slice(1, 3, 1)])
await fetcher.append("np_key", {"object_refs": [np_object_ref]}, [slice(1, 3, 1)])
results = await fetcher.get()
pd.testing.assert_frame_equal(results[0], pd_value)
np.testing.assert_array_equal(results[1], np_value)
pd.testing.assert_frame_equal(results[2], pd_value.iloc[[1, 3]])
np.testing.assert_array_equal(results[3], np_value[[1, 3]])
pd.testing.assert_frame_equal(results[2], pd_value.iloc[1:3])
np.testing.assert_array_equal(results[3], np_value[1:3])


@require_ray
Expand Down
1 change: 1 addition & 0 deletions mars/tensor/datastore/tests/test_datastore_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_store_tiledb_execution(setup):


@pytest.mark.skipif(h5py is None, reason="h5py not installed")
@pytest.mark.ray_dag
def test_store_hdf5_execution(setup):
raw = np.random.RandomState(0).rand(10, 20)

Expand Down
1 change: 1 addition & 0 deletions mars/tests/test_eager_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def test_mixed_config(setup):
np.testing.assert_array_equal(r.execute(), np.ones((10, 10)) * 10)


@pytest.mark.ray_dag
def test_index(setup):
with option_context({"eager_mode": True}):
a = mt.random.rand(10, 5, chunk_size=5)
Expand Down
1 change: 1 addition & 0 deletions mars/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def test_without_fuse(setup):
np.testing.assert_array_equal(r1, r2)


@pytest.mark.ray_dag
def test_fetch_slices(setup):
arr1 = mt.random.rand(10, 8, chunk_size=3)
r1 = arr1.execute().fetch()
Expand Down
9 changes: 4 additions & 5 deletions mars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,11 +1726,10 @@ def sync_to_async(func):
if inspect.iscoroutinefunction(func):
return func
else:

async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)

return async_wrapper
# Wrap the sync call to thread to avoid blocking the
# asyncio event loop. e.g. acquiring a threading.Lock()
# in the sync call.
return functools.partial(asyncio.to_thread, func)


def retry_callable(
Expand Down
17 changes: 15 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,32 @@
cy_extension_kw["extra_compile_args"] = extra_compile_args


# The pyx with C sources.
ext_include_source_map = {
"mars/_utils.pyx": [["mars/lib/mmh3_src"], ["mars/lib/mmh3_src/MurmurHash3.cpp"]],
}


def _discover_pyx():
exts = dict()
for root, _, files in os.walk(os.path.join(repo_root, "mars")):
for fn in files:
if not fn.endswith(".pyx"):
continue
full_fn = os.path.relpath(os.path.join(root, fn), repo_root)
include_dirs, source = ext_include_source_map.get(
full_fn.replace(os.path.sep, "/"), [[], []]
)
mod_name = full_fn.replace(".pyx", "").replace(os.path.sep, ".")
exts[mod_name] = Extension(mod_name, [full_fn], **cy_extension_kw)
exts[mod_name] = Extension(
mod_name,
[full_fn] + source,
include_dirs=[np.get_include()] + include_dirs,
**cy_extension_kw,
)
return exts


cy_extension_kw["include_dirs"] = [np.get_include()]
extensions_dict = _discover_pyx()
cy_extensions = list(extensions_dict.values())

Expand Down

0 comments on commit 5c3093a

Please sign in to comment.