Skip to content

Commit

Permalink
[serve] Support passing bound deployments within custom objects (ray-…
Browse files Browse the repository at this point in the history
…project#39015)

As per ray-project#38809, you currently cannot pass bound deployments nested within custom objects. This PR lifts that restriction.

The approach I took is to remove the "arbitrary object replacement" path in `_PyObjScanner.reducer_override`, which was effectively causing cloudpickle to return early. Instead, we now fully serialize objects aside from the `SourceType` using the standard cloudpickle path.

This has one major downside: all objects that `_PyObjScanner` is called on must now be serializable. This is not an issue for its current usage in the code base, but it required me to also add support for finding and replacing multiple types at once (because we currently do multiple passes on each Serve `Query` object).
  • Loading branch information
edoakes committed Sep 6, 2023
1 parent 54f9bf1 commit f2a2560
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 109 deletions.
43 changes: 17 additions & 26 deletions python/ray/dag/py_obj_scanner.py
@@ -1,6 +1,6 @@
import io
import sys
from typing import Generic, List, Dict, Any, Type, TypeVar
from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union

# For python < 3.8 we need to explicitly use pickle5 to support protocol 5
if sys.version_info < (3, 8):
Expand Down Expand Up @@ -32,27 +32,22 @@ def _get_node(instance_id: int, node_index: int) -> SourceType:
return _instances[instance_id]._replace_index(node_index)


def _get_object(instance_id: int, node_index: int) -> Any:
"""Used to get arbitrary object other than SourceType.
Note: This function should be static and globally importable,
otherwise the serialization overhead would be very significant.
"""
return _instances[instance_id]._objects[node_index]


class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, TransformedType]):
"""Utility to find and replace the `source_type` in Python objects.
This uses pickle to walk the PyObj graph and find first-level DAGNode
instances on ``find_nodes()``. The caller can then compute a replacement
table and then replace the nodes via ``replace_nodes()``.
`source_type` can either be a single type or a tuple of multiple types.
The caller must first call `find_nodes()`, then compute a replacement table and
pass it to `replace_nodes`.
This uses cloudpickle under the hood, so all sub-objects that are not `source_type`
must be serializable.
Args:
source_type: the type of object to find and replace. Default to DAGNodeBase.
source_type: the type(s) of object to find and replace. Default to DAGNodeBase.
"""

def __init__(self, source_type: Type = DAGNodeBase):
def __init__(self, source_type: Union[Type, Tuple] = DAGNodeBase):
self.source_type = source_type
# Buffer to keep intermediate serialized state.
self._buf = io.BytesIO()
Expand All @@ -70,21 +65,17 @@ def __init__(self, source_type: Type = DAGNodeBase):
def reducer_override(self, obj):
"""Hook for reducing objects.
The function intercepts serialization of all objects and store them
to internal data structures, preventing actually writing them to
the buffer.
Objects of `self.source_type` are saved to `self._found` and a global map so
they can later be replaced.
All other objects fall back to the default `CloudPickler` serialization.
"""
if obj is _get_node or obj is _get_object:
# Only fall back to cloudpickle for these two functions.
return super().reducer_override(obj)
elif isinstance(obj, self.source_type):
if isinstance(obj, self.source_type):
index = len(self._found)
self._found.append(obj)
return _get_node, (id(self), index)
else:
index = len(self._objects)
self._objects.append(obj)
return _get_object, (id(self), index)

return super().reducer_override(obj)

def find_nodes(self, obj: Any) -> List[SourceType]:
"""Find top-level DAGNodes."""
Expand Down
39 changes: 30 additions & 9 deletions python/ray/dag/tests/test_py_obj_scanner.py
@@ -1,5 +1,7 @@
from ray.dag.py_obj_scanner import _PyObjScanner, _instances
import pytest
from typing import Any

from ray.dag.py_obj_scanner import _PyObjScanner, _instances


class Source:
Expand All @@ -17,21 +19,40 @@ def test_simple_replace():
assert replaced == [1, [1, {"key": 1}]]


class NotSerializable:
def __reduce__(self):
raise Exception("don't even try to serialize me.")
def test_replace_multiple_types():
class OtherSource:
pass

scanner = _PyObjScanner(source_type=(Source, OtherSource))
my_objs = [Source(), [Source(), {"key": Source(), "key2": OtherSource()}]]

found = scanner.find_nodes(my_objs)
assert len(found) == 4

