Skip to content

Commit

Permalink
new approach to non-hashables
Browse files Browse the repository at this point in the history
  • Loading branch information
rmorshea committed Jan 7, 2016
1 parent 6954708 commit 5496564
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 73 deletions.
11 changes: 6 additions & 5 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,24 +679,25 @@ def test_observe_via_tags(self):

class A(HasTraits):
foo = Int()
bar = Int().tag(type=None)
bar = Int().tag(type='a')

a = A()

def _test_observer(change):
a.foo = change['new']

a.observe(_test_observer, tags={'type': lambda v: v is None})
a.observe(_test_observer, tags={'type': lambda v: v in 'abc'})
a.bar = 1
self.assertEqual(a.foo, a.bar)

a.add_traits(baz=Int().tag(type=None))
a.add_traits(baz=Int().tag(type='b'))
a.baz = 2
self.assertEqual(a.foo, a.baz)

a.unobserve(_test_observer)
a.unobserve(_test_observer, 'bar')
a.bar = 3
self.assertNotEqual(a.bar, a.baz)
print(a._trait_notifiers)
self.assertNotEqual(a.foo, a.bar)


class TestHasTraits(TestCase):
Expand Down
168 changes: 100 additions & 68 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .utils.getargspec import getargspec
from .utils.importstring import import_item
from .utils.sentinel import Sentinel
from .utils.index_table import itable

SequenceTypes = (list, tuple, set, frozenset)

