In [None]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(1)
%matplotlib inline

In [None]:
def show_bert_doctrine():
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
    bert = BertModel.from_pretrained(BERT_MODEL_NAME)
    for k, v in tokenizer("I am a boy", return_tensors='pt').items():
        print(k, v)
        if k == 'inputs_ids':
            print(tokenizer.convert_ids_to_tokens(v.squeeze()))
    h = bert(**tokenizer("I am a boy", return_tensors='pt'))[0]
    print(h.shape)
show_bert_doctrine()

In [None]:
from transformers import BertTokenizer, BertModel, BertConfig
BERT_MODEL_NAME = "bert-base-cased"

class BertEmbedding(nn.Module):
    def __init__(self):
        super(BertEmbedding, self).__init__()
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
    
    def fix_params(self):
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def free_params(self):
        for param in self.bert.parameters():
            param.requires_grad = True
        
    def forward(self, inputs):
        return self.bert(**inputs)[0][:, 1:-1, :]

In [None]:
def argmax(vec):
    _, idx = torch.max(vec, 1)
    return idx.item()

def prepare_sequence(seq, tags, tokenizer, tag_to_ix):
    tags = tags.split()
    targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)
    # prepare inputs for bert model and find start tokens for word piece tokens
    input_ids = tokenizer(seq, return_tensors='pt')
    word_pieces = tokenizer.convert_ids_to_tokens(input_ids['input_ids'].squeeze())[1:-1]
    token_starts = torch.LongTensor([i for i, wp in enumerate(word_pieces) if not wp.startswith("##")])
    return input_ids, targets, token_starts

# compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
                torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

create model

In [None]:
class BiLSTM_CRF(nn.Module):
    def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=768):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)
        self.word_embeds = BertEmbedding()
        self.lstm = nn.LSTM(embedding_dim, hidden_dim//2,
                           num_layers=1, bidirectional=True)
        # maps the output of the LSTM into tag space
        self.hidden2tag = nn.Linear(hidden_dim, self.target_size)
        # Matrix of transition parameters. Entry i,j is the score of transitioning from j to i
        self.transitions = nn.Parameter(torch.randn(self.target_size, self.target_size))
        self.transitions.data[tag_to_ix[START_TAG], :] = -1000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -1000
        self.hidden = self.init_hidden()
    
    def fix_bert(self):
        self.word_embeds.fix_params()
        
    def free_bert(self):
        self.word_embeds.free_params()
    
    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim//2),
                torch.randn(2, 1, self.hidden_dim//2))
    
    def _forward_alg(self, feats):
        init_alphas = torch.full((1, self.target_size), -10000.)
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.# START_TAG has all of the score, here is log score
        forward_var = init_alphas
        # Iterate through the sentence
        for feat in feats:
            forward_var = torch.logsumexp(feat.expand(self.tagset_size, -1) \
                                          + self.transitions.T + forward_var.view(-1, 1),
                                         dim = 0, keep_dim=True)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        
        alpha = log_sum_exp(terminal_var)
        return alpha
    
    def _get_lstm_features(self, embeds):
        self.hidden = self.init_hidden()
        embeds = embeds.view(embeds.shape[1], embeds.shape[0], -1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(lstm_out.shape[1], lstm_out.shape[0], self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def _score_sentence(self, feats, tags):
        score = torch.zeros(1) # give the score of a provieded tag sequence
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
        for i, feat in enumerate(feats):
            score = score + \
                    self.transitions[tags[i+1], tags[i] + feat[tags[i+1]]]
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score
    
    def _viterbi_decode(self, feats):
        backpointers = []
        # initialize the viterbi variable in log space
        init_vvars = torch.full((1, self.target_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0
        # forward_var at step i holds the viterbi variables for step i-1
        forward_var = init_vvars
        for feat in feats:
            scores = self.transitions + forward_var
            forward_var, bptrs = torch.max(scores, dim=1)
            forward_var = forward_var.view(1, -1) + feat.view(1, -1)
            backpointers.append(bptrs.cpu().numpy().tolist())
        
        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]
        
        # follow the back pointers to decade the best path
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]
        best_path.reverse()
        return path_score, best_path
        
    def neg_log_likelihood(self, input_ids, tags, token_starts):
        embeds = self.word_embeds(input_ids)[:, token_starts]
        feats = self._get_lstm_features(embeds).squeeze()
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score
        
    def forward(self, input_ids, token_starts):
        embeds = self.word_embeds(input_ids)[:, token_starts]
        # get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(embeds).squeeze()
        # find the best path given the features
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

Run training

In [None]:
START_TAG = '<START>'
STOP_TAG = '<STOP>'
EMBEDDING_DIM = 768
HIDDEM_DIM = 768

train_data = [("the wall street journal reported today that apple corporation made money",
    "B I I I O O O B I O O"),
             ("georgia tech is a university in georgia",
     "B I O O O O B")]
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
tag_to_ix = {'B':0, 'I':1, 'O':2, START_TAG:3, STOP_TAG:4}

model = BiLSTM_CRF(tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)
optimizer = optim.SGD([param for param in model.parameters() if param.requires_grad], lr= 0.01, weight_decay=1e-4)

# checking predictions before training
with torch.no_grad():
    precheck_sent, targets, token_starts = prepare_sequence(training_data[0][0], training_data[0][1], tokenizer, tag_to_ix)
    print(precheck_sent)
    print(targets)
    print(token_starts)
    print(model(precheck_sent, token_starts))

In [None]:
for epoch in range(300):
    for sent, tags in training_data:
        model.zero_grad()
        # turn training data into tensor of word indices
        sent_in, tagets, token_starts = prepare_sequence(sent, tags, tokenizer, tag_to_ix)
        # run forward pass
        loss = model.neg_log_likelihood(sent_in, targets,token_starts)
        print(loss.item())
        loss.backward()
        optimizer.step()
    
with torch.no_grad():
    precheck_sent, targets, token_starts = prepare_sequence(training_data[0][0], trainning_data[0][1], tokenizer, tag_to_ix)
    print(model(precheck_sent, token_starts))
        