In [1]:
import torch
import torch.nn as nn

# LSTM-CRF forward model
IMPOSSIBLE = -1e4

class BiLSTM_CRF(nn.Module):
    def __init__(
        self, vocab_size, num_tags, start_tag, stop_tag,
        embedding_dim, hidden_dim
    ):
        super().__init__()

        self.num_tags = num_tags
        self.START_TAG = start_tag
        self.STOP_TAG = stop_tag
            
        # CRF parameters
        self.transitions = nn.Parameter(torch.randn(self.num_tags, self.num_tags))
            
        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        self.transitions.data[:, self.START_TAG] = IMPOSSIBLE
        self.transitions.data[self.STOP_TAG, :] = IMPOSSIBLE

        # LSTM parameters
        self.hidden_dim = hidden_dim
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True,
                            batch_first=False)
            
        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim, self.num_tags)
            
        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2),
                torch.randn(2, 1, self.hidden_dim // 2))
        
    def _get_emissions(self, sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        emissions = self.hidden2tag(lstm_out)
        return emissions
        
    def neg_log_likelihood(self, sentence, tags):
        emissions = self._get_emissions(sentence)
        forward_score = self._forward_alg(emissions)
        gold_score = self._score_sentence(emissions, tags)
        return forward_score - gold_score
        
    def _forward_alg(self, emissions):

        init_alphas = self.transitions[self.START_TAG] + emissions[0]
            
        # Wrap in a variable so that we will get automatic backprop
        alphas = init_alphas

        for emission in emissions[1:]:
            alphas_t = [] # The forward tensors at this timestep
            for next_tag in range(self.num_tags):
                # emission score for the next tag
                emit_score = emission[next_tag].view(1, -1).expand(1, self.num_tags)
                # transition score from any previous tag to the next tag
                trans_score = self.transitions[:, next_tag].view(1, -1)
                # combine current scores with previous alphas 
                # since alphas are in log space (see logsumexp below),
                # we add them instead of multiplying
                next_tag_var = alphas + trans_score + emit_score

                alphas_t.append(torch.logsumexp(next_tag_var, 1).view(1))

            alphas = torch.cat(alphas_t).view(1, -1)

        terminal_alphas = alphas + self.transitions[:, self.STOP_TAG]
        alphas = torch.logsumexp(terminal_alphas, 1)

        return alphas
    
    def _viterbi_decode(self, emissions):
        backpointers = []

        # Initialize the viterbi variables in log space
        init_alphas = self.transitions[self.START_TAG] + emissions[:1]

        # alphas at step i holds the viterbi variables for step i-1
        alphas = init_alphas
        for emission in emissions[1:]:
            bptrs_t = [] # holds the backpointers for this step
            viterbivars_t = [] # holds the viterbi variables for this step

            for next_tag in range(self.num_tags):
                # next_tag_var[i] holds the viterbi variable for tag i at the
                # previous step, plus the score of transitioning
                # from tag i to next_tag.
                # We don't include the emission scores here because the max
                # does not depend on them (we add them in below)
                next_tag_var = alphas + self.transitions[:, next_tag] + emission[next_tag]
                best_tag_score, best_tag_id = torch.max(next_tag_var, dim=-1)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign alphas to the set
            # of viterbi variables we just computed
            alphas = (torch.cat(viterbivars_t)).view(1, -1)
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        terminal_alphas = alphas + self.transitions[:, self.STOP_TAG]
        best_tag_score, best_tag_id = torch.max(terminal_alphas, dim=-1)
        path_score = terminal_alphas[0][best_tag_id]
            
        # Follow the back pointers to decode the best path.
        # Append terminal tag 
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
                
        best_path.reverse()
        best_path = torch.cat(best_path)
            
        return path_score, best_path

    def forward(self, sentence): 
        # Get the emission scores from the BiLSTM
        emissions = self._get_emissions(sentence)
        print(emissions)

        # Find the best path, given the emission scores.
        score, tag_seq = self._viterbi_decode(emissions)
        return score, tag_seq
        
    def _score_sentence(self, emissions, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.START_TAG], dtype=torch.long), tags[0]])
        for i, emission in enumerate(emissions):
            score = score + self.transitions[tags[i], tags[i+1]] + emission[tags[i+1]]
        score = score + self.transitions[tags[-1], self.STOP_TAG]
        return score

def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


if __name__ == '__main__':
    
    START_TAG = "<START>"
    STOP_TAG = "<STOP>"
    EMBEDDING_DIM = 5
    HIDDEN_DIM = 4

    training_data = [
        (
            "Google Deepmind company".split(), 
            "B I O".split(),
        )
    ]


    word_to_ix = {START_TAG: 0, STOP_TAG: 1}
    for sentence, tags in training_data:
        for word in sentence:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)

    tag_to_ix = {START_TAG: 0, STOP_TAG: 1, 'B': 2, 'I': 3, 'O': 4}
    print(word_to_ix)

    crf_mod = BiLSTM_CRF(len(word_to_ix), len(tag_to_ix), tag_to_ix[START_TAG], tag_to_ix[STOP_TAG], 
                            embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM)

        
        
    sentence, tags = training_data[0]
    sentence_in = prepare_sequence(sentence, word_to_ix)
    targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)


    torch.manual_seed(1)
    print(sentence_in, targets)
        
    score, tag_seq = crf_mod(sentence_in)
    print(score, tag_seq)

{'<START>': 0, '<STOP>': 1, 'Google': 2, 'Deepmind': 3, 'company': 4}
tensor([2, 3, 4]) tensor([2, 3, 4])
tensor([[ 0.1878,  0.0654,  0.4902,  0.1473, -0.1422],
        [ 0.3721,  0.1898,  0.3060,  0.0743, -0.0744],
        [ 0.2491,  0.0582,  0.4111,  0.1462, -0.1255]],
       grad_fn=<AddmmBackward0>)
tensor([1.7896], grad_fn=<IndexBackward0>) tensor([2, 4, 2])
