-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
link_hook.py
190 lines (151 loc) · 6.39 KB
/
link_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import chainer
from chainer import utils
class _ForwardPreprocessCallbackArgs(object):
"""Callback data for LinkHook.forward_preprocess"""
def __init__(self, link, forward_name, args, kwargs):
self.link = link
self.forward_name = forward_name
self.args = args
self.kwargs = kwargs
def __repr__(self):
return utils._repr_with_named_data(
self, link=self.link, forward_name=self.forward_name,
args=self.args, kwargs=self.kwargs)
class _ForwardPostprocessCallbackArgs(object):
"""Callback data for LinkHook.forward_postrocess"""
def __init__(self, link, forward_name, args, kwargs, out):
self.link = link
self.forward_name = forward_name
self.args = args
self.kwargs = kwargs
self.out = out
def __repr__(self):
return utils._repr_with_named_data(
self, link=self.link, forward_name=self.forward_name,
args=self.args, kwargs=self.kwargs, out=self.out)
class LinkHook(object):
"""Base class of hooks for links.
:class:`~chainer.LinkHook` is a callback object
that is registered to a :class:`~chainer.Link`.
Registered link hooks are invoked before and after calling
:meth:`Link.forward() <chainer.Link.forward>` method of each link.
Link hooks that derive from :class:`LinkHook` may override the following
method:
* :meth:`~chainer.LinkHook.added`
* :meth:`~chainer.LinkHook.deleted`
* :meth:`~chainer.LinkHook.forward_preprocess`
* :meth:`~chainer.LinkHook.forward_postprocess`
By default, these methods do nothing.
Specifically, when the :meth:`~chainer.Link.__call__`
method of some link is invoked,
:meth:`~chainer.LinkHook.forward_preprocess`
(resp. :meth:`~chainer.LinkHook.forward_postprocess`)
of all link hooks registered to this link are called before (resp. after)
:meth:`Link.forward() <chainer.Link.forward>` method of the link.
There are two ways to register :class:`~chainer.LinkHook`
objects to :class:`~chainer.Link` objects.
The first one is to use ``with`` statement. Link hooks hooked
in this way are registered to all links within ``with`` statement
and are unregistered at the end of ``with`` statement.
.. admonition:: Example
The following code is a simple example in which
we measure the elapsed time of a part of forward propagation procedure
with :class:`~chainer.link_hooks.TimerHook`, which is a subclass of
:class:`~chainer.LinkHook`.
>>> class Model(chainer.Chain):
... def __init__(self):
... super(Model, self).__init__()
... with self.init_scope():
... self.l = L.Linear(10, 10)
... def forward(self, x1):
... return F.exp(self.l(x1))
>>> model1 = Model()
>>> model2 = Model()
>>> x = chainer.Variable(np.zeros((1, 10), np.float32))
>>> with chainer.link_hooks.TimerHook() as m:
... _ = model1(x)
... y = model2(x)
>>> model3 = Model()
>>> z = model3(y)
>>> print('Total time : {}'.format(m.total_time()))
... # doctest:+ELLIPSIS
Total time : ...
In this example, we measure the elapsed times for each forward
propagation of all functions in ``model1`` and ``model2``.
Note that ``model3`` is not a target measurement
as :class:`~chainer.link_hooks.TimerHook` is unregistered
before forward propagation of ``model3``.
.. note::
Chainer stores the dictionary of registered link hooks
as a thread local object. So, link hooks registered
are different depending on threads.
The other one is to register directly to
a :class:`~chainer.Link` object by calling its
:meth:`~chainer.Link.add_hook` method.
Link hooks registered in this way can be removed by
:meth:`~chainer.Link.delete_hook` method.
Contrary to former registration method, link hooks are registered
only to the link which :meth:`~chainer.Link.add_hook`
is called.
Args:
name(str): Name of this link hook.
"""
name = 'LinkHook'
def __enter__(self):
link_hooks = chainer._get_link_hooks()
if self.name in link_hooks:
raise KeyError('hook %s already exists' % self.name)
link_hooks[self.name] = self
self.added(None)
return self
def __exit__(self, *_):
link_hooks = chainer._get_link_hooks()
link_hooks[self.name].deleted(None)
del link_hooks[self.name]
def added(self, link):
"""Callback function invoked when the link hook is registered
Args:
link(~chainer.Link): Link object to which
the link hook is registered. ``None`` if the link hook is
registered globally.
"""
pass
def deleted(self, link):
"""Callback function invoked when the link hook is unregistered
Args:
link(~chainer.Link): Link object to which
the link hook is unregistered. ``None`` if the link hook had
been registered globally.
"""
pass
# forward
def forward_preprocess(self, args):
"""Callback function invoked before a forward call of a link.
Args:
args: Callback data. It has the following attributes:
* link (:class:`~chainer.Link`)
Link object.
* forward_name (:class:`str`)
Name of the forward method.
* args (:class:`tuple`)
Non-keyword arguments given to the forward method.
* kwargs (:class:`dict`)
Keyword arguments given to the forward method.
"""
pass
def forward_postprocess(self, args):
"""Callback function invoked after a forward call of a link.
Args:
args: Callback data. It has the following attributes:
* link (:class:`~chainer.Link`)
Link object.
* forward_name (:class:`str`)
Name of the forward method.
* args (:class:`tuple`)
Non-keyword arguments given to the forward method.
* kwargs (:class:`dict`)
Keyword arguments given to the forward method.
* out
Return value of the forward method.
"""
pass