diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 76916cc43..c27dbccfa 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -702,20 +702,35 @@ class A(HasTraits): a = A() - def _test_observer(change): - a.foo = change['new'] + def _test_observer1(change): + a.foo += 1 + def _test_observer2(change): + a.foo += 1 + def _test_observer3(change): + a.foo += 1 + + # test that multiple evals will register together + a.observe(_test_observer1, tags={'type': lambda v: v in 'ac'}) + a.observe(_test_observer2, tags={'type': lambda v: v in 'ab'}) + # test that evals and static tags register together + a.observe(_test_observer3, tags={'type': 'a'}) - a.observe(_test_observer, tags={'type': lambda v: v in 'abc'}) a.bar = 1 - self.assertEqual(a.foo, a.bar) + self.assertEqual(a.foo, 3) + a.foo = 0 - a.add_traits(baz=Int().tag(type='b')) - a.baz = 2 - self.assertEqual(a.foo, a.baz) + a.unobserve(_test_observer1, 'bar') + a.unobserve(_test_observer3, 'bar') + a.bar = 2 + self.assertEqual(a.foo, 1) + a.foo = 0 - a.unobserve(_test_observer, 'bar') - a.bar = 3 - self.assertNotEqual(a.foo, a.bar) + # test tagged notifiers know about + # dynamically added traits + a.add_traits(baz=Int().tag(type='b')) + a.baz = 1 + self.assertEqual(a.foo, 1) + a.foo = 0 class TestHasTraits(TestCase): diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index e9d563ebd..78f602f8d 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -59,7 +59,7 @@ from .utils.getargspec import getargspec from .utils.importstring import import_item from .utils.sentinel import Sentinel -from .utils.mapping import isdict +from .utils.dict_types import mapping SequenceTypes = (list, tuple, set, frozenset) @@ -214,7 +214,11 @@ def __eq__(self, other): if isinstance(other, types.FunctionType): return self.func == other else: - return self.func(other) + try: + return self.func(other) + except: + return False + __hash__ = None def getmembers(object, predicate=None): @@ -1142,15 +1146,11 @@ def trait_notifiers(self, name, type): if k in d['tags']: for t in (All, type): if v in d['tags'][k]: - if id(v) not in d['tags'][k].ids(): - # accounts for _SimpleEval - isdictkeys = list(d['tags'][k].keys()) - v = isdictkeys[isdictkeys.index(v)] - if t in d['tags'][k][v]: - for c in d['tags'][k][v][t]: - if c not in notifiers: - notifiers.append(c) - + for m in d['tags'][k][v]: + if t in m: + for c in m[t]: + if c not in notifiers: + notifiers.append(c) return notifiers @@ -1201,18 +1201,21 @@ def _add_notifiers(self, handler, name, tags, type): tagged = self._trait_notifiers['tags'] for k, v in tags.items(): if k in tagged: - values = tagged[k] + notifier_mapping = tagged[k] else: - values = isdict() - tagged[k] = values + # mapping handles all types + # and custom equivalence + notifier_mapping = mapping() + tagged[k] = notifier_mapping if isinstance(v, types.FunctionType): v = _SimpleEval(v) + # get the internal dict to which `v` should be assigned + values = notifier_mapping.get_internal_dict(v) if v in values: d = values[v] else: - # isdict handles unhashable types d = {} values[v] = d @@ -1238,21 +1241,18 @@ def _remove_notifiers(self, handler, name, tags, type): trait = getattr(self.__class__, name, None) if isinstance(trait, TraitType): d = self._trait_notifiers - for tags in (tags, trait.metadata): - for k, v in tags.items(): - if k in d['tags']: - mapping = d['tags'][k] - if v in mapping and id(v) not in mapping.ids(): - # accounts for _SimpleEval - isdictkeys = list(mapping.keys()) - v = isdictkeys[isdictkeys.index(v)] - try: - if handler is None: - del mapping[v][type] - else: - mapping[v][type].remove(handler) - except KeyError: - pass + for k, v in trait.metadata.items(): + if k in d['tags']: + for t in (All, type): + if v in d['tags'][k]: + for m in d['tags'][k][v]: + try: + if handler is None: + del m[t] + else: + m[t].remove(handler) + except: + pass def on_trait_change(self, handler=None, name=None, remove=False): """DEPRECATED: Setup a handler to be called when a trait changes. diff --git a/traitlets/utils/dict_types.py b/traitlets/utils/dict_types.py new file mode 100644 index 000000000..f8e19a837 --- /dev/null +++ b/traitlets/utils/dict_types.py @@ -0,0 +1,255 @@ +# encoding: utf-8 +""" +A utility for mapping unhashable objects to values +""" +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + + +class isdict(object): + + def __init__(self, pairs=None): + """A dict-like object that maps unhashable keys to values + + NOTE + ---- + Objects used as keys in an isdict are not stored as references. + Thus the keys of an won't be garbage collected until the :class:`isdict` + instance itself is released. + """ + # maps ids to values + self._dict = {} + # maps ids to true key objects + self._refs = {} + self.update(pairs) + + def __getitem__(self, key): + try: + return self._dict[id(key)] + except KeyError: + raise KeyError(key) + + def __setitem__(self, key, value): + i = id(key) + self._refs[i] = key + self._dict[i] = value + + def __delitem__(self, key): + try: + i = id(key) + del self._dict[i] + del self._refs[i] + except ValueError: + raise KeyError(key) + + def __iter__(self): + for k in self._refs.values(): + yield k + + def update(self, pairs=None): + if pairs is not None: + lengths = set(map(len, pairs)) + if 2 not in lengths or len(lengths)>1: + # invalid update sequence + for i in range(len(pairs)): + if len(pairs[i]) != 2: + t = (str(i), str(len(pairs))) + raise ValueError("update sequence element #%s has" + " length %s; 2 is required" % t) + else: + keys, values = zip(*pairs) + + ids = map(id, keys) + self._dict.update(zip(ids, values)) + self._refs.update(zip(ids, keys)) + + def pop(self, key): + try: + i = id(key) + del self._refs[i] + return self._dict.pop(i) + except ValueError: + raise KeyError(key) + + def get(self, key, default=None): + try: + return self.__getitem__(key) + except: + return default + + def ids(self): + return self._dict.keys() + + def keys(self): + return self._refs.values() + + def values(self): + return self._dict.values() + + def items(self): + return zip(self._refs.values(), self._dict.values()) + + def __repr__(self): + body = [] + for k, v in self.items(): + body.append(repr(k) + ': ' + repr(v)) + return '{' + ', '.join(body) + '}' + + +class eqdict(object): + + def __init__(self, pairs=None): + """A dict-like object for mapping equivalent keys to values""" + self._keys = [] + self._vals = [] + self.update(pairs) + + def __getitem__(self, key): + enum = enumerate(self._keys) + values = [self._vals[i] for i, k in enum if k == key] + if len(values) == 0: + raise KeyError(key) + else: + return values + + def __setitem__(self, key, value): + try: + i = self._keys.index(key) + except ValueError: + self._keys.append(key) + self._vals.append(value) + else: + self._vals[i] = value + + def __delitem__(self, key): + try: + i = self._keys.index(key) + except ValueError: + raise KeyError(key) + else: + del self._keys[i] + del self._vals[i] + + def __iter__(self): + for k in self._keys: + yield k + + def update(self, pairs=None): + if pairs is not None: + lengths = set(map(len, pairs)) + if 2 not in lengths or len(lengths)>1: + # invalid update sequence + for i in range(len(pairs)): + if len(pairs[i]) != 2: + t = (str(i), str(len(pairs))) + raise ValueError("update sequence element #%s has" + " length %s; 2 is required" % t) + else: + for k, v in zip(*pairs): + self.__setitem__(k, v) + + def get(self, key, default=None): + try: + return self.__getitem__(key) + except KeyError: + return default + + def pop(self, key): + try: + i = self._keys.index(key) + except ValueError: + raise KeyError(key) + else: + del self._keys[i] + return self._vals.pop(i) + + def keys(self): + return self._keys[:] + + def values(self): + return self._vals[:] + + def items(self): + return zip(self._keys, self._vals) + + def __repr__(self): + body = [] + for k, v in self.items(): + body.append(repr(k) + ': ' + repr(v)) + return '{' + ', '.join(body) + '}' + + +class mapping(object): + + def __init__(self): + """A dict-like object for mapping any key to a set of values""" + # handles unhashables + self._is = isdict() + # handles custom equivalence + self._eq = eqdict() + # all other values + self._dict = {} + + def __getitem__(self, key): + values = [] + try: + hash(key) + except: + if key in self._is[key]: + values.append(self._is[key]) + else: + # note that python 2 classes with + # __eq__ or __cmp__ are hashable + if key in self._dict: + values.append(self._dict[key]) + if key in self._eq: + values.extend(self._eq[key]) + if len(values) == 0: + raise KeyError(key) + return values + + def __setitem__(self, key, value): + try: + hash(key) + except: + if hasattr(key, '__eq__') or hasattr(key, '__cmp__'): + self._eq[key] = value + else: + self._is[key] = value + else: + # note that python 2 classes with + # __eq__ or __cmp__ are hashable + self._dict[key] = value + + def get_internal_dict(self, key): + """Return the internal dict to which this key would be assigned""" + try: + hash(key) + except: + if hasattr(key, '__eq__') or hasattr(key, '__cmp__'): + return self._eq + else: + return self._is + else: + # note that python 2 classes with + # __eq__ or __cmp__ are hashable + return self._dict + + def __iter__(self): + for k in self.keys(): + yield k + + def keys(self): + return self._is.keys() + self._eq.keys() + self._dict.keys() + + def values(self): + return self._is.values() + self._eq.values() + self._dict.values() + + def items(self): + return self._is.items() + self._eq.items() + self._dict.items() + + def __repr__(self): + body = [] + for k, v in self.items(): + body.append(repr(k) + ': ' + repr(v)) + return '{' + ', '.join(body) + '}' diff --git a/traitlets/utils/mapping.py b/traitlets/utils/mapping.py deleted file mode 100644 index 4b3322fc3..000000000 --- a/traitlets/utils/mapping.py +++ /dev/null @@ -1,90 +0,0 @@ -# encoding: utf-8 -""" -A utility for mapping unhashable objects to values -""" -# Copyright (c) IPython Development Team. -# Distributed under the terms of the Modified BSD License. - -import warnings - -class isdict(object): - - def __init__(self, pairs=None): - """A dict-like mapping that keys objects on their id - - Note - ---- - Objects used as keys in an idict are not stored as references. - Thus the keys of an :class:`idict` won't be garbage collected - until the :class:`idict` instance itself is released. - """ - # maps ids to values - self._dict = {} - # maps ids to true key objects - self._refs = {} - self.update(pairs) - - def __getitem__(self, key): - try: - return self._dict[id(key)] - except KeyError: - raise KeyError(key) - - def __setitem__(self, key, value): - i = id(key) - self._refs[i] = key - self._dict[i] = value - - def __delitem__(self, key): - try: - i = id(key) - del self._dict[i] - del self._refs[i] - except ValueError: - raise KeyError(key) - - def __iter__(self): - return self._refs.values().__iter__() - - def update(self, pairs=None): - if pairs is not None: - lengths = set(map(len, pairs)) - if 2 not in lengths or len(lengths)>1: - # invalid update sequence - for i in range(len(pairs)): - if len(pairs[i]) != 2: - t = (str(i), str(len(pairs))) - raise ValueError("update sequence element #%s has" - " length %s; 2 is required" % t) - else: - keys, values = zip(*pairs) - - ids = map(id, keys) - self._dict.update(zip(ids, values)) - self._refs.update(zip(ids, keys)) - - def pop(self, key): - try: - i = id(key) - del self._refs[i] - return self._dict.pop(i) - except ValueError: - raise KeyError(key) - - def ids(self): - return self._dict.keys() - - def keys(self): - return self._refs.values() - - def values(self): - return self._dict.values() - - def items(self): - return zip(self._refs.values(), self._dict.values()) - - def __repr__(self): - body = [] - for k, v in self.items(): - body.append(repr(k) + ': ' + repr(v)) - return '{' + ', '.join(body) + '}'