Skip to content

Commit

Permalink
Reworked reference counting
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyPavlenko committed Feb 2, 2024
1 parent a98ede9 commit d57f67b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 68 deletions.
126 changes: 65 additions & 61 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DeferredExecution:
If the input is a DeferredExecution node, it is executed first and the execution
output is used as the input for this one. All the executions are performed in a
single batch (i.e. using a single remote call) and the results are saved in all
the nodes with ref_counter > 1.
the nodes that have multiple subscribers.
Parameters
----------
Expand Down Expand Up @@ -98,43 +98,39 @@ def __init__(
kwargs: Dict[str, Any],
num_returns=1,
):
ref = DeferredExecution._ref
ref(data)
if isinstance(data, DeferredExecution):
data.subscribe()
self.data = data
self.func = func
self.args = args
self.kwargs = kwargs
self.num_returns = num_returns
self.flat_args = sum(ref(a) for a in args) == 0
self.flat_kwargs = sum(ref(a) for a in kwargs.values()) == 0
self.ref_counter = 1
self.flat_args = self._flat_args(args)
self.flat_kwargs = self._flat_args(kwargs.values())
self.subscribers = 0

@staticmethod
def _ref(obj):
@classmethod
def _flat_args(cls, args: Iterable):
"""
Increment the `ref_counter` if `obj` is a `DeferredExecution`.
If `obj` is a `ListOrTuple`, this method is called recursively for each element.
Check if the arguments list is flat and subscribe to all `DeferredExecution` objects.
Parameters
----------
obj : Any
args : Iterable
Returns
-------
int
The number of `ListOrTuple` or `DeferredExecution` objects found.
"""
if isinstance(obj, DeferredExecution):
obj.ref_count(1)
return 1
if isinstance(obj, ListOrTuple):
ref = DeferredExecution._ref
result = 1
for o in obj:
result += ref(o)
return result
return 0
bool
"""
flat = True
for arg in args:
if isinstance(arg, DeferredExecution):
flat = False
arg.subscribe()

Check warning on line 129 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L128-L129

Added lines #L128 - L129 were not covered by tests
elif isinstance(arg, ListOrTuple):
flat = False
cls._flat_args(arg)
return flat

def exec(
self,
Expand Down Expand Up @@ -166,6 +162,9 @@ def exec(
self._set_result(result, meta, 0)
return result, meta, 0

# If there are no subscribers, we still need the result here. We don't need to decrement
# it back. After the execution, the result is saved and the counter has no effect.
self.subscribers += 2
consumers, output = self._deconstruct()
# The last result is the MetaList, so adding +1 here.
num_returns = sum(c.num_returns for c in consumers) + 1
Expand Down Expand Up @@ -195,16 +194,22 @@ def has_result(self):
"""
return not hasattr(self, "func")

def ref_count(self, diff: int):
def subscribe(self):
"""
Increment the `ref_counter`.
Increment the `subscribers` counter.
Parameters
----------
diff : int
Subscriber is any instance that could trigger the execution of this task.
In case of a multiple subscribers, the execution could be triggerred multiple
times. To prevent the multiple executions, the execution result is returned
from the worker and saved in this instance. Subsequent calls to `execute()`
return the previously saved result.
"""
self.ref_counter += diff
assert self.ref_counter >= 0
self.subscribers += 1

def unsubscribe(self):
"""Decrement the `subscribers` counter."""
self.subscribers -= 1
assert self.subscribers >= 0

