Skip to content

Commit

Permalink
Merge pull request #689 from jnishi/add-ctc-padding
Browse files Browse the repository at this point in the history
Add ctc padding
  • Loading branch information
okuta committed Mar 1, 2016
2 parents 7a33884 + 9300f72 commit a76abe1
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 91 deletions.
217 changes: 144 additions & 73 deletions chainer/functions/loss/ctc.py
Expand Up @@ -2,6 +2,7 @@
import numpy
import six

import chainer
from chainer import cuda
from chainer import function
from chainer import utils
Expand All @@ -16,8 +17,8 @@ def _logsumexp(a, xp, axis=None):


def _softmax(x, xp):
val = xp.exp(x - xp.amax(x, axis=1, keepdims=True))
val /= xp.sum(val, axis=1, keepdims=True)
val = xp.exp(x - xp.amax(x, axis=2, keepdims=True))
val /= xp.sum(val, axis=2, keepdims=True)
return val


Expand All @@ -32,8 +33,18 @@ def _log_dot(prob, rr, xp):
return _logsumexp(prob + xp.swapaxes(rr, 1, 2), xp, axis=2)


def _activate(yseq, xp):
return [_softmax(y, xp) for y in yseq]
def _move_label_to_back(path, path_length, xp):
s1 = path.shape[1] # TODO(okuta): Change name
index = (xp.arange(0, path.size, s1, dtype=numpy.int32)[:, None]
+ (xp.arange(s1) + path_length[:, None])[:, ::-1] % s1)
return xp.take(path, index)


def _move_inputs(prob, input_length, xp):
seq, batch, ch = prob.shape
rotate = (xp.arange(seq)[:, None] + input_length) % seq
index = rotate * batch + xp.arange(batch)
return xp.take(prob.reshape(seq * batch, ch), index, axis=0)


class ConnectionistTemporalClassification(function.Function):
Expand All @@ -52,13 +63,13 @@ def __init__(self, blank_symbol):
self.zero_padding = -10000000000.0

def check_type_forward(self, in_types):
type_check.expect(in_types.size() > 1)
l_type = in_types[0]
type_check.expect(in_types.size() > 3) # TODO(okuta): > 3?
l_type = in_types[2]
type_check.expect(l_type.dtype == numpy.int32)

x_basetype = in_types[1]
x_basetype = in_types[3] # TODO(oktua): Check x_basetype size

for i in six.moves.range(2, len(in_types)):
for i in six.moves.range(3, len(in_types)):
x_type = in_types[i]
type_check.expect(
x_type.dtype == numpy.float32,
Expand All @@ -74,110 +85,147 @@ def log_matrix(self, x, xp):
'y = x == 0 ? e : log(x)',
'create_recurrence_relation')
res = create_recurrence_relation(x, self.zero_padding)
return res
return res.astype(numpy.float32)

def recurrence_relation(self, size, dtype, xp):
def recurrence_relation(self, path_length, max_length, dtype, xp):
"""Transition in forword and backword algorithms is represented as matrix.
See also
https://blog.wtf.sg/2014/10/06/connectionist-temporal-classification-ctc-with-theano/
"""

rr = (xp.eye(size, dtype=dtype) +
xp.eye(size, k=1, dtype=dtype) +
xp.eye(size, k=2, dtype=dtype) *
(xp.arange(size, dtype=dtype) % dtype(2)))
return self.log_matrix(rr, xp)
rr = (xp.eye(max_length, dtype=dtype) +
xp.eye(max_length, k=1, dtype=dtype) +
xp.eye(max_length, k=2, dtype=dtype) *
(xp.arange(max_length, dtype=dtype) % dtype(2)))
return self.log_matrix(
rr * (path_length[:, None] > xp.arange(max_length))[..., None], xp)

