From 5496564c7a40bd631adbacc07a4d686bf343edea Mon Sep 17 00:00:00 2001 From: Ryan Morshead Date: Wed, 6 Jan 2016 22:57:53 -0600 Subject: [PATCH] new approach to non-hashables --- traitlets/tests/test_traitlets.py | 11 +- traitlets/traitlets.py | 168 ++++++++++++++++++------------ traitlets/utils/index_table.py | 92 ++++++++++++++++ 3 files changed, 198 insertions(+), 73 deletions(-) create mode 100644 traitlets/utils/index_table.py diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 1b8074d2b..928eae5ac 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -679,24 +679,25 @@ def test_observe_via_tags(self): class A(HasTraits): foo = Int() - bar = Int().tag(type=None) + bar = Int().tag(type='a') a = A() def _test_observer(change): a.foo = change['new'] - a.observe(_test_observer, tags={'type': lambda v: v is None}) + a.observe(_test_observer, tags={'type': lambda v: v in 'abc'}) a.bar = 1 self.assertEqual(a.foo, a.bar) - a.add_traits(baz=Int().tag(type=None)) + a.add_traits(baz=Int().tag(type='b')) a.baz = 2 self.assertEqual(a.foo, a.baz) - a.unobserve(_test_observer) + a.unobserve(_test_observer, 'bar') a.bar = 3 - self.assertNotEqual(a.bar, a.baz) + print(a._trait_notifiers) + self.assertNotEqual(a.foo, a.bar) class TestHasTraits(TestCase): diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index e18e98b80..9b8aab9d0 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -58,6 +58,7 @@ from .utils.getargspec import getargspec from .utils.importstring import import_item from .utils.sentinel import Sentinel +from .utils.index_table import itable SequenceTypes = (list, tuple, set, frozenset) @@ -79,12 +80,22 @@ ''' ) +_Unhashable = Sentinel('Unhashable', 'traitlets', +''' +Used as a key in `HasTraits._trait_notifiers` for unhashable types. +''' +) + # Deprecated alias NoDefaultSpecified = Undefined class TraitError(Exception): pass +# used to lock reference counts to at least one +# when observing tags with unhashable types. +_lock_reference_count = {} + #----------------------------------------------------------------------------- # Utilities #----------------------------------------------------------------------------- @@ -178,23 +189,6 @@ def parse_notifier_name(names): return names -class _TaggedNotifierContainer: - - def __init__(self, value): - self.notifiers = [] - if isinstance(value, types.FunctionType): - self.eval = True - else: - self.eval = False - self.value = value - - def __eq__(self, other): - if not self.eval or isinstance(other, types.FunctionType): - return self.value == other - else: - return self.value(other) - - class _SimpleTest: def __init__ ( self, value ): self.value = value def __call__ ( self, test ): @@ -205,6 +199,23 @@ def __str__(self): return self.__repr__() +class _KeyValuePair: + def __init__(self, key, value): + self.key = key + self.value = value + def __eq__(self, other): + return self.key == other + + +class _SimpleEval: + def __init__(self, func): self.func = func + def __eq__(self, other): + if isinstance(other, types.FunctionType): + return self.func == other + else: + return self.func(other) + + def getmembers(object, predicate=None): """A safe version of inspect.getmembers that handles missing attributes. @@ -1043,27 +1054,23 @@ def notify_change(self, change): name, type = change['name'], change['type'] callables = [] - callables.extend(self._trait_notifiers['names'].get(name, {}).get(type, [])) - callables.extend(self._trait_notifiers['names'].get(name, {}).get(All, [])) - callables.extend(self._trait_notifiers['names'].get(All, {}).get(type, [])) - callables.extend(self._trait_notifiers['names'].get(All, {}).get(All, [])) + d = self._trait_notifiers + callables.extend(d['names'].get(name, {}).get(type, [])) + callables.extend(d['names'].get(name, {}).get(All, [])) + callables.extend(d['names'].get(All, {}).get(type, [])) + callables.extend(d['names'].get(All, {}).get(All, [])) trait = getattr(self.__class__, name) + for k, v in trait.metadata.items(): - try: - if type in self._trait_notifiers['tags']: - d = self._trait_notifiers['tags'][type] - else: - d = self._trait_notifiers['tags'][All] - contianer = d[k][d[k].index(v)] - except KeyError: - pass - except ValueError: - pass - else: - for c in contianer.notifiers: - if c not in callables: - callables.append(c) + if k in d['tags']: + for t in (All, type): + if v in d['tags'][k] and t in d['tags'][k][v]: + callables.extend(d['tags'][k][v][t]) + elif _Unhashable in d['tags'][k]: + mapping = d['tags'][k][_Unhashable] + if v in mapping and t in mapping[v]: + callables.extend(mapping[v][t]) # Now static ones magic_name = '_%s_changed' % name @@ -1105,25 +1112,46 @@ def _add_notifiers(self, handler, name, tags, type): if tags: tagged = self._trait_notifiers['tags'] - if type in tagged: - d = tagged[type] - else: - d = {} - tagged[type] = d for k, v in tags.items(): - if k in d: - l = d[k] + if k in tagged: + values = tagged[k] + else: + values = {} + tagged[k] = values + + if isinstance(v, types.FunctionType): + v = _SimpleEval(v) + + try: + hash(v) + except TypeError: + # handle unhashable types + if _Unhashable in values: + mapping = values[_Unhashable] + else: + mapping = itable() + values[_Unhashable] = mapping + + if v in mapping: + d = mapping[d] + else: + d = {} + mapping[v] = d else: - l = [] - d[k] = l - if v in l: - i = l.index(v) - c = l[i] + if v in values: + d = values[v] + else: + d = {} + values[v] = d + + if type in d: + nlist = d[type] else: - c = _TaggedNotifierContainer(v) - l.append(c) - if handler not in c.notifiers: - c.notifiers.append(handler) + nlist = [] + d[type] = nlist + + if handler not in nlist: + nlist.append(handler) def _remove_notifiers(self, handler, name, tags, type): try: @@ -1137,22 +1165,26 @@ def _remove_notifiers(self, handler, name, tags, type): if name is not All: trait = getattr(self.__class__, name, None) if isinstance(trait, TraitType): - for k, v in trait.metadata.items(): - try: - if type in self._trait_notifiers['tags']: - d = self._trait_notifiers['tags'][type] - else: - d = self._trait_notifiers['tags'][All] - i = d[k].index(v) - except KeyError: - pass - except ValueError: - pass - else: - if hander is None: - d[k].remove(i) - else: - d[k][i].remove(handler) + d = self._trait_notifiers + for tags in (tags, trait.metadata): + for k, v in tags.items(): + if k in d['tags']: + try: + if handler is None: + del d['tags'][k][v][type] + else: + d['tags'][k][v][type].remove(handler) + except: + # check in the unhashables + if _Unhashable in d['tags'][k]: + mapping = d['tags'][k][_Unhashable] + try: + if handler is None: + del mapping[v][type] + else: + mapping[v][type].remove(handler) + except: + pass def on_trait_change(self, handler=None, name=None, remove=False): """DEPRECATED: Setup a handler to be called when a trait changes. @@ -1251,7 +1283,7 @@ def unobserve(self, handler, names=All, tags=None, type='change'): """ names = parse_notifier_name(names) for n in names: - self._remove_notifiers(handler, n, tags, type) + self._remove_notifiers(handler, n, tags or {}, type) def unobserve_all(self, name=All): """Remove trait change handlers of any type for the specified name. diff --git a/traitlets/utils/index_table.py b/traitlets/utils/index_table.py new file mode 100644 index 000000000..24e0835a7 --- /dev/null +++ b/traitlets/utils/index_table.py @@ -0,0 +1,92 @@ +"""A dict-like table mapping keys to values based on indexing""" + +import warnings + +class itable(object): + + def __init__(self, keys=None, values=None): + """A dict-like table mapping keys to values based on indexing""" + if keys is not None and values is not None: + self.update(keys, values) + elif keys is None and values is None: + self.__keys__ = [] + self.__values__ = [] + else: + raise TypeError('Must provide a sequence for both keys and values') + + def __getitem__(self, key): + try: + i = self.__keys__.index(key) + except ValueError: + raise KeyError(key) + else: + return self.__values__[i] + + def __setitem__(self, key, value): + try: + i = self.__keys__.index(key) + except ValueError: + try: + hash(key) + except: + pass + else: + warnings.warn('got a hashable mapping; use `dict` instead') + self.__keys__.append(key) + self.__values__.append(value) + else: + self.__values__[i] = value + + def __delitem__(self, key): + try: + i = self.__keys__.index(key) + except ValueError: + raise KeyError(key) + else: + del self.__keys__[i] + del self.__values__[i] + + def __iter__(self): + return self.__keys__.__iter__() + + def update(self, keys, values): + if len(keys) != len(values): + raise ValueError('keys and values must have the same length') + try: + map(hash, keys) + except: + pass + else: + warnings.warn('got a hashable mapping; use `dict` instead') + + try: + # repeated keys are not removed, + # however they will be ignored + self.__keys__ = list(keys)[::-1] + self.__values__ = list(values)[::-1] + except: + raise ValueError('keys and values must be iterable') + + def pop(self, key): + try: + i = self.__keys__.index(key) + except ValueError: + raise KeyError(key) + else: + del self.__keys__[i] + return self.__values__.pop(i) + + def keys(self): + return self.__keys__[:] + + def values(self): + return self.__values__[:] + + def items(self): + return zip(self.__keys__, self.__values__) + + def __repr__(self): + body = [] + for k, v in self.items(): + body.append(repr(k) + ': ' + repr(v)) + return '{' + ', '.join(body) + '}'