/
ctc.py
351 lines (297 loc) · 13.8 KB
/
ctc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import collections
import numpy
import six
import chainer
from chainer.backends import cuda
from chainer import function
from chainer import utils
from chainer.utils import type_check
def _logsumexp(a, xp, axis=None):
vmax = xp.amax(a, axis=axis, keepdims=True)
if xp is numpy:
vmax += xp.log(xp.sum(xp.exp(a - vmax),
axis=axis, keepdims=True, dtype=a.dtype))
else:
_logsumexp_impl = cuda.reduce(
'T x, T vmax', 'T y',
'exp(x - vmax)', 'a + b', 'y += log(a)', '0',
'logsumexp_impl')
_logsumexp_impl(a, vmax, vmax, axis=axis, keepdims=True)
return xp.squeeze(vmax, axis=axis)
def _softmax(x, xp):
val = xp.exp(x - xp.amax(x, axis=2, keepdims=True))
val /= xp.sum(val, axis=2, keepdims=True)
return val
def _label_to_path(labels, blank_symbol, xp):
path = xp.full((len(labels), labels.shape[1] * 2 + 1),
blank_symbol, dtype=numpy.int32)
path[:, 1::2] = labels
return path
def _log_dot(prob, rr, xp):
return _logsumexp(prob + xp.swapaxes(rr, 1, 2), xp, axis=2)
def _move_label_to_back(path, path_length, xp):
s1 = path.shape[1] # TODO(okuta): Change name
index = (xp.arange(0, path.size, s1, dtype=numpy.int32)[:, None] +
(xp.arange(s1) + path_length[:, None])[:, ::-1] % s1)
return xp.take(path, index)
def _move_inputs(prob, input_length, xp):
seq, batch, ch = prob.shape
rotate = (xp.arange(seq)[:, None] + input_length) % seq
index = rotate * batch + xp.arange(batch)
return xp.take(prob.reshape(seq * batch, ch), index, axis=0)
class ConnectionistTemporalClassification(function.Function):
"""The implementation of Connectionist Temporal Classfication loss functions.
To make it usable for real-world cases, this class has two policies below.
1. This class computes forward and backward variables in the log domain.
2. This class applies the softmax function to inputs. The Backward
values of CTC loss is often overflows. This is avoided by computing
backward values before the activation function is applied.
"""
def __init__(self, blank_symbol, reduce='mean'):
self.blank_symbol = blank_symbol
self.zero_padding = -10000000000.0
if reduce not in ('mean', 'no'):
raise ValueError(
"only 'mean' and 'no' are valid "
"for 'reduce', but '%s' is given" % reduce)
self.reduce = reduce
def check_type_forward(self, in_types):
type_check.expect(in_types.size() > 3) # TODO(okuta): > 3?
l_type = in_types[2]
type_check.expect(l_type.dtype == numpy.int32)
x_basetype = in_types[3] # TODO(oktua): Check x_basetype size
for i in six.moves.range(3, len(in_types)):
x_type = in_types[i]
type_check.expect(
x_type.dtype == numpy.float32,
x_type.shape == x_basetype.shape,
)
def log_matrix(self, x, xp):
if xp == numpy:
res = numpy.ma.log(x).filled(fill_value=self.zero_padding)
else:
create_recurrence_relation = cuda.cupy.ElementwiseKernel(
'T x, T e', 'T y',
'y = x == 0 ? e : log(x)',
'create_recurrence_relation')
res = create_recurrence_relation(x, self.zero_padding)
return res.astype(numpy.float32)
def recurrence_relation(self, label, path_length, max_length, dtype, xp):
"""Transition in forword and backword algorithms is represented as matrix.
See also
https://blog.wtf.sg/2014/10/06/connectionist-temporal-classification-ctc-with-theano/
"""
batch, lab = label.shape
repeat_mask = xp.ones((batch, lab * 2 + 1))
repeat_mask[:, 1::2] = (label !=
xp.take(label, xp.arange(-1, lab - 1)
% lab + xp.arange(0, batch * lab,
lab)[:, None]))
repeat_mask[:, 1] = 1
rr = (xp.eye(max_length, dtype=dtype)[None, :] +
xp.eye(max_length, k=1, dtype=dtype)[None, :] +
(xp.eye(max_length, k=2, dtype=dtype) *
(xp.arange(max_length, dtype=dtype) % dtype(2))[None, :]
* repeat_mask[:, None]))
return self.log_matrix(
rr * (path_length[:, None] > xp.arange(max_length))[..., None], xp)
# path probablity to label probability
def label_probability(self, label_size, path, path_length,
multiply_seq, xp):
labels_prob = self.log_matrix(xp.zeros((len(path), label_size),
dtype=multiply_seq.dtype), xp)
ret = xp.empty(
(len(multiply_seq),) + labels_prob.shape, dtype=labels_prob.dtype)
ret[...] = labels_prob
if xp == numpy:
for b in six.moves.range(len(path)):
target_path = path[b][0:path_length[b]]
chars = {c for c in target_path}
for c in chars:
ret[:, b, c] = _logsumexp(
multiply_seq[:, b, 0:path_length[b]]
[:, target_path == c], numpy, axis=1)
else:
for i, multiply in enumerate(multiply_seq):
# TODO(okuta): remove loop
cuda.cupy.ElementwiseKernel(
'raw T x, raw I y, raw I l, I b_max, I c_max',
'T z',
'''
T value = z;
I b = i / b_max;
I c = i - b * b_max;
int ind[2] = {b, -1};
for (int index = 0; index < c_max; ++index) {
ind[1] = index;
if (ind[1] < l[ind[0]] && y[ind] == c) {
T xvalue = x[ind];
T at = xvalue, bt = value;
if (value > xvalue) {
at = value;
bt = xvalue;
}
value = at + log1p(exp(bt - at));
}
}
z = value;
''',
'reduce_probability')(multiply, path, path_length,
labels_prob.shape[1],
path.shape[1], ret[i])
return ret
def calc_trans(self, yseq, input_length,
label, label_length, path, path_length, xp):
forward_prob = self.log_matrix(
xp.eye(path.shape[1], dtype='f')[0], xp)[None, :]
backward_prob = forward_prob
offset = xp.arange(
0, yseq[0].size, yseq[0].shape[1], dtype=path.dtype)[:, None]
# prob[i] := forward[i] + backward[-i-1]
index = offset + path
frr = self.recurrence_relation(
label, path_length, path.shape[1], numpy.float32, xp)
prob = xp.empty(
(len(yseq),) + index.shape, dtype=forward_prob.dtype)
# forward computation.
for i, y in enumerate(yseq):
# calc forward probability in log scale
forward_prob = xp.take(y, index) + _log_dot(
forward_prob[:, None, :], frr, xp)
prob[i] = forward_prob
r_index = offset + _move_label_to_back(path, path_length, xp)
# rotate yseq with path_length
yseq_inv = _move_inputs(yseq, input_length, xp)[::-1]
brr = self.recurrence_relation(
_move_label_to_back(label, label_length, xp),
path_length, path.shape[1], numpy.float32, xp)
# move to back.
prob = _move_inputs(prob, input_length, xp)
# backward computation.
ps1 = path.shape[1]
backward_prob_index = (
xp.arange(0, path.size, ps1, dtype=numpy.int32)[:, None] +
(xp.arange(ps1) - path_length[:, None]) % ps1)
for i, y_inv in enumerate(yseq_inv):
# calc backward probability
backward_prob = _log_dot(backward_prob[:, None, :], brr, xp)
prob[-i - 1] += xp.take(
backward_prob[:, ::-1], backward_prob_index)
backward_prob += xp.take(y_inv, r_index)
# move to front.
return _move_inputs(prob, -self.input_length, xp)
def forward(self, inputs):
xp = cuda.get_array_module(inputs[0])
self.input_length = inputs[0]
label_length = inputs[1]
t = inputs[2]
xs = inputs[3:]
if chainer.is_debug():
# Batch size check.
assert len(xs[0]) == len(t)
assert len(xs[0]) == len(self.input_length)
assert len(xs[0]) == len(label_length)
# Length check.
assert len(xs) >= xp.max(self.input_length)
assert len(t[0]) >= xp.max(label_length)
self.path_length = 2 * label_length + 1
yseq_shape = (len(xs),) + xs[0].shape
self.yseq = _softmax(xp.vstack(xs).reshape(yseq_shape), xp)
log_yseq = self.log_matrix(self.yseq, xp)
self.path = _label_to_path(t, self.blank_symbol, xp)
self.prob_trans = self.calc_trans(
log_yseq, self.input_length, t,
label_length, self.path, self.path_length, xp)
loss = -_logsumexp(self.prob_trans[0], xp, axis=1)
if self.reduce == 'mean':
loss = utils.force_array(xp.mean(loss))
return loss,
def backward(self, inputs, grad_output):
xp = cuda.get_array_module(inputs[0])
batch_size = len(inputs[2])
total_probability = _logsumexp(self.prob_trans[0], xp, axis=1)
label_prob = self.label_probability(
self.yseq.shape[2], self.path, self.path_length,
self.prob_trans, xp)
self.yseq -= xp.exp(label_prob - total_probability[:, None])
if self.reduce == 'mean':
self.yseq *= grad_output[0] / batch_size
else:
self.yseq *= grad_output[0][..., None]
# mask
self.yseq *= (
xp.arange(len(self.yseq))[:, None] < self.input_length)[..., None]
return (None, None, None) + tuple([y for y in self.yseq])
def connectionist_temporal_classification(
x, t, blank_symbol, input_length=None, label_length=None,
reduce='mean'):
"""Connectionist Temporal Classification loss function.
Connectionist Temporal Classification(CTC) [Graves2006]_ is a loss function
of sequence labeling where the alignment between the inputs and target is
unknown. See also [Graves2012]_
The output is a variable whose value depends on the value of
the option ``reduce``. If it is ``'no'``, it holds the samplewise
loss values. If it is ``'mean'``, it takes the mean of loss values.
Args:
x (list or tuple of :class:`~chainer.Variable`):
RNN output at each time. Each element of ``x``, ``x[i]``
is a :class:`~chainer.Variable` representing output of RNN at time
``i``.
t (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Expected label sequence.
blank_symbol (int): Index of blank_symbol.
This value must be non-negative.
input_length (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Length of valid sequence for each of mini
batch ``x`` (optional). If input_length is skipped, It regards that
all of ``x`` is valid input.
label_length (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Length of valid sequence for each of mini
batch ``t`` (optional). If label_length is skipped, It regards that
all of ``t`` is valid input.
reduce (str): Reduction option. Its value must be either
``'mean'`` or ``'no'``. Otherwise,
:class:`ValueError` is raised.
Returns:
~chainer.Variable:
A variable holding a scalar value of the CTC loss.
If ``reduce`` is ``'no'``, the output variable holds array
whose shape is `(B,)` where `B` is the number of samples.
If it is ``'mean'``, it holds a scalar.
.. note::
You need to input ``x`` without applying to activation functions(e.g.
softmax function), because this function applies softmax functions
to ``x`` before calculating CTC loss to avoid numerical limitations.
You also need to apply softmax function to forwarded values before you
decode it.
.. note::
This function is differentiable only by ``x``.
.. note::
This function supports (batch, sequence, 1-dimensional input)-data.
.. [Graves2006] Alex Graves, Santiago Fernandez,\
Faustino Gomez, Jurgen Schmidhuber,\
`Connectionist Temporal Classification: Labelling Unsegmented\
Sequence Data with Recurrent Neural Networks\
<ftp://ftp.idsia.ch/pub/juergen/icml2006.pdf>`_
.. [Graves2012] Alex Graves,\
`Supervised Sequence Labelling with Recurrent Neural Networks\
<http://www.cs.toronto.edu/~graves/preprint.pdf>`_
"""
if not isinstance(x, collections.Sequence):
raise TypeError('x must be a list of Variables')
if not isinstance(blank_symbol, int):
raise TypeError('blank_symbol must be non-negative integer.')
assert 0 <= blank_symbol < x[0].shape[1]
# This implementation only supports 1-dimensional data.
# TODO(jnishi): Support d(>1)-dimentinal inputs.
assert x[0].ndim == 2
xp = cuda.get_array_module(x[0])
if input_length is None:
input_length = xp.full(len(x[0]), len(x), dtype=numpy.int32)
if label_length is None:
label_length = xp.full(len(t), t.shape[1], dtype=numpy.int32)
return ConnectionistTemporalClassification(blank_symbol, reduce)(
input_length, label_length, t, *x)