/
hierarchical_softmax.py
364 lines (290 loc) · 11.6 KB
/
hierarchical_softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import copy
import numpy
import six
from chainer.backends import cuda
from chainer import function
from chainer.initializers import uniform
from chainer import link
from chainer.utils import type_check
from chainer import variable
class TreeParser(object):
def __init__(self):
self.next_id = 0
def size(self):
return self.next_id
def get_paths(self):
return self.paths
def get_codes(self):
return self.codes
def parse(self, tree):
self.next_id = 0
self.path = []
self.code = []
self.paths = {}
self.codes = {}
self._parse(tree)
assert(len(self.path) == 0)
assert(len(self.code) == 0)
assert(len(self.paths) == len(self.codes))
def _parse(self, node):
if isinstance(node, tuple):
# internal node
if len(node) != 2:
raise ValueError(
'All internal nodes must have two child nodes')
left, right = node
self.path.append(self.next_id)
self.next_id += 1
self.code.append(1.0)
self._parse(left)
self.code[-1] = -1.0
self._parse(right)
self.path.pop()
self.code.pop()
else:
# leaf node
self.paths[node] = numpy.array(self.path, dtype=numpy.int32)
self.codes[node] = numpy.array(self.code, dtype=numpy.float32)
class BinaryHierarchicalSoftmaxFunction(function.Function):
"""Hierarchical softmax function based on a binary tree.
This function object should be allocated beforehand, and be copied on every
forward computation, since the initializer parses the given tree. See the
implementation of :class:`BinaryHierarchicalSoftmax` for details.
Args:
tree: A binary tree made with tuples like ``((1, 2), 3)``.
.. seealso::
See :class:`BinaryHierarchicalSoftmax` for details.
"""
def __init__(self, tree):
parser = TreeParser()
parser.parse(tree)
paths = parser.get_paths()
codes = parser.get_codes()
n_vocab = max(paths.keys()) + 1
self.paths = numpy.concatenate(
[paths[i] for i in range(n_vocab) if i in paths])
self.codes = numpy.concatenate(
[codes[i] for i in range(n_vocab) if i in codes])
begins = numpy.empty((n_vocab + 1,), dtype=numpy.int32)
begins[0] = 0
for i in range(0, n_vocab):
length = len(paths[i]) if i in paths else 0
begins[i + 1] = begins[i] + length
self.begins = begins
self.parser_size = parser.size()
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 3)
x_type, t_type, w_type = in_types
type_check.expect(
x_type.dtype == numpy.float32,
x_type.ndim == 2,
t_type.dtype == numpy.int32,
t_type.ndim == 1,
x_type.shape[0] == t_type.shape[0],
w_type.dtype == numpy.float32,
w_type.ndim == 2,
w_type.shape[0] == self.parser_size,
w_type.shape[1] == x_type.shape[1],
)
def to_gpu(self, device=None):
with cuda._get_device(device):
self.paths = cuda.to_gpu(self.paths)
self.codes = cuda.to_gpu(self.codes)
self.begins = cuda.to_gpu(self.begins)
def to_cpu(self):
self.paths = cuda.to_cpu(self.paths)
self.codes = cuda.to_cpu(self.codes)
self.begins = cuda.to_cpu(self.begins)
def forward_cpu(self, inputs):
x, t, W = inputs
loss = numpy.float32(0.0)
for ix, it in six.moves.zip(x, t):
loss += self._forward_cpu_one(ix, it, W)
return numpy.array(loss),
def _forward_cpu_one(self, x, t, W):
begin = self.begins[t]
end = self.begins[t + 1]
w = W[self.paths[begin:end]]
wxy = w.dot(x) * self.codes[begin:end]
loss = numpy.logaddexp(0.0, -wxy) # == log(1 + exp(-wxy))
return numpy.sum(loss)
def backward_cpu(self, inputs, grad_outputs):
x, t, W = inputs
gloss, = grad_outputs
gx = numpy.empty_like(x)
gW = numpy.zeros_like(W)
for i, (ix, it) in enumerate(six.moves.zip(x, t)):
gx[i] = self._backward_cpu_one(ix, it, W, gloss, gW)
return gx, None, gW
def _backward_cpu_one(self, x, t, W, gloss, gW):
begin = self.begins[t]
end = self.begins[t + 1]
path = self.paths[begin:end]
w = W[path]
wxy = w.dot(x) * self.codes[begin:end]
g = -gloss * self.codes[begin:end] / (1.0 + numpy.exp(wxy))
gx = g.dot(w)
gw = g.reshape((g.shape[0], 1)).dot(x.reshape(1, x.shape[0]))
gW[path] += gw
return gx
def forward_gpu(self, inputs):
x, t, W = inputs
max_length = cuda.reduce(
'T t, raw T begins', 'T out', 'begins[t + 1] - begins[t]',
'max(a, b)', 'out = a', '0',
'binary_hierarchical_softmax_max_length')(t, self.begins)
max_length = cuda.to_cpu(max_length)[()]
length = max_length * x.shape[0]
ls = cuda.cupy.empty((length,), dtype=numpy.float32)
n_in = x.shape[1]
wxy = cuda.cupy.empty_like(ls)
cuda.elementwise(
'''raw T x, raw T w, raw int32 ts, raw int32 paths,
raw T codes, raw int32 begins, int32 c, int32 max_length''',
'T ls, T wxy',
'''
int ind = i / max_length;
int offset = i - ind * max_length;
int t = ts[ind];
int begin = begins[t];
int length = begins[t + 1] - begins[t];
if (offset < length) {
int p = begin + offset;
int node = paths[p];
T wx = 0;
for (int j = 0; j < c; ++j) {
int w_ind[] = {node, j};
int x_ind[] = {ind, j};
wx += w[w_ind] * x[x_ind];
}
wxy = wx * codes[p];
ls = log(1 + exp(-wxy));
} else {
ls = 0;
}
''',
'binary_hierarchical_softmax_forward'
)(x, W, t, self.paths, self.codes, self.begins, n_in, max_length, ls,
wxy)
self.max_length = max_length
self.wxy = wxy
return ls.sum(),
def backward_gpu(self, inputs, grad_outputs):
x, t, W = inputs
gloss, = grad_outputs
n_in = x.shape[1]
gx = cuda.cupy.zeros_like(x)
gW = cuda.cupy.zeros_like(W)
cuda.elementwise(
'''T wxy, raw T x, raw T w, raw int32 ts, raw int32 paths,
raw T codes, raw int32 begins, raw T gloss,
int32 c, int32 max_length''',
'raw T gx, raw T gw',
'''
int ind = i / max_length;
int offset = i - ind * max_length;
int t = ts[ind];
int begin = begins[t];
int length = begins[t + 1] - begins[t];
if (offset < length) {
int p = begin + offset;
int node = paths[p];
T code = codes[p];
T g = -gloss[0] * code / (1.0 + exp(wxy));
for (int j = 0; j < c; ++j) {
int w_ind[] = {node, j};
int x_ind[] = {ind, j};
atomicAdd(&gx[x_ind], g * w[w_ind]);
atomicAdd(&gw[w_ind], g * x[x_ind]);
}
}
''',
'binary_hierarchical_softmax_bwd'
)(self.wxy, x, W, t, self.paths, self.codes, self.begins, gloss, n_in,
self.max_length, gx, gW)
return gx, None, gW
class BinaryHierarchicalSoftmax(link.Link):
"""Hierarchical softmax layer over binary tree.
In natural language applications, vocabulary size is too large to use
softmax loss.
Instead, the hierarchical softmax uses product of sigmoid functions.
It costs only :math:`O(\\log(n))` time where :math:`n` is the vocabulary
size in average.
At first a user need to prepare a binary tree whose each leaf is
corresponding to a word in a vocabulary.
When a word :math:`x` is given, exactly one path from the root of the tree
to the leaf of the word exists.
Let :math:`\\mbox{path}(x) = ((e_1, b_1), \\dots, (e_m, b_m))` be the path
of :math:`x`, where :math:`e_i` is an index of :math:`i`-th internal node,
and :math:`b_i \\in \\{-1, 1\\}` indicates direction to move at
:math:`i`-th internal node (-1 is left, and 1 is right).
Then, the probability of :math:`x` is given as below:
.. math::
P(x) &= \\prod_{(e_i, b_i) \\in \\mbox{path}(x)}P(b_i | e_i) \\\\
&= \\prod_{(e_i, b_i) \\in \\mbox{path}(x)}\\sigma(b_i x^\\top
w_{e_i}),
where :math:`\\sigma(\\cdot)` is a sigmoid function, and :math:`w` is a
weight matrix.
This function costs :math:`O(\\log(n))` time as an average length of paths
is :math:`O(\\log(n))`, and :math:`O(n)` memory as the number of internal
nodes equals :math:`n - 1`.
Args:
in_size (int): Dimension of input vectors.
tree: A binary tree made with tuples like `((1, 2), 3)`.
Attributes:
W (~chainer.Variable): Weight parameter matrix.
See: Hierarchical Probabilistic Neural Network Language Model [Morin+,
AISTAT2005].
"""
def __init__(self, in_size, tree):
# This function object is copied on every forward computation.
super(BinaryHierarchicalSoftmax, self).__init__()
self._func = BinaryHierarchicalSoftmaxFunction(tree)
with self.init_scope():
self.W = variable.Parameter(uniform.Uniform(1),
(self._func.parser_size, in_size))
def to_gpu(self, device=None):
with cuda._get_device(device):
super(BinaryHierarchicalSoftmax, self).to_gpu(device)
self._func.to_gpu(device)
def to_cpu(self):
super(BinaryHierarchicalSoftmax, self).to_cpu()
self._func.to_cpu()
@staticmethod
def create_huffman_tree(word_counts):
"""Makes a Huffman tree from a dictionary containing word counts.
This method creates a binary Huffman tree, that is required for
:class:`BinaryHierarchicalSoftmax`.
For example, ``{0: 8, 1: 5, 2: 6, 3: 4}`` is converted to
``((3, 1), (2, 0))``.
Args:
word_counts (dict of int key and int or float values):
Dictionary representing counts of words.
Returns:
Binary Huffman tree with tuples and keys of ``word_coutns``.
"""
if len(word_counts) == 0:
raise ValueError('Empty vocabulary')
q = six.moves.queue.PriorityQueue()
# Add unique id to each entry so that we can compare two entries with
# same counts.
# Note that itreitems randomly order the entries.
for uid, (w, c) in enumerate(six.iteritems(word_counts)):
q.put((c, uid, w))
while q.qsize() >= 2:
(count1, id1, word1) = q.get()
(count2, id2, word2) = q.get()
count = count1 + count2
tree = (word1, word2)
q.put((count, min(id1, id2), tree))
return q.get()[2]
def __call__(self, x, t):
"""Computes the loss value for given input and ground truth labels.
Args:
x (~chainer.Variable): Input to the classifier at each node.
t (~chainer.Variable): Batch of ground truth labels.
Returns:
~chainer.Variable: Loss value.
"""
f = copy.copy(self._func) # creates a copy of the function node
return f(x, t, self.W)