-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
crf1d.py
53 lines (38 loc) · 1.68 KB
/
crf1d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from chainer.functions.loss import crf1d
from chainer import initializers
from chainer import link
from chainer import variable
class CRF1d(link.Link):
"""Linear-chain conditional random field loss layer.
This link wraps the :func:`~chainer.functions.crf1d` function.
It holds a transition cost matrix as a parameter.
Args:
n_label (int): Number of labels.
initial_cost (:ref:`initializer <initializer>`): Initializer to
initialize the transition cost matrix.
If this attribute is not specified,
the transition cost matrix is initialized with zeros.
.. seealso:: :func:`~chainer.functions.crf1d` for more detail.
Attributes:
cost (~chainer.Variable): Transition cost parameter.
"""
def __init__(self, n_label, initial_cost=None):
super(CRF1d, self).__init__()
if initial_cost is None:
initial_cost = initializers.constant.Zero()
with self.init_scope():
self.cost = variable.Parameter(initializer=initial_cost,
shape=(n_label, n_label))
def forward(self, xs, ys, reduce='mean'):
return crf1d.crf1d(self.cost, xs, ys, reduce)
def argmax(self, xs):
"""Computes a state that maximizes a joint probability.
Args:
xs (list of Variable): Input vector for each label.
Returns:
tuple: A tuple of :class:`~chainer.Variable` representing each
log-likelihood and a list representing the argmax path.
.. seealso:: See :func:`~chainer.frunctions.crf1d_argmax` for more
detail.
"""
return crf1d.argmax_crf1d(self.cost, xs)