Skip to content

Commit

Permalink
feat: Allow using events on callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
fgmacedo committed Feb 24, 2023
1 parent 9a64f4b commit 37fd275
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 72 deletions.
12 changes: 12 additions & 0 deletions docs/actions.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ Use the `enter` or `exit` params available on the `State` constructor.

```

```{hint}
It's also possible to use an event name as action.
**Be careful to not introduce recursion errors** that will raise `RecursionError` exception.
```

### Bind state actions using decorator syntax


Expand Down Expand Up @@ -213,6 +219,12 @@ model, using the patterns:

```

```{hint}
It's also possible to use an event name as action to chain transitions.
**Be careful to not introduce recursion errors**, like `loop = initial.to.itself(after="loop")`, that will raise `RecursionError` exception.
```

### Bind event actions using decorator syntax

The action will be registered for every {ref}`transition` associated with the event.
Expand Down
1 change: 1 addition & 0 deletions docs/releases/2.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ See {ref}`internal transition` for more details.
guards using decorators is now possible.
- [#331](https://github.com/fgmacedo/python-statemachine/pull/331): Added a way to generate diagrams using [QuickChart.io](https://quickchart.io) instead of GraphViz. See {ref}`diagrams` for more details.
- [#353](https://github.com/fgmacedo/python-statemachine/pull/353): Support for abstract state machine classes, so you can subclass `StateMachine` to add behavior on your own base class. Abstract `StateMachine` cannot be instantiated.
- [#355](https://github.com/fgmacedo/python-statemachine/pull/355): Now is possible to trigger an event as an action by registering the event name as the callback param.

## Bugfixes in 2.0

Expand Down
11 changes: 10 additions & 1 deletion statemachine/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_func_by_attr(attr, *configs):
return func, config.obj


def ensure_callable(attr, *objects):
def ensure_callable(attr, *objects): # noqa: C901
"""Ensure that `attr` is a callable, if not, tries to retrieve one from any of the given
`objects`.
Expand Down Expand Up @@ -66,6 +66,15 @@ def wrapper(*args, **kwargs):

return wrapper

if getattr(func, "_is_sm_event", False):
"Events already have the 'machine' parameter defined."

def wrapper(*args, **kwargs):
kwargs.pop("machine")
return func(*args, **kwargs)

return wrapper

return SignatureAdapter.wrap(func)


Expand Down
40 changes: 20 additions & 20 deletions statemachine/event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .event_data import EventData
from .event_data import TriggerData
from .exceptions import TransitionNotAllowed


Expand All @@ -13,38 +14,36 @@ def __call__(self, machine, *args, **kwargs):
return self.trigger(machine, *args, **kwargs)

def trigger(self, machine, *args, **kwargs):
event_data = EventData(machine, self.name, *args, **kwargs)

def trigger_wrapper():
"""Wrapper that captures event_data as closure."""
return self._trigger(event_data)
trigger_data = TriggerData(
machine=machine,
event=self.name,
args=args,
kwargs=kwargs,
)
return self._trigger(trigger_data)

return machine._process(trigger_wrapper)

def _trigger(self, event_data):
event_data.source = event_data.machine.current_state
event_data.state = event_data.machine.current_state
event_data.model = event_data.machine.model

try:
self._process(event_data)
except Exception as error:
event_data.error = error
# TODO: Log errors
# TODO: Allow exception handlers
raise
def _trigger(self, trigger_data: TriggerData):
event_data = self._process(trigger_data)
return event_data.result

def _process(self, event_data):
for transition in event_data.source.transitions:
if not transition.match(event_data.event):
def _process(self, trigger_data: TriggerData):
state = trigger_data.machine.current_state
for transition in state.transitions:
if not transition.match(trigger_data.event):
continue
event_data._set_transition(transition)

event_data = EventData(trigger_data=trigger_data, transition=transition)
if transition.execute(event_data):
event_data.executed = True
break
else:
raise TransitionNotAllowed(event_data.event, event_data.state)
raise TransitionNotAllowed(trigger_data.event, state)

return event_data


def trigger_event_factory(event):
Expand All @@ -56,5 +55,6 @@ def trigger_event(self, *args, **kwargs):

trigger_event.name = event
trigger_event.identifier = event
trigger_event._is_sm_event = True

return trigger_event
90 changes: 62 additions & 28 deletions statemachine/event_data.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,77 @@
from dataclasses import dataclass
from dataclasses import field
from typing import TYPE_CHECKING
from typing import Any

if TYPE_CHECKING:
from .state import State
from .statemachine import StateMachine
from .transition import Transition


@dataclass
class TriggerData:
machine: "StateMachine"
event: str
"""The Event that was triggered."""

model: Any = field(init=False)
"""A reference to the underlying model that holds the current State."""

args: tuple = field(default_factory=tuple)
"""All positional arguments provided on the Event."""

kwargs: dict = field(default_factory=dict)
"""All keyword arguments provided on the Event."""

def __post_init__(self):
self.model = self.machine.model


@dataclass
class EventData:
def __init__(self, machine: "StateMachine", event: str, *args, **kwargs):
self.machine = machine
self.event = event
self.source = kwargs.get("source", None)
self.state = kwargs.get("state", None)
self.model = kwargs.get("model", None)
self.executed = False
self.transition: Transition | None = None
self.target = None
self._set_transition(kwargs.get("transition", None))

# runtime and error
self.args = args
self.kwargs = kwargs
self.error = None
self.result = None

def __repr__(self):
return f"{type(self).__name__}({self.__dict__!r})"

def _set_transition(self, transition: "Transition"):
self.transition = transition
self.target = getattr(transition, "target", None)
trigger_data: TriggerData
transition: "Transition"
"""The Transition instance that was activated by the Event."""

state: "State" = field(init=False)
"""The current State of the state machine."""

source: "State" = field(init=False)
"""The State the state machine was in when the Event started."""

target: "State" = field(init=False)
"""The destination State of the transition."""

result: "Any | None" = None
executed: bool = False

def __post_init__(self):
self.state = self.transition.source
self.source = self.transition.source
self.target = self.transition.target

@property
def machine(self):
return self.trigger_data.machine

@property
def event(self):
return self.trigger_data.event

@property
def args(self):
return self.trigger_data.args

@property
def extended_kwargs(self):
kwargs = self.kwargs.copy()
kwargs = self.trigger_data.kwargs.copy()
kwargs["event_data"] = self
kwargs["event"] = self.event
kwargs["source"] = self.source
kwargs["state"] = self.state
kwargs["model"] = self.model
kwargs["machine"] = self.trigger_data.machine
kwargs["event"] = self.trigger_data.event
kwargs["model"] = self.trigger_data.model
kwargs["transition"] = self.transition
kwargs["state"] = self.state
kwargs["source"] = self.source
kwargs["target"] = self.target
return kwargs
37 changes: 18 additions & 19 deletions statemachine/statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .dispatcher import ObjectConfig
from .dispatcher import resolver_factory
from .event import Event
from .event_data import TriggerData
from .event_data import EventData
from .exceptions import InvalidStateValue
from .exceptions import InvalidDefinition
Expand Down Expand Up @@ -62,30 +63,29 @@ def _activate_initial_state(self, initial_transition):
initial_transition.before.clear()
initial_transition.on.clear()
initial_transition.after.clear()

event_data = EventData(
self,
initial_transition.event,
trigger_data=TriggerData(
machine=self,
event=initial_transition.event,
),
transition=initial_transition,
)
self._activate(event_data)

def _get_protected_attrs(self):
return (
{
"_abstract",
"model",
"state_field",
"start_value",
"initial_state",
"final_states",
"states",
"_events",
"states_map",
"send",
}
| {s.id for s in self.states}
| set(self._events.keys())
)
return {
"_abstract",
"model",
"state_field",
"start_value",
"initial_state",
"final_states",
"states",
"_events",
"states_map",
"send",
} | {s.id for s in self.states}

def _visit_states_and_transitions(self, visitor):
for state in self.states:
Expand Down Expand Up @@ -165,7 +165,6 @@ def _process(self, trigger):

def _activate(self, event_data: EventData):
transition = event_data.transition
assert transition is not None
source = event_data.state
target = transition.target

Expand Down
7 changes: 5 additions & 2 deletions statemachine/transition.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from functools import partial
from typing import TYPE_CHECKING

from .callbacks import Callbacks
from .callbacks import ConditionWrapper
from .event_data import EventData
from .events import Events
from .exceptions import InvalidDefinition

if TYPE_CHECKING:
from .event_data import EventData


class Transition:
"""A transition holds reference to the source and target state.
Expand Down Expand Up @@ -119,7 +122,7 @@ def events(self):
def add_event(self, value):
self._events.add(value)

def execute(self, event_data: EventData):
def execute(self, event_data: "EventData"):
self.validators.call(*event_data.args, **event_data.extended_kwargs)
if not self._eval_cond(event_data):
return False
Expand Down
4 changes: 2 additions & 2 deletions tests/examples/order_control_rich_model_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self):
def payments_enough(self, amount):
return sum(self.payments) + amount >= self.order_total

def add_to_order(self, amount):
def before_add_to_order(self, amount):
self.order_total += amount
return self.order_total

Expand All @@ -40,7 +40,7 @@ class OrderControl(StateMachine):
shipping = State()
completed = State(final=True)

add_to_order = waiting_for_payment.to(waiting_for_payment, before="add_to_order")
add_to_order = waiting_for_payment.to(waiting_for_payment)
receive_payment = waiting_for_payment.to(
processing, cond="payments_enough"
) | waiting_for_payment.to(waiting_for_payment, unless="payments_enough")
Expand Down

0 comments on commit 37fd275

Please sign in to comment.