In [1]:
import numpy as np
import torch
import copy
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn
import torch.utils.data
import gc
import csv

In [2]:
def log_sum_exp(input, keepdim=False):
    assert input.dim() == 2
    max_scores, _ = input.max(dim=-1, keepdim=True)
    output = input - max_scores
    return max_scores + torch.log(torch.sum(torch.exp(output), dim=-1, keepdim=keepdim))


def gather_index(input, index):
    assert input.dim() == 2 and index.dim() == 1
    index = index.unsqueeze(1).expand_as(input)
    output = torch.gather(input, 1, index)
    return output[:, 0]

In [3]:
class CRF(nn.Module):
    def __init__(self, label_size,tag_to_ix,use_cuda=False):
        super().__init__()
        self.label_size = label_size
        self.transitions = nn.Parameter(
            torch.randn(label_size, label_size))
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
        if use_cuda:
            self.torch = torch.cuda
        else:
            self.torch = torch
        

    def _score_sentence(self, feats, tags):
        bsz, sent_len, l_size = feats.size()
        score = self.torch.FloatTensor(bsz).fill_(0.)
        
        s_score = self.torch.LongTensor([[tag_to_ix[START_TAG]]] * bsz)
        
        tags = torch.cat([s_score, tags], dim=-1)
        feats_t = feats.transpose(0, 1)

        for i, feat in enumerate(feats_t):
            temp = self.transitions.index_select(1, tags[:, i])
            bsz_t = gather_index(temp.transpose(0, 1), tags[:, i + 1])
            w_step_score = gather_index(feat, tags[:, i + 1])
            score = score + bsz_t + w_step_score

        temp = self.transitions.index_select(1, tags[:, -1])
        bsz_t = gather_index(temp.transpose(0, 1),
                             (self.torch.LongTensor([tag_to_ix[STOP_TAG]] * bsz)))
        return score + bsz_t

    def forward(self, feats):
        bsz, sent_len, l_size = feats.size()
        init_alphas = self.torch.FloatTensor(
            bsz, self.label_size).fill_(-10000.)
        
        init_alphas[:, tag_to_ix[START_TAG]].fill_(0.) 
        forward_var = init_alphas
        feats_t = feats.transpose(0, 1)
        for feat in feats_t:
            alphas_t = []
            for next_tag in range(self.label_size):
                emit_score = feat[:, next_tag].view(-1, 1)
                trans_score = self.transitions[next_tag].view(1, -1)
                next_tag_var = forward_var + trans_score + emit_score
                alphas_t.append(log_sum_exp(next_tag_var, True))
            forward_var = torch.cat(alphas_t, dim=-1)
        forward_var = forward_var + self.transitions[tag_to_ix[STOP_TAG]].view(
            1, -1)
        
        return log_sum_exp(forward_var)

    def viterbi_decode(self, feats):
        backpointers = []
        bsz, sent_len, l_size = feats.size()

        init_vvars = self.torch.FloatTensor(
            bsz, self.label_size).fill_(-10000.)
        
        init_vvars[:, tag_to_ix[START_TAG]].fill_(0.)
        forward_var = init_vvars

        feats_t = feats.transpose(0, 1)
        for feat in feats_t:
            bptrs_t = []
            viterbivars_t = []

            for next_tag in range(self.label_size):
                _trans = self.transitions[next_tag].view(
                    1, -1).expand_as(feat)
                next_tag_var = forward_var + _trans
                best_tag_scores, best_tag_ids = torch.max(
                    next_tag_var, 1, keepdim=True)  
                bptrs_t.append(best_tag_ids)
                viterbivars_t.append(best_tag_scores)

            forward_var = torch.cat(viterbivars_t, -1) + feat
            backpointers.append(torch.cat(bptrs_t, dim=-1))

        terminal_var = forward_var + self.transitions[tag_to_ix[STOP_TAG]].view(1, -1)
        _, best_tag_ids = torch.max(terminal_var, 1)

        best_path = [best_tag_ids.view(-1, 1)]
        for bptrs_t in reversed(backpointers):
            best_tag_ids = gather_index(bptrs_t, best_tag_ids)
            best_path.append(best_tag_ids.contiguous().view(-1, 1))

        best_path.pop()
        best_path.reverse()

        return torch.cat(best_path, dim=-1)
