Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT-#7001: Do not force materialization in MetaList.__getitem__() #7006

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion modin/core/execution/ray/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

"""Common utilities for Ray execution engine."""

from .engine_wrapper import RayWrapper, SignalActor
from .engine_wrapper import ObjectRefMapper, RayWrapper, SignalActor
from .utils import initialize_ray

__all__ = [
"initialize_ray",
"RayWrapper",
"ObjectRefMapper",
"SignalActor",
]
47 changes: 43 additions & 4 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ray._private.services import get_node_ip_address
from ray.util.client.common import ClientObjectRef

from modin.core.execution.ray.common import RayWrapper
from modin.core.execution.ray.common import ObjectRefMapper, RayWrapper
from modin.logging import get_logger

ObjectRefType = Union[ray.ObjectRef, ClientObjectRef, None]
Expand Down Expand Up @@ -491,9 +491,7 @@ def __getitem__(self, index):
Any
"""
obj = self._obj
if not isinstance(obj, list):
self._obj = obj = RayWrapper.materialize(obj)
return obj[index]
return obj[index] if isinstance(obj, list) else MetaListMapper(self, index)

def __setitem__(self, index, value):
"""
Expand All @@ -510,6 +508,47 @@ def __setitem__(self, index, value):
obj[index] = value


class MetaListMapper(ObjectRefMapper):
"""
Used by MetaList.__getitem__() for lazy materialization.

Parameters
----------
meta : MetaList
idx : int
AndreyPavlenko marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, meta: MetaList, idx: int):
self.meta = meta
self.idx = idx

def get(self):
"""
Get item at self.idx or object ref if not materialized.

Returns
-------
object
"""
obj = self.meta._obj
return obj[self.idx] if isinstance(obj, list) else obj

def map(self, materialized):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am kind of confused why this method is named map. Can you elaborate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main idea is to apply some calculation to a single materialized object and get one or multiple values. List with lengths is mapped to multiple lengths with MetaListMapper. Length is mapped to a different length with SlicedLenMapper.
If it's confusing, I let's rename it fo, for example, transfrom.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to MaterializationHook.

"""
Save the materialized list in self.meta and get the item at self.idx.

Parameters
----------
materialized : list

Returns
-------
object
"""
self.meta._obj = materialized
return materialized[self.idx]


class _Tag(Enum): # noqa: PR01
"""
A set of special values used for the method arguments de/construction.
Expand Down
105 changes: 96 additions & 9 deletions modin/core/execution/ray/common/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import asyncio
import os
from types import FunctionType
from typing import Sequence

import ray
from ray.util.client.common import ClientObjectRef
Expand Down Expand Up @@ -96,8 +97,7 @@
boolean
If the value is a future.
"""
ObjectIDType = (ray.ObjectRef, ClientObjectRef)
return isinstance(item, ObjectIDType)
return isinstance(item, ObjectRefTypes)

Check warning on line 100 in modin/core/execution/ray/common/engine_wrapper.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/engine_wrapper.py#L100

Added line #L100 was not covered by tests

@classmethod
def materialize(cls, obj_id):
Expand All @@ -114,7 +114,56 @@
object
Whatever was identified by `obj_id`.
"""
return ray.get(obj_id)
if isinstance(obj_id, ObjectRefMapper):
obj = obj_id.get()
return (
obj_id.map(ray.get(obj)) if isinstance(obj, RayObjectRefTypes) else obj
)

if not isinstance(obj_id, Sequence):
return ray.get(obj_id) if isinstance(obj_id, RayObjectRefTypes) else obj_id

if all(isinstance(obj, RayObjectRefTypes) for obj in obj_id):
return ray.get(obj_id)

ids = {}
result = []
for obj in obj_id:
if isinstance(obj, ObjectRefTypes):
if isinstance(obj, ObjectRefMapper):
oid = obj.get()
if isinstance(oid, RayObjectRefTypes):
mapper = obj
obj = oid
else:
result.append(oid)
continue

Check warning on line 140 in modin/core/execution/ray/common/engine_wrapper.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/engine_wrapper.py#L139-L140

Added lines #L139 - L140 were not covered by tests
else:
mapper = None
else:
result.append(obj)
continue

idx = ids.get(obj, None)
if idx is None:
ids[obj] = idx = len(ids)
YarShev marked this conversation as resolved.
Show resolved Hide resolved
if mapper is None:
result.append(obj)
else:
mapper._materialized_idx = idx
result.append(mapper)

if len(ids) == 0:
return result

materialized = ray.get(list(ids.keys()))
for i in range(len(result)):
if isinstance((obj := result[i]), ObjectRefTypes):
if isinstance(obj, ObjectRefMapper):
result[i] = obj.map(materialized[obj._materialized_idx])
else:
result[i] = materialized[ids[obj]]
return result

@classmethod
def put(cls, data, **kwargs):
Expand Down Expand Up @@ -161,12 +210,18 @@
obj_ids : list, scalar
num_returns : int, optional
"""
if not isinstance(obj_ids, list):
obj_ids = [obj_ids]
unique_ids = list(set(obj_ids))
if num_returns is None:
num_returns = len(unique_ids)
ray.wait(unique_ids, num_returns=num_returns)
if not isinstance(obj_ids, Sequence):
obj_ids = list(obj_ids)

Check warning on line 214 in modin/core/execution/ray/common/engine_wrapper.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/engine_wrapper.py#L214

Added line #L214 was not covered by tests

ids = set()
for obj in obj_ids:
if isinstance(obj, ObjectRefMapper):
obj = obj.get()

Check warning on line 219 in modin/core/execution/ray/common/engine_wrapper.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/engine_wrapper.py#L219

Added line #L219 was not covered by tests
if isinstance(obj, RayObjectRefTypes):
ids.add(obj)

if num_ids := len(ids):
ray.wait(list(ids), num_returns=num_returns or num_ids)
YarShev marked this conversation as resolved.
Show resolved Hide resolved


@ray.remote
Expand Down Expand Up @@ -218,3 +273,35 @@
bool
"""
return self.events[event_idx].is_set()


