In [None]:
!pip install pytorch-crf

In [103]:
import torch
from torch import nn
from torch import optim
from torchcrf import CRF

In [24]:
import numpy

In [90]:
class BinarySmoother(nn.Module):
    def __init__(self):
        super(BinarySmoother, self).__init__()
        
        self.pairwise_pot = nn.Parameter(torch.randn(1,2,2))
        
        
    def forward(self, unary_pot):
        potentials = unary_pot.unsqueeze(2) + self.pairwise_pot
#         print(potentials.size())
        crf = LinearChainCRF(potentials)
        return crf

In [105]:
def sigmoid(a):
    return 1./(1. + numpy.exp(-a))

In [188]:
''' generate data '''
transition = numpy.array([[0.8, 0.2], [0.1, 0.9]])

N = 100
L = 50
st_seqs = []
obs_seqs = []

for n in range(N):
    init = int(numpy.random.rand() > 0.5)
    s = []
    o = []
    for t in range(L):
        tp = transition[init][0]
        s.append(int(numpy.random.rand() > tp))
        op = sigmoid((2. * s[-1] - 1.) * 2. + 0.1 * numpy.random.randn())
        iop = 1. - op
        if s == 0:
            o.append([op, iop])
        else:
            o.append([iop, op])
    st_seqs.append(s)
    obs_seqs.append(o)
    
st_seqs = torch.from_numpy(numpy.array(st_seqs))
obs_seqs = torch.from_numpy(numpy.array(obs_seqs))

In [206]:
model = CRF(2)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [207]:
loss = None
disp_inter = 100

for ei in range(5000):
    optimizer.zero_grad()
    loss_ = -model(obs_seqs, st_seqs)
    loss_.backward()
    optimizer.step()
    
    if loss is None:
        loss = loss_.item()
    else:
        loss = 0.9 * loss + 0.1 * loss_.item()
        
    if numpy.mod(ei, disp_inter) == 0:
        print(F'loss {loss} at update {ei+1}')

loss 1917.6931171738486 at update 1
loss 1909.2042420170983 at update 101
loss 1905.467456630413 at update 201
loss 1902.56306218532 at update 301
loss 1900.3066534927855 at update 401
loss 1898.5524277859872 at update 501
loss 1897.1895918313498 at update 601
loss 1896.1288628318264 at update 701
loss 1895.3023763317246 at update 801
loss 1894.6589484288222 at update 901
loss 1894.1574130653146 at update 1001
loss 1893.7666054018673 at update 1101


KeyboardInterrupt: 

In [211]:
numpy.array(model.decode(obs_seqs[4].unsqueeze(0))).T[:20]

array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1]])

In [212]:
st_seqs[4]

tensor([0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        1, 0])

In [213]:
obs_seqs[4,:,0]

tensor([0.8626, 0.8761, 0.8530, 0.8868, 0.1139, 0.8784, 0.8756, 0.1199, 0.8779,
        0.8809, 0.8895, 0.1176, 0.8796, 0.8964, 0.8927, 0.8662, 0.8771, 0.8880,
        0.8852, 0.8814, 0.1128, 0.8739, 0.1152, 0.8911, 0.8756, 0.8938, 0.1180,
        0.8815, 0.8735, 0.8780, 0.8833, 0.8854, 0.8825, 0.8790, 0.8775, 0.9039,
        0.8804, 0.8830, 0.8728, 0.8876, 0.8842, 0.8653, 0.8813, 0.8599, 0.8883,
        0.8779, 0.8888, 0.1238, 0.1216, 0.8751], dtype=torch.float64)