replaced = scanner.replace_nodes(
{obj: 1 if isinstance(obj, Source) else 2 for obj in found}
)
assert replaced == [1, [1, {"key": 1, "key2": 2}]]

def test_not_serializing_objects():

def test_replace_nested_in_obj():
"""Test that the source can be nested in arbitrary objects."""
scanner = _PyObjScanner(source_type=Source)
not_serializable = NotSerializable()
my_objs = [not_serializable, {"key": Source()}]

class Outer:
def __init__(self, inner: Any):
self._inner = inner

def __eq__(self, other):
return self._inner == other._inner

my_objs = [Outer(Source()), Outer(Outer(Source())), Outer((Source(),))]

found = scanner.find_nodes(my_objs)
assert len(found) == 1
assert len(found) == 3

replaced = scanner.replace_nodes({obj: 1 for obj in found})
assert replaced == [not_serializable, {"key": 1}]
assert replaced == [Outer(1), Outer(Outer(1)), Outer((1,))]


def test_scanner_clear():
Expand Down
115 changes: 51 additions & 64 deletions python/ray/serve/_private/router.py
Expand Up @@ -98,85 +98,75 @@ class Query:
kwargs: Dict[Any, Any]
metadata: RequestMetadata

async def resolve_async_tasks(self):
"""Find all unresolved asyncio.Task and gather them all at once.
This is used for the old serve handle API and should be removed once that API
is fully deprecated & removed.
"""
scanner = _PyObjScanner(source_type=asyncio.Task)

try:
tasks = scanner.find_nodes((self.args, self.kwargs))

if len(tasks) > 0:
resolved = await asyncio.gather(*tasks)
replacement_table = dict(zip(tasks, resolved))
self.args, self.kwargs = scanner.replace_nodes(replacement_table)
finally:
# Make the scanner GC-able to avoid memory leaks.
scanner.clear()

async def resolve_deployment_handle_results_to_object_refs(self):
"""Replace DeploymentHandleResults with their resolved ObjectRefs.
DeploymentResponseGenerators are rejected (not currently supported).
async def replace_known_types_in_args(self):
"""Uses the `_PyObjScanner` to find and replace known types.
1) Replaces `asyncio.Task` objects with their results. This is used for the old
serve handle API and should be removed once that API is deprecated & removed.
2) Replaces `DeploymentResponse` objects with their resolved object refs. This
enables composition without explicitly calling `_to_object_ref`.
3) Buffers the bodies of `starlette.requests.Request` objects to avoid them
being unserializable. This is a temporary compatibility measure and passing
the objects should be fully disallowed in a future release.
"""
from ray.serve.handle import (
_DeploymentResponseBase,
DeploymentResponse,
DeploymentResponseGenerator,
)

scanner = _PyObjScanner(source_type=_DeploymentResponseBase)
scanner = _PyObjScanner(
source_type=(asyncio.Task, _DeploymentResponseBase, Request)
)

try:
result_to_object_ref_coros = []
results = scanner.find_nodes((self.args, self.kwargs))
for result in results:
result_to_object_ref_coros.append(result._to_object_ref())
if isinstance(result, DeploymentResponseGenerator):
tasks = []
responses = []
replacement_table = {}
objs = scanner.find_nodes((self.args, self.kwargs))
for obj in objs:
if isinstance(obj, asyncio.Task):
tasks.append(obj)
elif isinstance(obj, DeploymentResponseGenerator):
raise RuntimeError(
"Streaming deployment handle results cannot be passed to "
"downstream handle calls. If you have a use case requiring "
"this feature, please file a feature request on GitHub."
)
elif isinstance(obj, DeploymentResponse):
responses.append(obj)
elif isinstance(obj, Request):
global WARNED_ABOUT_STARLETTE_REQUESTS_ONCE
if not WARNED_ABOUT_STARLETTE_REQUESTS_ONCE:
# TODO(edoakes): fully disallow this in the future.
warnings.warn(
"`starlette.Request` objects should not be directly passed "
"via `ServeHandle` calls. Not all functionality is "
"guaranteed to work (e.g., detecting disconnects) and this "
"may be disallowed in a future release."
)
WARNED_ABOUT_STARLETTE_REQUESTS_ONCE = True

if len(results) > 0:
obj_refs = await asyncio.gather(*result_to_object_ref_coros)
replacement_table = dict(zip(results, obj_refs))
self.args, self.kwargs = scanner.replace_nodes(replacement_table)
finally:
# Make the scanner GC-able to avoid memory leaks.
scanner.clear()
async def empty_send():
pass

async def buffer_starlette_requests_and_warn(self):
"""Buffer any `starlette.request.Requests` objects to make them serializable.
obj._send = empty_send
obj._receive = make_buffered_asgi_receive(await obj.body())
replacement_table[obj] = obj