class ObjectRefMapper:
"""Map the materialized object to a different value."""

def get(self):
"""
Get an object reference or the cached, previously mapped value.

Returns
-------
ray.ObjectRef or object
"""
raise NotImplementedError()

def map(self, materialized):
"""
Map the materialized object.

Parameters
----------
materialized : object

Returns
-------
object
"""
raise NotImplementedError()


RayObjectRefTypes = (ray.ObjectRef, ClientObjectRef)
ObjectRefTypes = (*RayObjectRefTypes, ObjectRefMapper)
5 changes: 2 additions & 3 deletions modin/core/execution/ray/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import psutil
import ray
from packaging import version
from ray.util.client.common import ClientObjectRef

from modin.config import (
CIAWSAccessKeyID,
Expand All @@ -40,7 +39,7 @@
from modin.core.execution.utils import set_env
from modin.error_message import ErrorMessage

from .engine_wrapper import RayWrapper
from .engine_wrapper import ObjectRefTypes, RayWrapper

_OBJECT_STORE_TO_SYSTEM_MEMORY_RATIO = 0.6
# This constant should be in sync with the limit in ray, which is private,
Expand All @@ -50,7 +49,7 @@

_RAY_IGNORE_UNHANDLED_ERRORS_VAR = "RAY_IGNORE_UNHANDLED_ERRORS"

ObjectIDType = (ray.ObjectRef, ClientObjectRef)
ObjectIDType = ObjectRefTypes


def initialize_ray(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

from modin.config import LazyExecution
from modin.core.dataframe.pandas.partitioning.partition import PandasDataframePartition
from modin.core.execution.ray.common import RayWrapper
from modin.core.execution.ray.common import ObjectRefMapper, RayWrapper
from modin.core.execution.ray.common.deferred_execution import (
DeferredExecution,
MetaList,
MetaListMapper,
)
from modin.core.execution.ray.common.utils import ObjectIDType
from modin.logging import get_logger
from modin.pandas.indexing import compute_sliced_len
from modin.utils import _inherit_docstrings

compute_sliced_len = ray.remote(compute_sliced_len)


class PandasOnRayDataframePartition(PandasDataframePartition):
"""
Expand Down Expand Up @@ -199,25 +198,21 @@
self._is_debug(log) and log.debug(f"ENTER::Partition.mask::{self._identity}")
new_obj = super().mask(row_labels, col_labels)
if isinstance(row_labels, slice) and isinstance(
self._length_cache, ObjectIDType
(len_cache := self._length_cache), ObjectIDType
dchigarev marked this conversation as resolved.
Show resolved Hide resolved
):
if row_labels == slice(None):
# fast path - full axis take
new_obj._length_cache = self._length_cache
new_obj._length_cache = len_cache
else:
new_obj._length_cache = compute_sliced_len.remote(
row_labels, self._length_cache
)
new_obj._length_cache = SlicedLenMapper(len_cache, row_labels)
if isinstance(col_labels, slice) and isinstance(
self._width_cache, ObjectIDType
(width_cache := self._width_cache), ObjectIDType
):
if col_labels == slice(None):
# fast path - full axis take
new_obj._width_cache = self._width_cache
new_obj._width_cache = width_cache
else:
new_obj._width_cache = compute_sliced_len.remote(
col_labels, self._width_cache
)
new_obj._width_cache = SlicedLenMapper(width_cache, col_labels)

Check warning on line 215 in modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py#L215

Added line #L215 was not covered by tests
self._is_debug(log) and log.debug(f"EXIT::Partition.mask::{self._identity}")
return new_obj

Expand Down Expand Up @@ -421,3 +416,51 @@


LazyExecution.subscribe(_configure_lazy_exec)


class SlicedLenMapper(ObjectRefMapper):
"""
Used by mask() for the slilced length computation.

Parameters
----------
ref : ObjectIDType
slc : slice
AndreyPavlenko marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, ref: ObjectIDType, slc: slice):
self.ref = ref
self.slc = slc

def get(self):
"""
Get the sliced length or object ref if not materialized.

Returns
-------
int or ObjectIDType
"""
if isinstance(self.ref, MetaListMapper):
len_or_ref = self.ref.get()
return (

Check warning on line 445 in modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py#L443-L445

Added lines #L443 - L445 were not covered by tests
compute_sliced_len(self.slc, len_or_ref)
if isinstance(len_or_ref, int)
else len_or_ref
)
return self.ref

Check warning on line 450 in modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py#L450

Added line #L450 was not covered by tests

def map(self, materialized):
"""
Get the sliced length.

Parameters
----------
materialized : list or int

Returns
-------
int
"""
if isinstance(self.ref, MetaListMapper):
materialized = self.ref.map(materialized)
return compute_sliced_len(self.slc, materialized)

Check warning on line 466 in modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py#L464-L466

Added lines #L464 - L466 were not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class PandasOnRayDataframePartitionManager(GenericRayDataframePartitionManager):
_column_partitions_class = PandasOnRayDataframeColumnPartition
_row_partition_class = PandasOnRayDataframeRowPartition
_execution_wrapper = RayWrapper
materialize_futures = RayWrapper.materialize

@classmethod
def wait_partitions(cls, partitions):
Expand Down