# path probablity to label probability
def label_probability(self, label_size, path, multiply, xp):
def label_probability(self, label_size, path, path_length,
multiply_seq, xp):
labels_prob = self.log_matrix(xp.zeros((len(path), label_size),
dtype=multiply.dtype), xp)
dtype=multiply_seq.dtype), xp)
ret = xp.empty(
(len(multiply_seq),) + labels_prob.shape, dtype=labels_prob.dtype)
ret[...] = labels_prob
if xp == numpy:
for b in six.moves.range(len(path)):
chars = {c for c in path[b]}
target_path = path[b][0:path_length[b]]
chars = {c for c in target_path}
for c in chars:
labels_prob[b, c] = _logsumexp(
multiply[b, path[b] == c], numpy)
ret[:, b, c] = _logsumexp(
multiply_seq[:, b, 0:path_length[b]]
[:, target_path == c], numpy, axis=1)
else:
cuda.cupy.ElementwiseKernel(
'raw T x, raw I y, I b_max, I c_max',
'T z',
'''
T value = z;
I c = i % b_max, b = i / b_max;
int ind[2] = {b, -1};
for (int index = 0; index < c_max; ++index) {
ind[1] = index;
if (y[ind] == c) {
T xvalue = x[ind];
if (value > xvalue) {
value = value + log(1 + exp(xvalue - value));
} else {
value = xvalue + log(1 + exp(value - xvalue));
for i, multiply in enumerate(multiply_seq):
# TODO(okuta): remove loop
cuda.cupy.ElementwiseKernel(
'raw T x, raw I y, raw I l, I b_max, I c_max',
'T z',
'''
T value = z;
I c = i % b_max, b = i / b_max;
int ind[2] = {b, -1};
for (int index = 0; index < c_max; ++index) {
ind[1] = index;
if (ind[1] < l[ind[0]] && y[ind] == c) {
T xvalue = x[ind];
T at = xvalue, bt = value;
if (value > xvalue) {
at = value;
bt = xvalue;
}
value = at + log(1 + exp(bt - at));
}
}
z = value;
}
''',
'reduce_probability')(multiply, path, labels_prob.shape[1],
path.shape[1], labels_prob)
return labels_prob
''',
'reduce_probability')(multiply, path, path_length,
labels_prob.shape[1],
path.shape[1], ret[i])
return ret

def calc_trans(self, path, yseq, rr, xp):
def calc_trans(self, path, yseq, xp):
forward_prob = self.log_matrix(
xp.eye(path.shape[1], dtype='f')[0], xp)[None, :]
backward_prob = forward_prob
offset = xp.arange(
0, yseq[0].size, yseq[0].shape[1], dtype=path.dtype)[:, None]

# prob[i] := forward[i] + backward[-i-1]
prob = []
index = offset + path
for y in yseq:
frr = self.recurrence_relation(
self.path_length, path.shape[1], numpy.float32, xp)
prob = xp.empty(
(len(yseq),) + index.shape, dtype=forward_prob.dtype)
# forward computation.
for i, y in enumerate(yseq):
# calc forward probability in log scale
forward_prob = xp.take(y, index) + _log_dot(
forward_prob[:, None, :], rr, xp)
prob.append(forward_prob)

r_index = offset + path[:, ::-1]
for i, y_inv in enumerate(yseq[::-1]):
forward_prob[:, None, :], frr, xp)
prob[i] = forward_prob
r_index = offset + _move_label_to_back(path, self.path_length, xp)

# rotate yseq with path_length
yseq_inv = _move_inputs(yseq, self.input_length, xp)[::-1]
brr = self.recurrence_relation(
self.path_length, path.shape[1], numpy.float32, xp)

# move to back.
prob = _move_inputs(prob, self.input_length, xp)

# backward computation.
ps1 = path.shape[1]
backward_prob_index = (
xp.arange(0, path.size, ps1, dtype=numpy.int32)[:, None]
+ (xp.arange(ps1) - self.path_length[:, None]) % ps1)
for i, y_inv in enumerate(yseq_inv):
# calc backward probability
backward_prob = _log_dot(backward_prob[:, None, :], rr, xp)
prob[-i - 1] += backward_prob[:, ::-1]
backward_prob = _log_dot(backward_prob[:, None, :], brr, xp)
prob[-i - 1] += xp.take(
backward_prob[:, ::-1], backward_prob_index)
backward_prob = xp.take(y_inv, r_index) + backward_prob
return prob

# move to front.
return _move_inputs(prob, -self.input_length, xp)