def _deconstruct(self) -> Tuple[List["DeferredExecution"], List[Any]]:
"""
Expand All @@ -214,7 +219,7 @@ def _deconstruct(self) -> Tuple[List["DeferredExecution"], List[Any]]:
materialization before passing the list to a Ray worker.
The format of the list is the following:
<input object> sequence<<function> <n><args> <n><kwargs> <ref> <res>>...
<input object> sequence<<function> <n><args> <n><kwargs> <ref> <nret>>...
If <n> before <args> is >= 0, then the next n objects are the function arguments.
If it is -1, it means that the method arguments contain list and/or
DeferredExecution (chain) objects. In this case the next values are read
Expand All @@ -235,15 +240,17 @@ def _deconstruct(self) -> Tuple[List["DeferredExecution"], List[Any]]:
chain referring to the execution result of this method and, thus, it must
be saved so that other chains could retrieve the object by the id.
<res> is a 'get result' flag. If it's True, then the method execution
result must not only be passed to the next method in the chain, but also
returned to the caller. The object length and width are added to the meta list.
<nret> field contains either the `num_returns` value or 0. If it's 0, the
execution result is not returned, but is just passed to the next task in the
chain. If it's 1, the result is returned as is. Otherwise, it's expected that
the result is iterable and the specified number of values is returned from
the iterator. The values lengths and widths are added to the meta list.
Returns
-------
tuple of list
* The first list is the result consumers.
If a DeferredExecution has multiple references, the execution result
If a DeferredExecution has multiple subscribers, the execution result
should be returned and saved in order to avoid duplicate executions.
These DeferredExecution tasks are added to this list and, after the
execution, the results are passed to the ``_set_result()`` method of
Expand All @@ -253,7 +260,6 @@ def _deconstruct(self) -> Tuple[List["DeferredExecution"], List[Any]]:
stack = []
result_consumers = []
output = []
self.ref_count(1)
# Using stack and generators to avoid the ``RecursionError``s.
stack.append(self._deconstruct_chain(self, output, stack, result_consumers))
while stack:
Expand Down Expand Up @@ -296,12 +302,17 @@ def _deconstruct_chain(
out_append = output.append
out_extend = output.extend
while True:
de.ref_count(-1)
de.unsubscribe()
if (out_pos := getattr(de, "out_pos", None)) and not de.has_result:
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
if de.ref_counter == 0:
if de.subscribers == 0:

Check warning on line 310 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L307-L310

Added lines #L307 - L310 were not covered by tests
# We may have subscribed to the same node multiple times.
# It could happen, for example, if it's passed to the args
# multiple times, or it's one of the parent nodes and also
# passed to the args. In this case, there are no multiple
# subscribers, and we don't need to return the result.
output[out_pos + 1] = 0
result_consumers.remove(de)
break

Check warning on line 318 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L316-L318

Added lines #L316 - L318 were not covered by tests
Expand All @@ -319,7 +330,6 @@ def _deconstruct_chain(
stack.append(de)
de = data

assert stack and isinstance(stack[-1], DeferredExecution)
while stack and isinstance(stack[-1], DeferredExecution):
de: DeferredExecution = stack.pop()
args = de.args
Expand All @@ -345,11 +355,11 @@ def _deconstruct_chain(
out_extend(kwargs)

out_append(0) # Placeholder for ref id
if de.ref_counter > 0:
if de.subscribers > 0:
# Ref id. This is the index in the output list.
de.out_pos = len(output) - 1
result_consumers.append(de)
out_append(1) # Return result for this node
out_append(de.num_returns) # Return result for this node
else:
out_append(0) # Do not return result for this node

Expand Down Expand Up @@ -382,19 +392,14 @@ def _deconstruct_list(
for obj in lst:
if isinstance(obj, DeferredExecution):
if out_pos := getattr(obj, "out_pos", None):
obj.ref_count(-1)
obj.unsubscribe()
if obj.has_result:
if isinstance(obj.data, ListOrTuple):
yield cls._deconstruct_list(
obj.data, output, stack, result_consumers, out_append
)
else:
out_append(obj.data)
out_append(obj.data)

Check warning on line 397 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L394-L397

Added lines #L394 - L397 were not covered by tests
else:
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
if obj.ref_counter == 0:
if obj.subscribers == 0:
output[out_pos + 1] = 0
result_consumers.remove(obj)

Check warning on line 404 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L399-L404

Added lines #L399 - L404 were not covered by tests
else:
Expand Down Expand Up @@ -666,16 +671,15 @@ def construct_chain(

if ref := pop(): # <ref> is not 0 - adding the result to refs
refs[ref] = obj
if pop(): # <res> is True - returning the result
if isinstance(obj, ListOrTuple):
for o in obj:
meta.append(len(o) if hasattr(o, "__len__") else 0)
meta.append(len(o.columns) if hasattr(o, "columns") else 0)
yield o
else:
meta.append(len(obj) if hasattr(obj, "__len__") else 0)
meta.append(len(obj.columns) if hasattr(obj, "columns") else 0)
yield obj
if (num_returns := pop()) == 0:
continue

itr = iter([obj] if num_returns == 1 else obj)
for _ in range(num_returns):
obj = next(itr)
meta.append(len(obj) if hasattr(obj, "__len__") else 0)
meta.append(len(obj.columns) if hasattr(obj, "columns") else 0)
yield obj

@classmethod
def construct_list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(
meta_offset: int = 0,
):
super().__init__()
if isinstance(data, DeferredExecution):
data.subscribe()
self._data_ref = data
# The metadata is stored in the MetaList at 0 offset. If the data is
# a DeferredExecution, the _meta will be replaced with the list, returned
Expand All @@ -93,9 +95,9 @@ def __init__(
)

def __del__(self):
"""Decrement the reference counter."""
"""Unsubscribe from DeferredExecution."""
if isinstance(self._data_ref, DeferredExecution):
self._data_ref.ref_count(-1)
self._data_ref.unsubscribe()

def apply(self, func: Callable, *args, **kwargs):
"""
Expand Down Expand Up @@ -151,7 +153,6 @@ def drain_call_queue(self):
f"ENTER::Partition.drain_call_queue::{self._identity}"
)
self._data_ref, self._meta, self._meta_offset = data.exec()
data.ref_count(-1)
self._is_debug(log) and log.debug(
f"EXIT::Partition.drain_call_queue::{self._identity}"
)
Expand All @@ -170,11 +171,8 @@ def __copy__(self):
PandasOnRayDataframePartition
A copy of this partition.
"""
data = self._data_ref
if isinstance(data, DeferredExecution):
data.ref_count(1)
return self.__constructor__(
data,
self._data_ref,
meta=self._meta,
meta_offset=self._meta_offset,
)
Expand Down

0 comments on commit d57f67b

Please sign in to comment.