Skip to content

Commit

Permalink
fix: Reconstructing callable references when using deepcopy over a SM. (
Browse files Browse the repository at this point in the history
  • Loading branch information
fgmacedo committed Apr 18, 2024
1 parent f236cad commit 16f2a73
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
3 changes: 3 additions & 0 deletions statemachine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def register(self, callbacks: CallbackMetaList, resolver):
executor_list.add(callbacks, resolver)
return executor_list

def clear(self):
self._registry.clear()

def __getitem__(self, callbacks: CallbackMetaList) -> CallbacksExecutor:
return self._registry[callbacks]

Expand Down
40 changes: 24 additions & 16 deletions statemachine/statemachine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
from copy import deepcopy
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -78,9 +79,9 @@ def __init__(
if self._abstract:
raise InvalidDefinition(_("There are no states or transitions."))

initial_transition = Transition(None, self._get_initial_state(), event="__initial__")
self._setup(initial_transition)
self._activate_initial_state(initial_transition)
self._initial_transition = Transition(None, self._get_initial_state(), event="__initial__")
self._setup()
self._activate_initial_state()

def __init_subclass__(cls, strict_states: bool = False):
cls._strict_states = strict_states
Expand All @@ -98,27 +99,39 @@ def __repr__(self):
f"current_state={current_state_id!r})"
)

def __deepcopy__(self, memo):
deepcopy_method = self.__deepcopy__
self.__deepcopy__ = None
try:
cp = deepcopy(self, memo)
finally:
self.__deepcopy__ = deepcopy_method
cp.__deepcopy__ = deepcopy_method
cp._callbacks_registry.clear()
cp._setup()
return cp

def _get_initial_state(self):
current_state_value = self.start_value if self.start_value else self.initial_state.value
try:
return self.states_map[current_state_value]
except KeyError as err:
raise InvalidStateValue(current_state_value) from err

def _activate_initial_state(self, initial_transition):
def _activate_initial_state(self):
if self.current_state_value is None:
# send an one-time event `__initial__` to enter the current state.
# current_state = self.current_state
initial_transition.before.clear()
initial_transition.on.clear()
initial_transition.after.clear()
self._initial_transition.before.clear()
self._initial_transition.on.clear()
self._initial_transition.after.clear()

event_data = EventData(
trigger_data=TriggerData(
machine=self,
event=initial_transition.event,
event=self._initial_transition.event,
),
transition=initial_transition,
transition=self._initial_transition,
)
self._activate(event_data)

Expand All @@ -142,12 +155,7 @@ def _visit_states_and_transitions(self, visitor):
for transition in state.transitions:
visitor(transition)

def _setup(self, initial_transition: Transition):
"""
Args:
initial_transition: A special :ref:`transition` that triggers the enter on the
`initial` :ref:`State`.
"""
def _setup(self):
machine = ObjectConfig.from_obj(self, skip_attrs=self._get_protected_attrs())
model = ObjectConfig.from_obj(self.model, skip_attrs={self.state_field})
default_resolver = resolver_factory(machine, model)
Expand All @@ -162,7 +170,7 @@ def setup_visitor(visited):

self._visit_states_and_transitions(setup_visitor)

initial_transition._setup(register)
self._initial_transition._setup(register)

def _build_observers_visitor(self, *observers):
registry_callbacks = [
Expand Down
25 changes: 25 additions & 0 deletions tests/test_deepcopy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from copy import deepcopy

import pytest

from statemachine import State
from statemachine import StateMachine
from statemachine.exceptions import TransitionNotAllowed


def test_deepcopy():
class MySM(StateMachine):
draft = State("Draft", initial=True, value="draft")
published = State("Published", value="published")

publish = draft.to(published, cond="let_me_be_visible")

class MyModel:
let_me_be_visible = False

sm = MySM(MyModel())

sm2 = deepcopy(sm)

with pytest.raises(TransitionNotAllowed):
sm2.send("publish")

0 comments on commit 16f2a73

Please sign in to comment.