Skip to content

Commit

Permalink
Minor cleanup (documentation & validation)
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed Nov 10, 2017
1 parent aecc2e9 commit d68d5d4
Showing 1 changed file with 51 additions and 17 deletions.
68 changes: 51 additions & 17 deletions coaster/sqlalchemy/statemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,34 @@ def __init__(self, name, statemanager, states):
self.name = name
self.statemanager = statemanager
self.states = []
values = []

# First, ensure all provided states are StateManager instances and associated with the state manager
for state in states:
if not isinstance(state, ManagedState) or state.statemanager != statemanager:
raise ValueError("Invalid state %s for state group %s" (repr(state), repr(self)))

# Second, separate conditional from regular states
regular_states = [s for s in states if not s.validator]
conditional_states = [s for s in states if s.validator]

# Third, add all the regular states and keep a copy of their state values
values = set()
for state in regular_states:
self.states.append(state)
if isinstance(state.value, iterables):
values.update(state.value)
else:
values.add(state.value)

# Fourth, prevent adding a conditional state if the value is already present from a
# regular state. This is an error as the condition will never be tested
for state in conditional_states:
# Prevent grouping of conditional states with their original states
if state.value in values:
state_values = set(state.value if isinstance(state.value, iterables) else [state.value])
if state_values & values: # They overlap
raise ValueError("The value for state %s is already in this state group" % repr(state))
self.states.append(state)
values.append(state.value)
values.update(state_values)

def __repr__(self):
return "%s.%s" % (self.statemanager.name, self.name)
Expand Down Expand Up @@ -287,6 +306,7 @@ def __init__(self, func, statemanager, from_, to, if_=None, data=None):
# by calling add_transition directly
self.transitions = {}
# Repeated use of @StateManager.transition will update this dictionary
# instead of replacing it
self.data = {}
self.add_transition(statemanager, from_, to, if_, data)

Expand All @@ -300,7 +320,10 @@ def add_transition(self, statemanager, from_, to, if_=None, data=None):
elif to.value not in statemanager.lenum:
raise StateTransitionError("To state is not a valid state value: %s" % repr(to))
if data:
if 'name' in data:
raise TypeError("Invalid transition data parameter 'name'")
self.data.update(data)
self.data['name'] = self.name

if if_ is None:
if_ = []
Expand Down Expand Up @@ -340,24 +363,24 @@ def __get__(self, obj, cls=None):


class _StateTransitionWrapper(object):
def __init__(self, st, obj):
self.st = st
def __init__(self, statetransition, obj):
self.statetransition = statetransition
self.obj = obj

@property
def data(self):
"""
Transition descriptive data
"""
return self.st.data
return self.statetransition.data

def _state_invalid(self):
"""
If the state is invalid for the transition, return details on what didn't match
:return: Tuple of (state manager, current state, label for current state)
"""
for statemanager, conditions in self.st.transitions.items():
for statemanager, conditions in self.statetransition.transitions.items():
current_state = getattr(self.obj, statemanager.propname)
if conditions['from'] is None:
state_valid = True
Expand All @@ -381,28 +404,28 @@ def __call__(self, *args, **kwargs):
# Validate that each of the state managers is in the correct state
state_invalid = self._state_invalid()
if state_invalid:
transition_error.send(self.obj, transition=self.st, statemanager=state_invalid[0])
transition_error.send(self.obj, transition=self.statetransition, statemanager=state_invalid[0])
raise StateTransitionError(
u"Invalid state for transition {transition}: {state} = {value} ({label})".format(
transition=self.st.name,
transition=self.statetransition.name,
state=repr(state_invalid[0]),
value=repr(state_invalid[1]),
label=repr(state_invalid[2])
))

# Raise a transition-before signal
transition_before.send(self.obj, transition=self.st)
transition_before.send(self.obj, transition=self.statetransition)
# Call the transition function
try:
result = self.st.func(self.obj, *args, **kwargs)
result = self.statetransition.func(self.obj, *args, **kwargs)
except Exception as e:
transition_exception.send(self.obj, transition=self.st, exception=e)
transition_exception.send(self.obj, transition=self.statetransition, exception=e)
raise
# Change the state for each of the state managers
for statemanager, conditions in self.st.transitions.items():
for statemanager, conditions in self.statetransition.transitions.items():
statemanager._set(self.obj, conditions['to'].value, force=True) # Change state
# Raise a transition-after signal
transition_after.send(self.obj, transition=self.st)
transition_after.send(self.obj, transition=self.statetransition)
return result


Expand All @@ -422,10 +445,13 @@ def __init__(self, propname, lenum, readonly=True, doc=None):
self.lenum = lenum
self.readonly = readonly
self.__doc__ = doc
self.states = {} # name: ManagedState
self.states = {} # name: ManagedState/ManagedStateGroup
self.transitions = [] # names of transitions linked to this state manager

# Make a copy of all states in the lenum within the state manager as a ManagedState.
# We do NOT convert grouped states into a ManagedStateGroup instance as ManagedState
# is more efficient at testing whether a value is in a group: it uses the `in` operator
# while ManagedStateGroup does `any(s() for s in states)`.
for state_name, value in lenum.__names__.items():
self._add_state_internal(state_name, value,
# Grouped states are represented as sets and can't have labels, so be careful about those
Expand Down Expand Up @@ -474,7 +500,13 @@ def _add_state_internal(self, name, value, label=None,

def add_state_group(self, name, *states):
"""
Add a group of states (including conditional states)
Add a group of managed states. Groups can be specified directly in the
:class:`~coaster.utils.classes.LabeledEnum`. This method is only useful
for grouping a conditional state with existing states. It cannot be
used to form a group of groups.
:param str name: Name of this group
:param states: :class:`ManagedState` instances to be grouped together
"""
# See `_add_state_internal` for explanation of the following
if hasattr(self, name):
Expand All @@ -501,6 +533,7 @@ def add_conditional_state(self, name, state, validator, class_validator=None, ca
"""
if name in self.lenum.__dict__ or name in self.states:
raise AttributeError("State %s already exists" % name)
# We'll accept a ManagedState with grouped values, but not a ManagedStateGroup
if not isinstance(state, ManagedState):
raise ValueError("Invalid state: %s" % repr(state))
elif state.statemanager != self:
Expand Down Expand Up @@ -542,7 +575,8 @@ def __call__(self, obj, cls=None):
@staticmethod
def check_constraint(column, lenum, **kwargs):
"""
Returns a SQL CHECK constraint string given a column name and a LabeledEnum
Returns a SQL CHECK constraint string given a column name and a
:class:`~coaster.utils.classes.LabeledEnum`
:param str column: Column name
:param LabeledEnum lenum: LabeledEnum to retrieve valid values from
Expand Down

0 comments on commit d68d5d4

Please sign in to comment.