Skip to content

Commit

Permalink
chore: Improved factory type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
fgmacedo committed Aug 2, 2023
1 parent d16ee12 commit cbef007
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 30 deletions.
66 changes: 36 additions & 30 deletions statemachine/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from uuid import uuid4

Expand Down Expand Up @@ -31,7 +32,13 @@ def __init__(cls, name: str, bases: Tuple[type], attrs: Dict[str, Any]):
cls.add_inherited(bases)
cls.add_from_attributes(attrs)

cls._set_special_states()
try:
cls.initial_state: State = next(s for s in cls.states if s.initial)
except StopIteration:
cls.initial_state = None # Abstract SM still don't have states

cls.final_states: List[State] = [state for state in cls.states if state.final]

cls._check()

if TYPE_CHECKING:
Expand All @@ -40,35 +47,6 @@ def __init__(cls, name: str, bases: Tuple[type], attrs: Dict[str, Any]):
def __getattr__(self, attribute: str) -> Any:
...

def _set_special_states(cls):
if not cls.states:
return
initials = [s for s in cls.states if s.initial]
if len(initials) != 1:
raise InvalidDefinition(
_(
"There should be one and only one initial state. "
"Your currently have these: {!r}"
).format([s.id for s in initials])
)
cls.initial_state = initials[0]
cls.final_states = [state for state in cls.states if state.final]

def _disconnected_states(cls, starting_state):
visitable_states = set(visit_connected_states(starting_state))
return set(cls.states) - visitable_states

def _check_disconnected_state(cls):
disconnected_states = cls._disconnected_states(cls.initial_state)
if disconnected_states:
raise InvalidDefinition(
_(
"There are unreachable states. "
"The statemachine graph should have a single component. "
"Disconnected states: {}"
).format([s.id for s in disconnected_states])
)

def _check(cls):
has_states = bool(cls.states)
has_events = bool(cls._events)
Expand All @@ -85,8 +63,21 @@ def _check(cls):
if not has_events:
raise InvalidDefinition(_("There are no events."))

cls._check_initial_state()
cls._check_final_states()
cls._check_disconnected_state()

def _check_initial_state(cls):
initials = [s for s in cls.states if s.initial]
if len(initials) != 1:
raise InvalidDefinition(
_(
"There should be one and only one initial state. "
"Your currently have these: {!r}"
).format([s.id for s in initials])
)

def _check_final_states(cls):
final_state_with_invalid_transitions = [
state for state in cls.final_states if state.transitions
]
Expand All @@ -98,6 +89,21 @@ def _check(cls):
).format([s.id for s in final_state_with_invalid_transitions])
)

def _disconnected_states(cls, starting_state):
visitable_states = set(visit_connected_states(starting_state))
return set(cls.states) - visitable_states

def _check_disconnected_state(cls):
disconnected_states = cls._disconnected_states(cls.initial_state)
if disconnected_states:
raise InvalidDefinition(
_(
"There are unreachable states. "
"The statemachine graph should have a single component. "
"Disconnected states: {}"
).format([s.id for s in disconnected_states])
)

def add_inherited(cls, bases):
for base in bases:
for state in getattr(base, "states", []):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ class CampaignMachine(StateMachine):
deliver = producing.to(closed)


def test_machine_should_activate_initial_state():
class CampaignMachine(StateMachine):
"A workflow machine"
producing = State()
closed = State()
draft = State(initial=True)

add_job = draft.to(draft) | producing.to(producing)
produce = draft.to(producing)
deliver = producing.to(closed)

sm = CampaignMachine()

assert sm.current_state == sm.draft
assert sm.current_state.is_active


def test_machine_should_not_allow_transitions_from_final_state():
with pytest.raises(exceptions.InvalidDefinition):

Expand Down

0 comments on commit cbef007

Please sign in to comment.