Skip to content

Commit

Permalink
Reworked reference counting
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyPavlenko committed Feb 2, 2024
1 parent a98ede9 commit 95392a9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 53 deletions.
95 changes: 49 additions & 46 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DeferredExecution:
If the input is a DeferredExecution node, it is executed first and the execution
output is used as the input for this one. All the executions are performed in a
single batch (i.e. using a single remote call) and the results are saved in all
the nodes with ref_counter > 1.
the nodes that have multiple subscribers.
Parameters
----------
Expand Down Expand Up @@ -98,43 +98,39 @@ def __init__(
kwargs: Dict[str, Any],
num_returns=1,
):
ref = DeferredExecution._ref
ref(data)
if isinstance(data, DeferredExecution):
data.subscribe()
self.data = data
self.func = func
self.args = args
self.kwargs = kwargs
self.num_returns = num_returns
self.flat_args = sum(ref(a) for a in args) == 0
self.flat_kwargs = sum(ref(a) for a in kwargs.values()) == 0
self.ref_counter = 1
self.flat_args = self._flat_args(args)
self.flat_kwargs = self._flat_args(kwargs.values())
self.subscribers = 0

@staticmethod
def _ref(obj):
@classmethod
def _flat_args(cls, args: Iterable):
"""
Increment the `ref_counter` if `obj` is a `DeferredExecution`.
If `obj` is a `ListOrTuple`, this method is called recursively for each element.
Check if the arguments list is flat and subscribe to all `DeferredExecution` objects.
Parameters
----------
obj : Any
args : Iterable
Returns
-------
int
The number of `ListOrTuple` or `DeferredExecution` objects found.
"""
if isinstance(obj, DeferredExecution):
obj.ref_count(1)
return 1
if isinstance(obj, ListOrTuple):
ref = DeferredExecution._ref
result = 1
for o in obj:
result += ref(o)
return result
return 0
bool
"""
flat = True
for arg in args:
if isinstance(arg, DeferredExecution):
flat = False
arg.subscribe()
elif isinstance(arg, ListOrTuple):
flat = False
cls._flat_args(arg)
return flat

def exec(
self,
Expand Down Expand Up @@ -166,6 +162,9 @@ def exec(
self._set_result(result, meta, 0)
return result, meta, 0

# If there are no subscribers, we still need the result here. We don't need to decrement
# it back. After the execution, the result is saved and the counter has no effect.
self.subscribers += 2
consumers, output = self._deconstruct()
# The last result is the MetaList, so adding +1 here.
num_returns = sum(c.num_returns for c in consumers) + 1
Expand Down Expand Up @@ -195,16 +194,22 @@ def has_result(self):
"""
return not hasattr(self, "func")

def ref_count(self, diff: int):
def subscribe(self):
"""
Increment the `ref_counter`.
Increment the `subscribers` counter.
Parameters
----------
diff : int
Subscriber is any instance that could trigger the execution of this task.
In case of a multiple subscribers, the execution could be triggerred multiple
times. To prevent the multiple executions, the execution result is returned
from the worker and saved in this instance. Subsequent calls to `execute()`
return the previously saved result.
"""
self.ref_counter += diff
assert self.ref_counter >= 0
self.subscribers += 1

def unsubscribe(self):
"""Decrement the `subscribers` counter."""
self.subscribers -= 1
assert self.subscribers >= 0

def _deconstruct(self) -> Tuple[List["DeferredExecution"], List[Any]]:
"""
Expand Down Expand Up @@ -243,7 +248,7 @@ def _deconstruct(self) -> Tuple[List["DeferredExecution"], List[Any]]:
-------
tuple of list
* The first list is the result consumers.
If a DeferredExecution has multiple references, the execution result
If a DeferredExecution has multiple subscribers, the execution result
should be returned and saved in order to avoid duplicate executions.
These DeferredExecution tasks are added to this list and, after the
execution, the results are passed to the ``_set_result()`` method of
Expand All @@ -253,7 +258,6 @@ def _deconstruct(self) -> Tuple[List["DeferredExecution"], List[Any]]:
stack = []
result_consumers = []
output = []
self.ref_count(1)
# Using stack and generators to avoid the ``RecursionError``s.
stack.append(self._deconstruct_chain(self, output, stack, result_consumers))
while stack:
Expand Down Expand Up @@ -296,12 +300,17 @@ def _deconstruct_chain(
out_append = output.append
out_extend = output.extend
while True:
de.ref_count(-1)
de.unsubscribe()
if (out_pos := getattr(de, "out_pos", None)) and not de.has_result:
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
if de.ref_counter == 0:
if de.subscribers == 0:
# We may have subscribed to the same node multiple times.
# It could happen, for example, if it's passed to the args
# multiple times, or it's one of the parent nodes and also
# passed to the args. In this case, there are no multiple
# subscribers, and we don't need to return the result.
output[out_pos + 1] = 0
result_consumers.remove(de)
break
Expand All @@ -319,7 +328,6 @@ def _deconstruct_chain(
stack.append(de)
de = data

assert stack and isinstance(stack[-1], DeferredExecution)
while stack and isinstance(stack[-1], DeferredExecution):
de: DeferredExecution = stack.pop()
args = de.args
Expand All @@ -345,7 +353,7 @@ def _deconstruct_chain(
out_extend(kwargs)

out_append(0) # Placeholder for ref id
if de.ref_counter > 0:
if de.subscribers > 0:
# Ref id. This is the index in the output list.
de.out_pos = len(output) - 1
result_consumers.append(de)
Expand Down Expand Up @@ -382,19 +390,14 @@ def _deconstruct_list(
for obj in lst:
if isinstance(obj, DeferredExecution):
if out_pos := getattr(obj, "out_pos", None):
obj.ref_count(-1)
obj.unsubscribe()
if obj.has_result:
if isinstance(obj.data, ListOrTuple):
yield cls._deconstruct_list(
obj.data, output, stack, result_consumers, out_append
)
else:
out_append(obj.data)
out_append(obj.data)
else:
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
if obj.ref_counter == 0:
if obj.subscribers == 0:
output[out_pos + 1] = 0
result_consumers.remove(obj)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(
meta_offset: int = 0,
):
super().__init__()
if isinstance(data, DeferredExecution):
data.subscribe()
self._data_ref = data
# The metadata is stored in the MetaList at 0 offset. If the data is
# a DeferredExecution, the _meta will be replaced with the list, returned
Expand All @@ -93,9 +95,9 @@ def __init__(
)

def __del__(self):
"""Decrement the reference counter."""
"""Unsubscribe from DeferredExecution."""
if isinstance(self._data_ref, DeferredExecution):
self._data_ref.ref_count(-1)
self._data_ref.unsubscribe()

def apply(self, func: Callable, *args, **kwargs):
"""
Expand Down Expand Up @@ -151,7 +153,6 @@ def drain_call_queue(self):
f"ENTER::Partition.drain_call_queue::{self._identity}"
)
self._data_ref, self._meta, self._meta_offset = data.exec()
data.ref_count(-1)
self._is_debug(log) and log.debug(
f"EXIT::Partition.drain_call_queue::{self._identity}"
)
Expand All @@ -170,11 +171,8 @@ def __copy__(self):
PandasOnRayDataframePartition
A copy of this partition.
"""
data = self._data_ref
if isinstance(data, DeferredExecution):
data.ref_count(1)
return self.__constructor__(
data,
self._data_ref,
meta=self._meta,
meta_offset=self._meta_offset,
)
Expand Down

0 comments on commit 95392a9

Please sign in to comment.