/
black_out.py
84 lines (66 loc) · 3.08 KB
/
black_out.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from chainer.functions.array import broadcast
from chainer.functions.array import concat
from chainer.functions.array import expand_dims
from chainer.functions.array import reshape
from chainer.functions.connection import embed_id
from chainer.functions.math import average
from chainer.functions.math import exponential
from chainer.functions.math import logsumexp
from chainer.functions.math import matmul
from chainer.functions.math import sum as _sum
def black_out(x, t, W, samples, reduce='mean'):
"""BlackOut loss function.
BlackOut loss function is defined as
.. math::
-\\log(p(t)) - \\sum_{s \\in S} \\log(1 - p(s)),
where :math:`t` is the correct label, :math:`S` is a set of negative
examples and :math:`p(\\cdot)` is likelihood of a given label.
And, :math:`p` is defined as
.. math::
p(y) = \\frac{\\exp(W_y^\\top x)}{
\\sum_{s \\in samples} \\exp(W_s^\\top x)}.
The output is a variable whose value depends on the value of
the option ``reduce``. If it is ``'no'``, it holds the
no loss values. If it is ``'mean'``, this function takes
a mean of loss values.
Args:
x (~chainer.Variable): Batch of input vectors.
Its shape should be :math:`(N, D)`.
t (~chainer.Variable): Vector of ground truth labels.
Its shape should be :math:`(N,)`. Each elements :math:`v`
should satisfy :math:`0 \\geq v \\geq V` or :math:`-1`
where :math:`V` is the number of label types.
W (~chainer.Variable): Weight matrix.
Its shape should be :math:`(V, D)`
samples (~chainer.Variable): Negative samples.
Its shape should be :math:`(N, S)` where :math:`S` is
the number of negative samples.
reduce (str): Reduction option. Its value must be either
``'no'`` or ``'mean'``. Otherwise,
:class:`ValueError` is raised.
Returns:
~chainer.Variable:
A variable object holding loss value(s).
If ``reduce`` is ``'no'``, the output variable holds an
array whose shape is :math:`(N,)` .
If it is ``'mean'``, it holds a scalar.
See: `BlackOut: Speeding up Recurrent Neural Network Language Models With \
Very Large Vocabularies <https://arxiv.org/abs/1511.06909>`_
.. seealso:: :class:`~chainer.links.BlackOut`.
"""
batch_size = x.shape[0]
neg_emb = embed_id.embed_id(samples, W)
neg_y = matmul.matmul(neg_emb, x[:, :, None])
neg_y = reshape.reshape(neg_y, neg_y.shape[:-1])
pos_emb = expand_dims.expand_dims(embed_id.embed_id(t, W), 1)
pos_y = matmul.matmul(pos_emb, x[:, :, None])
pos_y = reshape.reshape(pos_y, pos_y.shape[:-1])
logz = logsumexp.logsumexp(concat.concat([pos_y, neg_y]), axis=1)
blogz, bneg_y = broadcast.broadcast(
reshape.reshape(logz, (batch_size, 1)), neg_y)
ny = exponential.log(1 - exponential.exp(bneg_y - blogz))
py = reshape.reshape(pos_y, (batch_size,))
loss = -(py - logz + _sum.sum(ny, axis=1))
if reduce == 'mean':
loss = average.average(loss)
return loss