Skip to content

Commit

Permalink
changes to tagged event handlers following review
Browse files Browse the repository at this point in the history
  • Loading branch information
rmorshea committed Mar 21, 2017
1 parent d758882 commit 782f695
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 128 deletions.
5 changes: 1 addition & 4 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,8 @@ def _test_observer2(change):

a.unobserve_all()

# test that tagged notifiers know
# about dynamically added traits
a.observe(_test_observer1, tags={'type': 'b'})
a.add_traits(baz=Int().tag(type='b'))

a.observe(_test_observer1, tags={'type': 'b'})
a.baz = 1
self.assertEqual(a.foo, 1)
a.foo = 0
Expand Down
178 changes: 54 additions & 124 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def class_init(self, cls, name):
def subclass_init(self, cls):
if '_%s_default' % self.name in cls.__dict__:
method = getattr(cls, '_%s_default' % self.name)
cls._trait_default_generators[self.name] = method
cls._register_default_generators(method, self.name)

def __init__(self, default_value=Undefined, allow_none=False, read_only=None, help=None,
config=None, **kwargs):
Expand Down Expand Up @@ -928,6 +928,19 @@ def __get__(self, inst, cls=None):
return self
return types.MethodType(self.func, inst)

def __eq__(self, other):
if isinstance(other, EventHandler):
return self is other
else:
return other == self.func

def __hash__(self):
"""This hash is ONLY safe for hash comparisons to other EventHandlers!"""
if self.func is None:
return id(self)
else:
return hash(self.func)


class TraitEventHandler(EventHandler):

Expand All @@ -940,78 +953,30 @@ def __init__(self, names, tags):
self.trait_names = parse_notifier_name(names)
self.metadata = tags or {}

def register(self, obj):
"""Associate this event with traits on an instance"""
raise NotImplementedError()

def event_of(self, obj, names=None):
"""Get the object's trait names this instance is an event of
Parameters
----------
obj: HasTraits instance or class
The instance or class being examined
name: list, tuple
A list of those trait's names. Only those names which
this instance is an event of will be returned. If no
names are given, then all the trait names of this event
are returned instead.
Returns
-------
A ``set`` of trait names
"""
tagged = parse_notifier_tags(obj, self.metadata)
if names is not None:
return set(names).intersection(self.trait_names + tagged)
else:
return set(self.trait_names + tagged)

def __eq__(self, other):
if isinstance(other, TraitEventHandler):
return self is other
else:
return other == self.func


class ObserveHandler(TraitEventHandler):

caches_instance_resources = True

def __init__(self, names, type, tags):
super(ObserveHandler, self).__init__(names, tags)
self.type = type

def instance_init(self, inst):
self.register(inst)

def register(self, inst):
inst.observe(self, self.event_of(inst), type=self.type)
def instance_init(self, obj):
super(ObserveHandler, self).instance_init(obj)
obj.observe(self, self.trait_names, self.type, self.metadata)


class ValidateHandler(TraitEventHandler):

caches_instance_resources = True

def instance_init(self, inst):
self.register(inst)

def register(self, inst, names=None):
inst._register_validator(self, self.event_of(inst))
def instance_init(self, obj):
super(ValidateHandler, self).instance_init(obj)
obj._register_validator(self, self.trait_names, self.metadata)


class DefaultHandler(TraitEventHandler):

caches_class_resources = True

def class_init(self, cls, name):
super(DefaultHandler, self).class_init(cls, name)
self.register(cls)

def register(self, cls):
"""Associate this event with traits on a class"""
for n in self.event_of(cls):
cls._trait_default_generators[n] = self
cls._register_default_generators(self, self.trait_names, self.metadata)


class HasDescriptors(six.with_metaclass(MetaHasDescriptors, object)):
Expand Down Expand Up @@ -1066,7 +1031,6 @@ def setup_instance(*args, **kwargs):
self._trait_values = {}
self._trait_notifiers = {}
self._trait_validators = {}
self._static_trait_notifiers = []
super(HasTraits, self).setup_instance(*args, **kwargs)

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -1109,7 +1073,6 @@ def __getstate__(self):
# recall of instance_init during __setstate__
d['_trait_notifiers'] = {}
d['_trait_validators'] = {}
d['_static_trait_notifiers'] = []
return d

def __setstate__(self, state):
Expand Down Expand Up @@ -1273,8 +1236,6 @@ def _add_notifiers(self, handler, names, type):
nlist = self._trait_notifiers[name][type]
if handler not in nlist:
nlist.append(handler)
if handler.name is None and handler not in self._static_trait_notifiers:
self._static_trait_notifiers.append(handler)

def _remove_notifiers(self, handler, names, type):
for name in names:
Expand All @@ -1283,8 +1244,6 @@ def _remove_notifiers(self, handler, names, type):
event = self._trait_notifiers[name][type].pop(i)
except KeyError:
pass
if getattr(handler, 'name', None) is None:
self._static_trait_notifiers.remove(handler)