This is an anti-pattern because the requests will not be fully functional, so
warn the user. We may fully disallow it in the future.
"""
global WARNED_ABOUT_STARLETTE_REQUESTS_ONCE
scanner = _PyObjScanner(source_type=Request)
# Gather `asyncio.Task` results concurrently.
if len(tasks) > 0:
resolved = await asyncio.gather(*tasks)
replacement_table.update(zip(tasks, resolved))

try:
requests = scanner.find_nodes((self.args, self.kwargs))
if len(requests) > 0 and not WARNED_ABOUT_STARLETTE_REQUESTS_ONCE:
WARNED_ABOUT_STARLETTE_REQUESTS_ONCE = True
# TODO(edoakes): fully disallow this in the future.
warnings.warn(
"`starlette.Request` objects should not be directly passed via "
"`ServeHandle` calls. Not all functionality is guaranteed to work "
"(e.g., detecting disconnects) and this may be disallowed in a "
"future release."
# Gather `DeploymentResponse` object refs concurrently.
if len(responses) > 0:
obj_refs = await asyncio.gather(
*[r._to_object_ref() for r in responses]
)
replacement_table.update((zip(responses, obj_refs)))

for request in requests:

async def empty_send():
pass

request._send = empty_send
request._receive = make_buffered_asgi_receive(await request.body())
self.args, self.kwargs = scanner.replace_nodes(replacement_table)
finally:
# Make the scanner GC-able to avoid memory leaks.
scanner.clear()
Expand Down Expand Up @@ -1219,10 +1209,7 @@ async def assign_request(
kwargs=request_kwargs,
metadata=request_meta,
)
await query.resolve_async_tasks()
await query.resolve_deployment_handle_results_to_object_refs()
await query.buffer_starlette_requests_and_warn()

await query.replace_known_types_in_args()
return await self._replica_scheduler.assign_replica(query)
finally:
# If the query is disconnected before assignment, this coroutine
Expand Down
27 changes: 27 additions & 0 deletions python/ray/serve/tests/test_api.py
Expand Up @@ -1057,6 +1057,33 @@ async def __call__(self):
assert app_handle.remote().result() == "hello world!!"


def test_deployment_handle_nested_in_obj(serve_instance):
"""Test binding a handle within a custom object."""

class HandleWrapper:
def __init__(self, handle: RayServeHandle):
self._handle = handle

def get(self) -> DeploymentHandle:
return self._handle.options(use_new_handle_api=True)

@serve.deployment
def f() -> str:
return "hi"

@serve.deployment
class MyDriver:
def __init__(self, handle_wrapper: HandleWrapper):
self.handle_wrapper = handle_wrapper

async def __call__(self) -> str:
return await self.handle_wrapper.get().remote()

handle_wrapper = HandleWrapper(f.bind())
h = serve.run(MyDriver.bind(handle_wrapper)).options(use_new_handle_api=True)
assert h.remote().result() == "hi"


if __name__ == "__main__":
import sys

Expand Down
13 changes: 3 additions & 10 deletions python/ray/serve/tests/test_deploy_2.py
Expand Up @@ -2,6 +2,7 @@
import functools
import os
import sys
import threading
import time
from typing import Dict

Expand Down Expand Up @@ -345,8 +346,6 @@ async def __call__(self):


def test_nonserializable_deployment(serve_instance):
import threading

lock = threading.Lock()

class D:
Expand All @@ -365,16 +364,10 @@ class E:
def __init__(self, arg):
self.arg = arg

with pytest.raises(
TypeError,
match=r"Could not serialize the deployment init args:[\s\S]*was found to be non-serializable.*", # noqa
):
with pytest.raises(TypeError, match="pickle"):
serve.run(E.bind(lock))

with pytest.raises(
TypeError,
match=r"Could not serialize the deployment init kwargs:[\s\S]*was found to be non-serializable.*", # noqa
):
with pytest.raises(TypeError, match="pickle"):
serve.run(E.bind(arg=lock))


Expand Down

0 comments on commit f2a2560

Please sign in to comment.