In [2]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import random
from transformers import BertTokenizer, BertModel

torch.manual_seed(1)

<torch._C.Generator at 0x25b9facc850>

#### BiLSTM CRF model

In [7]:
class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, target_size, embedding_mat, start_tag, end_tag, tag_to_ix, batch_size=1, device='cpu'):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.target_size = target_size
        self.batch_size = batch_size
        self.device = device
        self.tag_to_ix = tag_to_ix

        self.embedding = nn.Embedding.from_pretrained(embedding_mat)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)
        self.hidden2tag = nn.Linear(hidden_dim, target_size)

        self.transitions_to = nn.Parameter(torch.randn(target_size, target_size))
        self.transitions_to.data[start_tag, :] = -10000
        self.transitions_to.data[:, end_tag] = -10000

        self.transitions_from = nn.Parameter(torch.randn(target_size, target_size))
        self.transitions_from.data[:, start_tag] = -10000
        self.transitions_from.data[end_tag, :] = -10000

        self.hidden = (torch.randn(2, 1, hidden_dim // 2),
                       torch.randn(2, 1, hidden_dim // 2))
        

    def get_lstm_features(self, sentence):
        embeds = self.embedding(sentence).view(len(sentence), 1, -1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats


    def _forward_algo(self, lstm_features):

        scores = torch.full((1, self.target_size), -10000.).to(self.device)
        scores[0][self.tag_to_ix['START_TAG']] = 0.

        forward_var = scores

        for feat in lstm_features:
            next_tag_var = self.transitions_to + feat.view(-1, 1).expand(-1, self.target_size) + forward_var.expand(self.target_size, -1)
            #use log sum exp to avoid underflow
            max_score = next_tag_var.max()
            next_tag_var = next_tag_var - max_score
            forward_var = max_score + torch.logsumexp(next_tag_var, dim=0).view(1, -1)
            
        terminal_var = (forward_var + self.transitions_to[self.tag_to_ix['END_TAG']]).view(1, -1)
        alpha = terminal_var
        return alpha
    

    def _score_sentence(self, lstm_features, tags):
        score = torch.zeros(1).to(self.device)
        tags = torch.cat([torch.tensor([self.tag_to_ix['START_TAG']], dtype=torch.long).to(self.device), tags])
        for i, feat in enumerate(lstm_features):
            score += self.transitions_to[tags[i + 1], tags[i]] + feat[tags[i + 1]]
                
        score += self.transitions_to[self.tag_to_ix['END_TAG'], tags[-1]]
        return score
    
    def neg_log_likelihood(self, sentence, tags):
        lstm_feats = self.get_lstm_features(sentence)
        forward_score = self._forward_algo(lstm_feats)
        gold_score = self._score_sentence(lstm_feats, tags)
        return forward_score - gold_score
        

        

In [8]:
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

In [10]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"
EMBEDDING_DIM = 5
HIDDEN_DIM = 4

# Make up some training data
training_data = [(
    "the wall street journal reported today that apple corporation made money".split(),
    "B I I I O O O B I O O".split()
), (
    "georgia tech is a university in georgia".split(),
    "B I O O O O B".split()
)]

word_to_ix = {}
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)

tag_to_ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4}

model = BiLSTM_CRF(len(word_to_ix), 5, HIDDEN_DIM, len(tag_to_ix), torch.randn(len(word_to_ix), 5), tag_to_ix[START_TAG], tag_to_ix[STOP_TAG])
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

sentence_in = prepare_sequence(training_data[0][0], word_to_ix)
print(sentence_in.shape)
print(model.get_lstm_features(sentence_in).shape)

torch.Size([11])
torch.Size([11, 1, 5])
torch.Size([11, 1, 4])
torch.Size([11, 4])
torch.Size([11, 5])


In [41]:
temp = torch.full((1, 5), -1)
temp[0][0] = 0
temp.unsqueeze(1).expand(-1, 5, -1)

tensor([[[ 0, -1, -1, -1, -1],
         [ 0, -1, -1, -1, -1],
         [ 0, -1, -1, -1, -1],
         [ 0, -1, -1, -1, -1],
         [ 0, -1, -1, -1, -1]]])

In [42]:
transition = nn.Parameter(torch.randn(5, 5))
transition.data[0, :] = -10000
transition.data[:, 4] = -10000

print(transition)

Parameter containing:
tensor([[-1.0000e+04, -1.0000e+04, -1.0000e+04, -1.0000e+04, -1.0000e+04],
        [-1.6608e+00, -1.1081e+00,  1.1770e+00, -1.0175e-02, -1.0000e+04],
        [ 5.8099e-01, -1.4000e+00,  2.3599e-01,  4.4388e-01, -1.0000e+04],
        [-1.2586e+00, -6.9026e-01, -7.9675e-01,  3.1644e-01, -1.0000e+04],
        [ 1.4568e+00,  9.1951e-01,  3.5823e-01,  4.0518e-01, -1.0000e+04]],
       requires_grad=True)


In [49]:
transition.unsqueeze(0).expand(5, -1, -1) + temp.unsqueeze(1).expand(-1, 5, -1)

tensor([[[-1.0000e+04, -1.0001e+04, -1.0001e+04, -1.0001e+04, -1.0001e+04],
         [-1.6608e+00, -2.1081e+00,  1.7705e-01, -1.0102e+00, -1.0001e+04],
         [ 5.8099e-01, -2.4000e+00, -7.6401e-01, -5.5612e-01, -1.0001e+04],
         [-1.2586e+00, -1.6903e+00, -1.7967e+00, -6.8356e-01, -1.0001e+04],
         [ 1.4568e+00, -8.0491e-02, -6.4177e-01, -5.9482e-01, -1.0001e+04]],

        [[-1.0000e+04, -1.0001e+04, -1.0001e+04, -1.0001e+04, -1.0001e+04],
         [-1.6608e+00, -2.1081e+00,  1.7705e-01, -1.0102e+00, -1.0001e+04],
         [ 5.8099e-01, -2.4000e+00, -7.6401e-01, -5.5612e-01, -1.0001e+04],
         [-1.2586e+00, -1.6903e+00, -1.7967e+00, -6.8356e-01, -1.0001e+04],
         [ 1.4568e+00, -8.0491e-02, -6.4177e-01, -5.9482e-01, -1.0001e+04]],

        [[-1.0000e+04, -1.0001e+04, -1.0001e+04, -1.0001e+04, -1.0001e+04],
         [-1.6608e+00, -2.1081e+00,  1.7705e-01, -1.0102e+00, -1.0001e+04],
         [ 5.8099e-01, -2.4000e+00, -7.6401e-01, -5.5612e-01, -1.0001e+04],
        

In [52]:
feat = torch.randn(5)
print(feat)

feat.view(5, 1).expand(-1, 5)

tensor([ 1.0919,  0.5058,  1.3415,  1.5653, -1.1470])


tensor([[ 1.0919,  1.0919,  1.0919,  1.0919,  1.0919],
        [ 0.5058,  0.5058,  0.5058,  0.5058,  0.5058],
        [ 1.3415,  1.3415,  1.3415,  1.3415,  1.3415],
        [ 1.5653,  1.5653,  1.5653,  1.5653,  1.5653],
        [-1.1470, -1.1470, -1.1470, -1.1470, -1.1470]])