def on_trait_change(self, handler=None, name=None, remove=False):
"""DEPRECATED: Setup a handler to be called when a trait changes.
Expand Down Expand Up @@ -1348,14 +1307,19 @@ def observe(self, handler, names=None, type='change', tags=None):
type : str, All (default: 'change')
The type of notification to filter by. If equal to All, then all
notifications are passed to the observe handler.
Returns
-------
An :class:`ObserveHandler`.
"""
if names is None and tags is None:
names = (All,)
if isinstance(handler, ObserveHandler):
self._add_notifiers(handler, names, type)
else:
event = ObserveHandler(names, type=type, tags=tags)
event(handler).register(self)
all_names = (
parse_notifier_name(names) +
parse_notifier_tags(self, tags)
)
self._add_notifiers(handler, all_names, type)

def unobserve(self, handler, names=None, type='change', tags=None):
"""Remove a trait change handler.
Expand All @@ -1374,10 +1338,6 @@ def unobserve(self, handler, names=None, type='change', tags=None):
The type of notification to filter by. If All, the specified handler
is uninstalled from the list of notifiers corresponding to all types.
"""
if isinstance(tags, six.string_types):
warn("new argument 'tags' introduced: use 'type' as keyword",
DeprecationWarning, stacklevel=2)
type, tags = tags, None
if names is None and tags is None:
names = All
named = parse_notifier_name(names)
Expand All @@ -1400,7 +1360,6 @@ def unobserve_all(self, name=All, type=All):
"""
if name is All and type is All:
self._trait_notifiers = {}
self._static_trait_notifiers = []
else:
for n in self._trait_notifiers if name is All else [name]:
if type is All:
Expand All @@ -1417,11 +1376,8 @@ def unobserve_all(self, name=All, type=All):
pass
else:
del self._trait_notifiers[n][t]
for h in handlers:
if h.name is None:
self._static_trait_notifiers.remove(h)

def _register_validator(self, handler, names):
def _register_validator(self, handler, names=None, tags=None):
"""Setup a handler to be called when a trait should be cross validated.
This is used to setup dynamic notifications for cross-validation.
Expand All @@ -1441,28 +1397,36 @@ def _register_validator(self, handler, names):
names : List of strings
The names of the traits that should be cross-validated
"""
for name in names:
all_names = (
parse_notifier_name(names) +
parse_notifier_tags(self, tags)
)
for name in all_names:
magic_name = '_%s_validate' % name
if hasattr(self, magic_name):
class_value = getattr(self.__class__, magic_name)
if not isinstance(class_value, ValidateHandler):
_deprecated_method(class_value, self.__class, magic_name,
"use @validate decorator instead.")
for name in names:
for name in all_names:
self._trait_validators[name] = handler

@classmethod
def _register_default_generators(cls, handler, names=None, tags=None):
all_names = (
parse_notifier_name(names) +
parse_notifier_tags(cls, tags)
)
for name in all_names:
cls._trait_default_generators[name] = handler

def add_traits(self, **traits):
"""Dynamically add trait attributes to the HasTraits instance."""
self.__class__ = type(self.__class__.__name__, (self.__class__,),
traits)

events = self._static_trait_notifiers
events.extend(self.trait_events().values())

for e in events:
if e.caches_instance_resources:
e.register(self)

self.__class__ = type(
self.__class__.__name__,
(self.__class__,),
traits
)
for trait in traits.values():
trait.instance_init(self)

Expand Down Expand Up @@ -1539,7 +1503,8 @@ def trait_defaults(self, *names, **metadata):
Notes
-----
Dynamically generated default values may
depend on the current state of the object."""
depend on the current state of the object.
"""
if len(names) == 1 and len(metadata) == 0:
return self._trait_default_generators[names[0]](self)

Expand Down Expand Up @@ -1604,41 +1569,6 @@ def trait_metadata(self, traitname, key, default=None):
else:
return trait.metadata.get(key, default)

@classmethod
def class_own_trait_events(cls, name):
"""Get a dict of all event handlers defined on this class, not a parent.
Works like ``event_handlers``, except for excluding traits from parents.
"""
sup = super(cls, cls)
return {n: e for (n, e) in cls.events(name).items()
if getattr(sup, n, None) is not e}

@classmethod
def trait_events(cls, name=None):
"""Get a ``dict`` of all the event handlers of this class.
Parameters
----------
name: str (default: None)
The name of a trait of this class. If name is ``None`` then all
the event handlers of this class will be returned instead.
Returns
-------
The event handlers associated with a trait name, or all event handlers.
"""
events = {}
for k, v in getmembers(cls):
if isinstance(v, EventHandler):
if name is None:
events[k] = v
elif name in v.trait_names:
events[k] = v
elif hasattr(v, 'tags'):
if cls.trait_names(**v.tags):
events[k] = v
return events

#-----------------------------------------------------------------------------
# Actual TraitTypes implementations/subclasses
Expand Down

0 comments on commit 782f695

Please sign in to comment.