Expand All @@ -79,12 +80,22 @@
'''
)

_Unhashable = Sentinel('Unhashable', 'traitlets',
'''
Used as a key in `HasTraits._trait_notifiers` for unhashable types.
'''
)

# Deprecated alias
NoDefaultSpecified = Undefined

class TraitError(Exception):
pass

# used to lock reference counts to at least one
# when observing tags with unhashable types.
_lock_reference_count = {}

#-----------------------------------------------------------------------------
# Utilities
#-----------------------------------------------------------------------------
Expand Down Expand Up @@ -178,23 +189,6 @@ def parse_notifier_name(names):
return names


class _TaggedNotifierContainer:

def __init__(self, value):
self.notifiers = []
if isinstance(value, types.FunctionType):
self.eval = True
else:
self.eval = False
self.value = value

def __eq__(self, other):
if not self.eval or isinstance(other, types.FunctionType):
return self.value == other
else:
return self.value(other)


class _SimpleTest:
def __init__ ( self, value ): self.value = value
def __call__ ( self, test ):
Expand All @@ -205,6 +199,23 @@ def __str__(self):
return self.__repr__()


class _KeyValuePair:
def __init__(self, key, value):
self.key = key
self.value = value
def __eq__(self, other):
return self.key == other


class _SimpleEval:
def __init__(self, func): self.func = func
def __eq__(self, other):
if isinstance(other, types.FunctionType):
return self.func == other
else:
return self.func(other)


def getmembers(object, predicate=None):
"""A safe version of inspect.getmembers that handles missing attributes.
Expand Down Expand Up @@ -1043,27 +1054,23 @@ def notify_change(self, change):
name, type = change['name'], change['type']

callables = []
callables.extend(self._trait_notifiers['names'].get(name, {}).get(type, []))
callables.extend(self._trait_notifiers['names'].get(name, {}).get(All, []))
callables.extend(self._trait_notifiers['names'].get(All, {}).get(type, []))
callables.extend(self._trait_notifiers['names'].get(All, {}).get(All, []))
d = self._trait_notifiers
callables.extend(d['names'].get(name, {}).get(type, []))
callables.extend(d['names'].get(name, {}).get(All, []))
callables.extend(d['names'].get(All, {}).get(type, []))
callables.extend(d['names'].get(All, {}).get(All, []))

trait = getattr(self.__class__, name)

for k, v in trait.metadata.items():
try:
if type in self._trait_notifiers['tags']:
d = self._trait_notifiers['tags'][type]
else:
d = self._trait_notifiers['tags'][All]
contianer = d[k][d[k].index(v)]
except KeyError:
pass
except ValueError:
pass
else:
for c in contianer.notifiers:
if c not in callables:
callables.append(c)
if k in d['tags']:
for t in (All, type):
if v in d['tags'][k] and t in d['tags'][k][v]:
callables.extend(d['tags'][k][v][t])
elif _Unhashable in d['tags'][k]:
mapping = d['tags'][k][_Unhashable]
if v in mapping and t in mapping[v]:
callables.extend(mapping[v][t])

# Now static ones
magic_name = '_%s_changed' % name
Expand Down Expand Up @@ -1105,25 +1112,46 @@ def _add_notifiers(self, handler, name, tags, type):

if tags:
tagged = self._trait_notifiers['tags']
if type in tagged:
d = tagged[type]
else:
d = {}
tagged[type] = d
for k, v in tags.items():
if k in d:
l = d[k]
if k in tagged:
values = tagged[k]
else:
values = {}
tagged[k] = values

if isinstance(v, types.FunctionType):
v = _SimpleEval(v)

try:
hash(v)
except TypeError:
# handle unhashable types
if _Unhashable in values:
mapping = values[_Unhashable]
else:
mapping = itable()
values[_Unhashable] = mapping

if v in mapping:
d = mapping[d]
else:
d = {}
mapping[v] = d
else:
l = []
d[k] = l
if v in l:
i = l.index(v)
c = l[i]
if v in values:
d = values[v]
else:
d = {}
values[v] = d

if type in d:
nlist = d[type]
else:
c = _TaggedNotifierContainer(v)
l.append(c)
if handler not in c.notifiers:
c.notifiers.append(handler)
nlist = []
d[type] = nlist

if handler not in nlist:
nlist.append(handler)

def _remove_notifiers(self, handler, name, tags, type):
try:
Expand All @@ -1137,22 +1165,26 @@ def _remove_notifiers(self, handler, name, tags, type):
if name is not All:
trait = getattr(self.__class__, name, None)
if isinstance(trait, TraitType):
for k, v in trait.metadata.items():
try:
if type in self._trait_notifiers['tags']:
d = self._trait_notifiers['tags'][type]
else:
d = self._trait_notifiers['tags'][All]
i = d[k].index(v)
except KeyError:
pass
except ValueError:
pass
else:
if hander is None:
d[k].remove(i)
else:
d[k][i].remove(handler)
d = self._trait_notifiers
for tags in (tags, trait.metadata):
for k, v in tags.items():
if k in d['tags']:
try:
if handler is None:
del d['tags'][k][v][type]
else:
d['tags'][k][v][type].remove(handler)
except:
# check in the unhashables
if _Unhashable in d['tags'][k]:
mapping = d['tags'][k][_Unhashable]
try:
if handler is None:
del mapping[v][type]
else:
mapping[v][type].remove(handler)
except:
pass

def on_trait_change(self, handler=None, name=None, remove=False):
"""DEPRECATED: Setup a handler to be called when a trait changes.
Expand Down Expand Up @@ -1251,7 +1283,7 @@ def unobserve(self, handler, names=All, tags=None, type='change'):
"""
names = parse_notifier_name(names)
for n in names:
self._remove_notifiers(handler, n, tags, type)
self._remove_notifiers(handler, n, tags or {}, type)

def unobserve_all(self, name=All):
"""Remove trait change handlers of any type for the specified name.
Expand Down
92 changes: 92 additions & 0 deletions traitlets/utils/index_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""A dict-like table mapping keys to values based on indexing"""

import warnings

class itable(object):

def __init__(self, keys=None, values=None):
"""A dict-like table mapping keys to values based on indexing"""
if keys is not None and values is not None:
self.update(keys, values)
elif keys is None and values is None:
self.__keys__ = []
self.__values__ = []
else:
raise TypeError('Must provide a sequence for both keys and values')

def __getitem__(self, key):
try:
i = self.__keys__.index(key)
except ValueError:
raise KeyError(key)
else:
return self.__values__[i]

def __setitem__(self, key, value):
try:
i = self.__keys__.index(key)
except ValueError:
try:
hash(key)
except:
pass
else:
warnings.warn('got a hashable mapping; use `dict` instead')
self.__keys__.append(key)
self.__values__.append(value)
else:
self.__values__[i] = value

def __delitem__(self, key):
try:
i = self.__keys__.index(key)
except ValueError:
raise KeyError(key)
else:
del self.__keys__[i]
del self.__values__[i]

def __iter__(self):
return self.__keys__.__iter__()

def update(self, keys, values):
if len(keys) != len(values):
raise ValueError('keys and values must have the same length')
try:
map(hash, keys)
except:
pass
else:
warnings.warn('got a hashable mapping; use `dict` instead')

try:
# repeated keys are not removed,
# however they will be ignored
self.__keys__ = list(keys)[::-1]
self.__values__ = list(values)[::-1]
except:
raise ValueError('keys and values must be iterable')

def pop(self, key):
try:
i = self.__keys__.index(key)
except ValueError:
raise KeyError(key)
else:
del self.__keys__[i]
return self.__values__.pop(i)

def keys(self):
return self.__keys__[:]

def values(self):
return self.__values__[:]

def items(self):
return zip(self.__keys__, self.__values__)

def __repr__(self):
body = []
for k, v in self.items():
body.append(repr(k) + ': ' + repr(v))
return '{' + ', '.join(body) + '}'

0 comments on commit 5496564

Please sign in to comment.