Skip to content

Commit

Permalink
StateManager now returns a wrapper instead of a bool (#165)
Browse files Browse the repository at this point in the history
This breaks one test, but that test is for an internal object
  • Loading branch information
jace committed Feb 7, 2018
1 parent cde8875 commit 8979fa9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
48 changes: 39 additions & 9 deletions coaster/sqlalchemy/statemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def is_scalar(self):
def __repr__(self):
return '%s.%s' % (self.statemanager.name, self.name)

def __call__(self, obj, cls=None):
def _eval(self, obj, cls=None):
# TODO: Respect cache as specified in `cache_for`
if obj is not None: # We're being called with an instance
if isinstance(self.value, iterables):
Expand All @@ -286,8 +286,18 @@ def __call__(self, obj, cls=None):
else:
return valuematch

def __call__(self, obj, cls=None):
if obj is not None:
return ManagedStateWrapper(self, obj, cls)
else:
return self._eval(obj, cls)


class ManagedStateGroup(object):
"""
Represents a group of managed states in a StateManager. Do not use this
class directly. Use :meth:`~StateManager.add_state_group` instead.
"""
def __init__(self, name, statemanager, states):
self.name = name
self.statemanager = statemanager
Expand Down Expand Up @@ -324,33 +334,53 @@ def __init__(self, name, statemanager, states):
def __repr__(self):
return '%s.%s' % (self.statemanager.name, self.name)

def __call__(self, obj, cls=None):
def _eval(self, obj, cls=None):
if obj is not None: # We're being called with an instance
return any(s(obj, cls) for s in self.states)
else:
return or_(*[s(obj, cls) for s in self.states])

def __call__(self, obj, cls=None):
if obj is not None:
return ManagedStateWrapper(self, obj, cls)
else:
return self._eval(obj, cls)


class ManagedStateWrapper(object):
"""
Wraps a :class:`ManagedState` or :class:`ManagedStateGroup` with
an object or class and otherwise provides transparent access to contents
an object or class, and otherwise provides transparent access to contents.
"""
def __init__(self, mstate, obj, cls=None):
if not isinstance(mstate, (ManagedState, ManagedStateGroup)):
raise TypeError("Parameter is not a managed state: %s" % repr(mstate))
self.__mstate = mstate
self.__obj = obj
self.__cls = cls
self._mstate = mstate
self._obj = obj
self._cls = cls

def __repr__(self):
return '<ManagedStateWrapper %s>' % repr(self.__mstate)
return '<ManagedStateWrapper %s>' % repr(self._mstate)

def __call__(self):
return self.__mstate(self.__obj, self.__cls)
return self._mstate._eval(self._obj, self._cls)

def __getattr__(self, attr):
return getattr(self.__mstate, attr)
return getattr(self._mstate, attr)

def __eq__(self, other):
return (isinstance(other, ManagedStateWrapper) and
self._mstate == other._mstate and
self._obj == other._obj and
self._cls == other._cls)

def __ne__(self, other):
return not self.__eq__(other)

def __bool__(self):
return self()

__nonzero__ = __bool__


class StateTransition(object):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_statemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_conditional_state_unmanaged_state(self):
def test_conditional_state_label(self):
"""Conditional states can have labels"""
self.assertEqual(MyPost.__dict__['state'].RECENT.label.name, 'recent')
self.assertEqual(self.post.state.RECENT.label.name, 'recent')

def test_transition_invalid_from_to(self):
"""
Expand Down Expand Up @@ -515,13 +516,13 @@ def test_current_states(self):
def test_managed_state_wrapper(self):
"""ManagedStateWrapper will only wrap a managed state or group"""
draft = MyPost.__dict__['state'].DRAFT
wdraft = ManagedStateWrapper(draft, self.post)
wdraft = ManagedStateWrapper(draft, self.post, MyPost)
self.assertEqual(draft.value, wdraft.value)
self.assertTrue(wdraft())
self.assertEqual(self.post.state.DRAFT, wdraft())
self.assertEqual(self.post.state.DRAFT, wdraft)
self.post.submit()
self.assertFalse(wdraft())
self.assertEqual(self.post.state.DRAFT, wdraft())
self.assertEqual(self.post.state.DRAFT, wdraft)

with self.assertRaises(TypeError):
ManagedStateWrapper(MY_STATE.DRAFT, self.post)
Expand Down

0 comments on commit 8979fa9

Please sign in to comment.