def forward(self, inputs):
xp = cuda.get_array_module(inputs[0])
batch_size = len(inputs[0])
self.yseq = _activate(inputs[1::], xp)
log_yseq = [self.log_matrix(y, xp) for y in self.yseq]
self.path = _label_to_path(inputs[0], self.blank_symbol, xp)
rr = self.recurrence_relation(
self.path.shape[1], numpy.float32, xp)[None, :, :]
self.prob_trans = self.calc_trans(self.path, log_yseq, rr, xp)
self.input_length = inputs[0]

# The length of path is (2 * label_length + 1)
self.path_length = 2 * inputs[1] + 1

batch_size = len(inputs[2])
yseq_shape = (len(inputs) - 3,) + inputs[3].shape
self.yseq = _softmax(xp.vstack(inputs[3::]).reshape(yseq_shape), xp)
log_yseq = self.log_matrix(self.yseq, xp)
self.path = _label_to_path(inputs[2], self.blank_symbol, xp)
self.prob_trans = self.calc_trans(self.path, log_yseq, xp)

loss = utils.force_array(xp.sum(
_logsumexp(self.prob_trans[-1], xp, axis=1)))
_logsumexp(self.prob_trans[0], xp, axis=1)))
loss /= -batch_size
return loss,

def backward(self, inputs, grad_output):
xp = cuda.get_array_module(inputs[0])
batch_size = len(inputs[0])
batch_size = len(inputs[2])

total_probability = _logsumexp(self.prob_trans[0], xp, axis=1)
scale = grad_output[0] / batch_size
for y, prob in zip(self.yseq, self.prob_trans):
label_prob = self.label_probability(
y.shape[1], self.path, prob, xp)
y -= xp.exp(label_prob - total_probability[:, None])
y *= scale
return (None,) + tuple(self.yseq)


def connectionist_temporal_classification(x, t, blank_symbol):
label_prob = self.label_probability(
self.yseq.shape[2], self.path, self.path_length,
self.prob_trans, xp)
self.yseq -= xp.exp(label_prob - total_probability[:, None])
self.yseq *= grad_output[0] / batch_size
# mask
self.yseq *= (
xp.arange(len(self.yseq))[:, None] < self.input_length)[..., None]
return (None, None, None) + tuple([y for y in self.yseq])


def connectionist_temporal_classification(
x, t, blank_symbol, input_length=None, label_length=None):
"""Connectionist Temporal Classification loss function.
Connectionist Temporal Classification(CTC) [Graves2006]_ is a loss function
Expand All @@ -191,7 +239,13 @@ def connectionist_temporal_classification(x, t, blank_symbol):
``i``.
t (Variable): Expected label sequence.
blank_symbol (int): Index of blank_symbol.
This value must be non-negative.
This value must be non-negative.
input_length (Variable): Length of valid sequence for each of mini
batch x (optional). If input_length is skipped, It regards that all
of x is valid input.
label_length (Variable): Length of valid sequence for each of mini
batch t (optional). If label_length is skipped, It regards that all
of t is valid input.
Returns:
Variable: A variable holding a scalar value of the CTC loss.
Expand Down Expand Up @@ -230,4 +284,21 @@ def connectionist_temporal_classification(x, t, blank_symbol):
# TODO(jnishi): Support d(>1)-dimentinal inputs.
assert(len(x[0].data.shape) == 2)

return ConnectionistTemporalClassification(blank_symbol)(t, *x)
if input_length is None:
xp = cuda.get_array_module(x[0].data)
input_length = chainer.Variable(
xp.full((len(x[0].data),), len(x), dtype=numpy.int32))
label_length = chainer.Variable(
xp.full((len(t.data),), len(t.data[0]), dtype=numpy.int32))

# Batch size check.
assert len(x[0].data) == len(t.data)
assert len(x[0].data) == len(input_length.data)
assert len(x[0].data) == len(label_length.data)

# Length check.
assert len(x) >= max(input_length.data)
assert len(t.data[0]) >= max(label_length.data)

return ConnectionistTemporalClassification(blank_symbol)(
input_length, label_length, t, *x)

0 comments on commit a76abe1

Please sign in to comment.