diff --git a/python/ray/dag/py_obj_scanner.py b/python/ray/dag/py_obj_scanner.py index 20798b8441912..478763bc3ba99 100644 --- a/python/ray/dag/py_obj_scanner.py +++ b/python/ray/dag/py_obj_scanner.py @@ -1,6 +1,6 @@ import io import sys -from typing import Generic, List, Dict, Any, Type, TypeVar +from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union # For python < 3.8 we need to explicitly use pickle5 to support protocol 5 if sys.version_info < (3, 8): @@ -32,27 +32,22 @@ def _get_node(instance_id: int, node_index: int) -> SourceType: return _instances[instance_id]._replace_index(node_index) -def _get_object(instance_id: int, node_index: int) -> Any: - """Used to get arbitrary object other than SourceType. - - Note: This function should be static and globally importable, - otherwise the serialization overhead would be very significant. - """ - return _instances[instance_id]._objects[node_index] - - class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, TransformedType]): """Utility to find and replace the `source_type` in Python objects. - This uses pickle to walk the PyObj graph and find first-level DAGNode - instances on ``find_nodes()``. The caller can then compute a replacement - table and then replace the nodes via ``replace_nodes()``. + `source_type` can either be a single type or a tuple of multiple types. + + The caller must first call `find_nodes()`, then compute a replacement table and + pass it to `replace_nodes`. + + This uses cloudpickle under the hood, so all sub-objects that are not `source_type` + must be serializable. Args: - source_type: the type of object to find and replace. Default to DAGNodeBase. + source_type: the type(s) of object to find and replace. Default to DAGNodeBase. """ - def __init__(self, source_type: Type = DAGNodeBase): + def __init__(self, source_type: Union[Type, Tuple] = DAGNodeBase): self.source_type = source_type # Buffer to keep intermediate serialized state. self._buf = io.BytesIO() @@ -70,21 +65,17 @@ def __init__(self, source_type: Type = DAGNodeBase): def reducer_override(self, obj): """Hook for reducing objects. - The function intercepts serialization of all objects and store them - to internal data structures, preventing actually writing them to - the buffer. + Objects of `self.source_type` are saved to `self._found` and a global map so + they can later be replaced. + + All other objects fall back to the default `CloudPickler` serialization. """ - if obj is _get_node or obj is _get_object: - # Only fall back to cloudpickle for these two functions. - return super().reducer_override(obj) - elif isinstance(obj, self.source_type): + if isinstance(obj, self.source_type): index = len(self._found) self._found.append(obj) return _get_node, (id(self), index) - else: - index = len(self._objects) - self._objects.append(obj) - return _get_object, (id(self), index) + + return super().reducer_override(obj) def find_nodes(self, obj: Any) -> List[SourceType]: """Find top-level DAGNodes.""" diff --git a/python/ray/dag/tests/test_py_obj_scanner.py b/python/ray/dag/tests/test_py_obj_scanner.py index e70bf0b220c7a..c07fdd499e38f 100644 --- a/python/ray/dag/tests/test_py_obj_scanner.py +++ b/python/ray/dag/tests/test_py_obj_scanner.py @@ -1,5 +1,7 @@ -from ray.dag.py_obj_scanner import _PyObjScanner, _instances import pytest +from typing import Any + +from ray.dag.py_obj_scanner import _PyObjScanner, _instances class Source: @@ -17,21 +19,40 @@ def test_simple_replace(): assert replaced == [1, [1, {"key": 1}]] -class NotSerializable: - def __reduce__(self): - raise Exception("don't even try to serialize me.") +def test_replace_multiple_types(): + class OtherSource: + pass + + scanner = _PyObjScanner(source_type=(Source, OtherSource)) + my_objs = [Source(), [Source(), {"key": Source(), "key2": OtherSource()}]] + + found = scanner.find_nodes(my_objs) + assert len(found) == 4 + replaced = scanner.replace_nodes( + {obj: 1 if isinstance(obj, Source) else 2 for obj in found} + ) + assert replaced == [1, [1, {"key": 1, "key2": 2}]] -def test_not_serializing_objects(): + +def test_replace_nested_in_obj(): + """Test that the source can be nested in arbitrary objects.""" scanner = _PyObjScanner(source_type=Source) - not_serializable = NotSerializable() - my_objs = [not_serializable, {"key": Source()}] + + class Outer: + def __init__(self, inner: Any): + self._inner = inner + + def __eq__(self, other): + return self._inner == other._inner + + my_objs = [Outer(Source()), Outer(Outer(Source())), Outer((Source(),))] found = scanner.find_nodes(my_objs) - assert len(found) == 1 + assert len(found) == 3 replaced = scanner.replace_nodes({obj: 1 for obj in found}) - assert replaced == [not_serializable, {"key": 1}] + assert replaced == [Outer(1), Outer(Outer(1)), Outer((1,))] def test_scanner_clear(): diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index d59aac388dcc1..d38922401d604 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -98,85 +98,75 @@ class Query: kwargs: Dict[Any, Any] metadata: RequestMetadata - async def resolve_async_tasks(self): - """Find all unresolved asyncio.Task and gather them all at once. - - This is used for the old serve handle API and should be removed once that API - is fully deprecated & removed. - """ - scanner = _PyObjScanner(source_type=asyncio.Task) - - try: - tasks = scanner.find_nodes((self.args, self.kwargs)) - - if len(tasks) > 0: - resolved = await asyncio.gather(*tasks) - replacement_table = dict(zip(tasks, resolved)) - self.args, self.kwargs = scanner.replace_nodes(replacement_table) - finally: - # Make the scanner GC-able to avoid memory leaks. - scanner.clear() - - async def resolve_deployment_handle_results_to_object_refs(self): - """Replace DeploymentHandleResults with their resolved ObjectRefs. - - DeploymentResponseGenerators are rejected (not currently supported). + async def replace_known_types_in_args(self): + """Uses the `_PyObjScanner` to find and replace known types. + + 1) Replaces `asyncio.Task` objects with their results. This is used for the old + serve handle API and should be removed once that API is deprecated & removed. + 2) Replaces `DeploymentResponse` objects with their resolved object refs. This + enables composition without explicitly calling `_to_object_ref`. + 3) Buffers the bodies of `starlette.requests.Request` objects to avoid them + being unserializable. This is a temporary compatibility measure and passing + the objects should be fully disallowed in a future release. """ from ray.serve.handle import ( _DeploymentResponseBase, + DeploymentResponse, DeploymentResponseGenerator, ) - scanner = _PyObjScanner(source_type=_DeploymentResponseBase) + scanner = _PyObjScanner( + source_type=(asyncio.Task, _DeploymentResponseBase, Request) + ) try: - result_to_object_ref_coros = [] - results = scanner.find_nodes((self.args, self.kwargs)) - for result in results: - result_to_object_ref_coros.append(result._to_object_ref()) - if isinstance(result, DeploymentResponseGenerator): + tasks = [] + responses = [] + replacement_table = {} + objs = scanner.find_nodes((self.args, self.kwargs)) + for obj in objs: + if isinstance(obj, asyncio.Task): + tasks.append(obj) + elif isinstance(obj, DeploymentResponseGenerator): raise RuntimeError( "Streaming deployment handle results cannot be passed to " "downstream handle calls. If you have a use case requiring " "this feature, please file a feature request on GitHub." ) + elif isinstance(obj, DeploymentResponse): + responses.append(obj) + elif isinstance(obj, Request): + global WARNED_ABOUT_STARLETTE_REQUESTS_ONCE + if not WARNED_ABOUT_STARLETTE_REQUESTS_ONCE: + # TODO(edoakes): fully disallow this in the future. + warnings.warn( + "`starlette.Request` objects should not be directly passed " + "via `ServeHandle` calls. Not all functionality is " + "guaranteed to work (e.g., detecting disconnects) and this " + "may be disallowed in a future release." + ) + WARNED_ABOUT_STARLETTE_REQUESTS_ONCE = True - if len(results) > 0: - obj_refs = await asyncio.gather(*result_to_object_ref_coros) - replacement_table = dict(zip(results, obj_refs)) - self.args, self.kwargs = scanner.replace_nodes(replacement_table) - finally: - # Make the scanner GC-able to avoid memory leaks. - scanner.clear() + async def empty_send(): + pass - async def buffer_starlette_requests_and_warn(self): - """Buffer any `starlette.request.Requests` objects to make them serializable. + obj._send = empty_send + obj._receive = make_buffered_asgi_receive(await obj.body()) + replacement_table[obj] = obj - This is an anti-pattern because the requests will not be fully functional, so - warn the user. We may fully disallow it in the future. - """ - global WARNED_ABOUT_STARLETTE_REQUESTS_ONCE - scanner = _PyObjScanner(source_type=Request) + # Gather `asyncio.Task` results concurrently. + if len(tasks) > 0: + resolved = await asyncio.gather(*tasks) + replacement_table.update(zip(tasks, resolved)) - try: - requests = scanner.find_nodes((self.args, self.kwargs)) - if len(requests) > 0 and not WARNED_ABOUT_STARLETTE_REQUESTS_ONCE: - WARNED_ABOUT_STARLETTE_REQUESTS_ONCE = True - # TODO(edoakes): fully disallow this in the future. - warnings.warn( - "`starlette.Request` objects should not be directly passed via " - "`ServeHandle` calls. Not all functionality is guaranteed to work " - "(e.g., detecting disconnects) and this may be disallowed in a " - "future release." + # Gather `DeploymentResponse` object refs concurrently. + if len(responses) > 0: + obj_refs = await asyncio.gather( + *[r._to_object_ref() for r in responses] ) + replacement_table.update((zip(responses, obj_refs))) - for request in requests: - - async def empty_send(): - pass - - request._send = empty_send - request._receive = make_buffered_asgi_receive(await request.body()) + self.args, self.kwargs = scanner.replace_nodes(replacement_table) finally: # Make the scanner GC-able to avoid memory leaks. scanner.clear() @@ -1219,10 +1209,7 @@ async def assign_request( kwargs=request_kwargs, metadata=request_meta, ) - await query.resolve_async_tasks() - await query.resolve_deployment_handle_results_to_object_refs() - await query.buffer_starlette_requests_and_warn() - + await query.replace_known_types_in_args() return await self._replica_scheduler.assign_replica(query) finally: # If the query is disconnected before assignment, this coroutine diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index c13b8a4d99906..73ace64b44e1c 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -1057,6 +1057,33 @@ async def __call__(self): assert app_handle.remote().result() == "hello world!!" +def test_deployment_handle_nested_in_obj(serve_instance): + """Test binding a handle within a custom object.""" + + class HandleWrapper: + def __init__(self, handle: RayServeHandle): + self._handle = handle + + def get(self) -> DeploymentHandle: + return self._handle.options(use_new_handle_api=True) + + @serve.deployment + def f() -> str: + return "hi" + + @serve.deployment + class MyDriver: + def __init__(self, handle_wrapper: HandleWrapper): + self.handle_wrapper = handle_wrapper + + async def __call__(self) -> str: + return await self.handle_wrapper.get().remote() + + handle_wrapper = HandleWrapper(f.bind()) + h = serve.run(MyDriver.bind(handle_wrapper)).options(use_new_handle_api=True) + assert h.remote().result() == "hi" + + if __name__ == "__main__": import sys diff --git a/python/ray/serve/tests/test_deploy_2.py b/python/ray/serve/tests/test_deploy_2.py index 1c85dad78baca..0d0cd04e09d29 100644 --- a/python/ray/serve/tests/test_deploy_2.py +++ b/python/ray/serve/tests/test_deploy_2.py @@ -2,6 +2,7 @@ import functools import os import sys +import threading import time from typing import Dict @@ -345,8 +346,6 @@ async def __call__(self): def test_nonserializable_deployment(serve_instance): - import threading - lock = threading.Lock() class D: @@ -365,16 +364,10 @@ class E: def __init__(self, arg): self.arg = arg - with pytest.raises( - TypeError, - match=r"Could not serialize the deployment init args:[\s\S]*was found to be non-serializable.*", # noqa - ): + with pytest.raises(TypeError, match="pickle"): serve.run(E.bind(lock)) - with pytest.raises( - TypeError, - match=r"Could not serialize the deployment init kwargs:[\s\S]*was found to be non-serializable.*", # noqa - ): + with pytest.raises(TypeError, match="pickle"): serve.run(E.bind(arg=lock))