In [None]:
!pip install pytorch-crf

In [103]:
import torch
from torch import nn
from torch import optim
from torchcrf import CRF

In [24]:
import numpy

In [259]:
def sigmoid(a):
    return 1./(1. + numpy.exp(-a))

In [262]:
''' generate data '''
transition = numpy.array([[0.8, 0.2], [0.1, 0.9]])

N = 50
L = 50
st_seqs = []
obs_seqs = []

for n in range(N):
    init = int(numpy.random.rand() > 0.5)
    s = []
    o = []
    for t in range(L):
        tp = transition[init][0]
        s.append(int(numpy.random.rand() > tp))
        op = sigmoid((2. * s[-1] - 1.) + 0.3 * numpy.random.randn())
        iop = 1. - op
        if s == 0:
            o.append([op, iop])
        else:
            o.append([iop, op])
    st_seqs.append(s)
    obs_seqs.append(o)
    
st_seqs = torch.from_numpy(numpy.array(st_seqs))
obs_seqs = torch.from_numpy(numpy.log(numpy.array(obs_seqs)))

In [267]:
model = CRF(2)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [264]:
loss = None
disp_inter = 100

for ei in range(5000):
    optimizer.zero_grad()
    loss_ = -model(obs_seqs, st_seqs)
    loss_.backward()
    optimizer.step()
    
    if loss is None:
        loss = loss_.item()
    else:
        loss = 0.9 * loss + 0.1 * loss_.item()
        
    if numpy.mod(ei, disp_inter) == 0:
        print(F'loss {loss} at update {ei+1}')

loss 30.995348678424037 at update 1
loss 26.11239015826308 at update 101
loss 23.028353541460493 at update 201
loss 21.434001959671964 at update 301
loss 20.63814698263612 at update 401
loss 20.239464190448686 at update 501
loss 20.034024937085633 at update 601
loss 19.924984346355533 at update 701
loss 19.86593551525732 at update 801
loss 19.833452191134437 at update 901
loss 19.815233843849377 at update 1001
loss 19.80476783415728 at update 1101
loss 19.798618443412316 at update 1201
loss 19.794958430289736 at update 1301
loss 19.792782779533358 at update 1401
loss 19.79150978868243 at update 1501
loss 19.79078475966418 at update 1601
loss 19.790386144161072 at update 1701
loss 19.790175743860825 at update 1801
loss 19.790069472792123 at update 1901
loss 19.790018352021466 at update 2001
loss 19.789994945698922 at update 2101
loss 19.78998466662761 at update 2201
loss 19.789980392314583 at update 2301
loss 19.789978704492857 at update 2401
loss 19.7899782624431 at update 2501
loss 19

In [265]:
numpy.array(model.decode(obs_seqs[1].unsqueeze(0))).T[:20]

array([[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1,
        1, 1, 0, 0, 0, 1]])

In [266]:
st_seqs[1]

tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0,
        0, 0])

In [255]:
def softmax(x):
    return numpy.exp(x) / numpy.exp(x).sum(-1, keepdims=True)

In [256]:
softmax(model.transitions.data.numpy())

array([[0.42429057, 0.5757094 ],
       [0.6192078 , 0.38079214]], dtype=float32)

In [257]:
model.transitions

Parameter containing:
tensor([[ 0.2340,  0.5392],
        [ 0.0278, -0.4584]], requires_grad=True)