In [2]:
import torch

In [16]:
class CRF(torch.nn.Module):
    def __init__(self, num_tags):
        super(CRF, self).__init__()
        self.num_tags = num_tags + 2
        self.start = self.num_tags-2
        self.end = self.start+1
        self.transitions = torch.nn.Parameter(torch.randn(self.num_tags, self.num_tags))
    
    def forward_score(self,features):
        scores = torch.ones(features.shape[0], self.num_tags) * -6969
        scores[:,self.start] = 0
        for i in range(features.shape[1]):
            feat = features[:,i]
            score = scores.unsqueeze(1) + feat.unsqueeze(2) + self.transitions.unsqueeze(0)
            scores = torch.logsumexp(score, dim=-1)
        scores = scores + self.transitions[self.end]
        return torch.logsumexp(scores, dim=-1)
    
    def score_sentence(self,features,tags):
        scores = features.gather(2, tags.unsqueeze(2)).squeeze(2)
        start = torch.ones(features.shape[0],1,dtype=torch.long) * self.start
        tags = torch.cat([start,tags],dim=1)
        trans_scores = self.transitions[tags[:,:-1],tags[:,1:]]
        last_tag = tags.gather(1,torch.ones(features.shape[0],1,dtype=torch.long) * features.shape[1])
        last_scores = self.transitions[self.end,last_tag]
        return (trans_scores + scores).sum(dim=1) + last_scores
    
    def viterbi_decode(self,features):
        scores = torch.ones(features.shape[0], self.num_tags) * -6969
        ptrs = torch.zeros_like(features)
        scores[:,self.start] = 0
        for i in range(features.shape[1]):
            feat = features[:,i]
            score = scores.unsqueeze(1) + self.transitions

            score, ptrs[:,i,:] = score.max(dim=-1)
            score += feat
            scores = score

        scores += self.transitions[self.end]
        scores, idx = scores.max(dim=-1)
        best_paths = []
        ptrs = ptrs.cpu().numpy()
        for i in range(features.shape[0]):
            bt = idx[i].item()
            best_path = [bt]
            for ptr in reversed(ptrs[i]):
                bt = int(ptr[bt])
                best_path.append(bt)
            best_path.pop()
            best_paths.append(best_path[::-1])
        return scores, best_paths
    
    def forward(self,features):
        return self.viterbi_decode(features)

    def loss(self,features,tags):
        forward_score = self.forward_score(features)
        gold_score = self.score_sentence(features,tags.long())
        return (forward_score - gold_score).mean()

In [17]:
crf = CRF(10)
a = torch.randn(2,40,12)
b = torch.randint(0,10,(2,40))
c = crf.loss(a,b)
print(c)

torch.Size([2, 40]) torch.Size([2, 40]) torch.Size([2, 1])
tensor(153.3246, grad_fn=<MeanBackward0>)
