# 2 - Pretrained Transformers for PoS Tagging

## Introduction

In the previous notebook we showed how to use a BiLSTM with pretrained GloVe embeddings for PoS tagging. In this notebook we'll be using a pretrained [Transformer](https://arxiv.org/abs/1706.03762) model, specifically the pre-trained [BERT](https://arxiv.org/abs/1810.04805) model. The transformer will replace the embedding layer of our BiLSTM, and the rest of the model will be the same.

## Preparing Data

First, let's import the necessary Python modules.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchtext import data
from torchtext import datasets

from transformers import BertTokenizer, BertModel

import numpy as np

import time
import random
import functools

Next, we'll set the random seeds for reproducability.

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

Then, we'll import the BERT tokenizer. This defines how text into the model should be processed, but more importantly contains the vocabulary that the BERT model was pretrained with. We'll be using the `bert-base-uncased` tokenizer and model. This was trained on text that has been lowercased.

In order to use pretrained models for NLP the vocabulary used needs to exactly match that of the pretrained model.

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Another thing that we need to do is make sure the input sequence is formatted in the same way in which the BERT model was trained. 

BERT was trained on sequences that begin with a `[CLS]` token and end with a `[SEQ]` token.

So the sequence of tokens

```python
text = ['jack', 'went', 'to', 'the', 'shop']
```

should become:

```python
text = ['[CLS]', 'jack', 'went', 'to', 'the', 'shop', '[SEP]']
```

Along with making our vocabularies match we also need to make sure our padding and unk tokens match those used in the pretrained model. By default TorchText uses `<pad>` and `<unk>`, but the BERT model uses `[PAD]` and `[UNK]`.

Let's get the special tokens:

In [4]:
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

[CLS] [SEP] [PAD] [UNK]


We are mainly interested in the actual integer representations of the special tokens. This is because we aren't using TorchText's vocabulary module, but using the one provided by the pretrained model. 

We get the indexes of the special tokens by passing them through the tokenizer's `convert_tokens_to_ids` function.

In [5]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


One other thing is that the pretrained model was trained on sequences up to a maximum length and we need to ensure that our sequences are also trimmed to this length.

In [6]:
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']

print(max_input_length)

512


Next, we'll define two helper functions that make use of our vocabulary.

The first will cut the sequence of tokens to the desired maximum length, specified by our pretrained model, and then convert the tokens into indexes by passing them through the vocabulary. This is what we will use on our input sequence we want to tag.

Note that we actually cut tokens to `max_input_length-2`, this is because we need to add the special `[CLS]` and `[SEP]` tokens to each side of the sequence.

In [7]:
def cut_and_convert_to_id(tokens, tokenizer, max_input_length):
    tokens = tokens[:max_input_length-2]
    tokens = tokenizer.convert_tokens_to_ids(tokens)
    return tokens

The second helper function simply cuts the sequence to the maximum length. This is used for our tags. We do not pass the tags through pretrained model's vocabulary as the vocab was only build for English sentences, and not for part-of-speech tags. We will be building the tag vocabulary ourselves.

In [8]:
def cut_to_max_length(tokens, max_input_length):
    tokens = tokens[:max_input_length-2]
    return tokens

We need to pass the above two functions to the `Field`, the TorchText abstraction that handles a lot of the data processing for us. We make use of Python's `functools` that allow us to pass functions which already have some of their arguments supplied. 

In [9]:
text_preprocessor = functools.partial(cut_and_convert_to_id,
                                      tokenizer = tokenizer,
                                      max_input_length = max_input_length)

tag_preprocessor = functools.partial(cut_to_max_length,
                                     max_input_length = max_input_length)

Next, we define our fields.

For the `TEXT` field, which will be processing the sequences we want to tag, we first tell TorchText that we do not want to use a vocabulary with `use_vocab = False`. As our model is `uncased`, we also want to ensure all text is lowercased with `lower=True`. The `preprocessing` argument is a function applied to sequences after they have been tokenized, but before they are numericalized. As we have set `use_vocab` to false, they will never actually be numericalized, and as we are using TorchText's POS datasets they have also already been tokenized - so the argument to this will just be applied to the sequence of tokens. This is where our help functions from above come in handy and `text_preprocessor` will both numericalize our data using the pretrained model's vocabulary, as well as cutting it to the maximum length. The remaining four arguments define the special tokens required by the pretrained model.

For the `UD_TAGS` field, we need to ensure the length of our tags matches the length of our text sequence. As we have added a token on either side of the text sequence, we need to do the same with the sequence of tags. We do this by adding a `<pad>` token to each side which we will later tell our model to not use when calculating losses or accuracy. We won't have unknown tags in our sequence of tags, so we set the `unk_token` to `None`. Finally, we pass our `tag_preprocessor` defined above, which simply cuts the tags to the maximum length our pretrained model can handle.

In [10]:
TEXT = data.Field(use_vocab = False,
                  lower = True,
                  preprocessing = text_preprocessor,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

UD_TAGS = data.Field(unk_token = None,
                     init_token = '<pad>',
                     eos_token = '<pad>',
                     preprocessing = tag_preprocessor)

Then, we define which of our fields defined above correspond to which fields in the dataset.

In [11]:
fields = (("text", TEXT), ("udtags", UD_TAGS))

Next, we load the data using our fields.

In [12]:
train_data, valid_data, test_data = datasets.UDPOS.splits(fields)

We can check an example by printing it. As we have already numericalized our `text` using the vocabulary of the pretrained model, it is already a sequence of integers. The tags have yet to be numericalized. 

In [13]:
print(vars(train_data.examples[0]))

{'text': [2632, 1011, 100, 1024, 2137, 2749, 2730, 100, 14093, 2632, 1011, 100, 1010, 1996, 14512, 2012, 1996, 8806, 1999, 1996, 2237, 1997, 100, 1010, 2379, 1996, 9042, 3675, 1012], 'udtags': ['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']}


Our next step is to build the tag vocabulary so they can be numericalized during training. We do this by using the field's `.build_vocab` method on the `train_data`.

In [14]:
UD_TAGS.build_vocab(train_data)

print(UD_TAGS.vocab.stoi)

defaultdict(None, {'<pad>': 0, 'NOUN': 1, 'PUNCT': 2, 'VERB': 3, 'PRON': 4, 'ADP': 5, 'DET': 6, 'PROPN': 7, 'ADJ': 8, 'AUX': 9, 'ADV': 10, 'CCONJ': 11, 'PART': 12, 'NUM': 13, 'SCONJ': 14, 'X': 15, 'INTJ': 16, 'SYM': 17})


Next, we'll define our iterators. This will define how batches of data are provided when training. We set a batch size and define `device`, which will automatically put our batch on to the GPU, if we have one.

In [15]:
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

## Building the Model

Next up is defining our model. This is incredibly similar to the BiLSTM model from the previous notebook, however we are now using the pretrained BERT model to provide us with embeddings for our input text sequence instead of using an embedding layer.

![](assets/pos-bidirectional-lstm.png)

Looking at the above image of the model. Previously the yellow squares were the embeddings provided by the embedding layer. Now they are embeddings provided by the pretrained BERT model.

One thing to note is that we do not define an `embedding_dim` for our model, it is the size of the output of the pretrained BERT model and we cannot change it. Thus, we simply get the `embedding_dim` from the model's `hidden_size` attribute.

We will also not be fine-tuning the weights of the BERT model. This is because it is a relatively large model with lots of parameters and is unlikely to fit into our GPU's memory. As we aren't training this model we never need to calculate or store gradients across its parameters, hence we wrap it in a `no_grad` during the forward pass of our model.

In [16]:
class BiLSTMBERTTagger(nn.Module):
    def __init__(self,
                 bert,
                 hidden_dim, 
                 output_dim, 
                 n_layers, 
                 bidirectional, 
                 dropout):
        
        super().__init__()
        
        self.bert = bert
        
        embedding_dim = bert.config.to_dict()['hidden_size']
        
        self.lstm = nn.LSTM(embedding_dim, 
                            hidden_dim, 
                            num_layers = n_layers, 
                            bidirectional = bidirectional,
                            dropout = 0 if n_layers < 2 else dropout)
        
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):
  
        #text = [sent len, batch size]
    
        text = text.permute(1, 0)
        
        #text = [batch size, sent len]
        
        with torch.no_grad():
            embedded = self.dropout(self.bert(text)[0])
        
        #embedded = [batch size, seq len, emb dim]
                
        embedded = embedded.permute(1, 0, 2)
            
        outputs, (hidden, cell) = self.lstm(embedded)
        
        #output = [sent len, batch size, hid dim * n directions]
        #hidden/cell = [n layers * n directions, batch size, hid dim]
        
        predictions = self.fc(self.dropout(outputs))
        
        #predictions = [sent len, batch size, output dim]
        
        return predictions

Next, we load the actual pretrained BERT uncased model - before we only loaded the tokenizer associated with the model.

The first time we run this it will have to download the pretrained parameters.

In [17]:
bert = BertModel.from_pretrained('bert-base-uncased')

## Training the Model

We finally get to instantiate our model - a BiLSTM using a pretrained BERT model to get word embeddings.

The hyperparameters have been chosen as they are sensible defaults. There may be a better configuration of hyperparameters which perform better for this model and dataset.

In [18]:
HIDDEN_DIM = 256
OUTPUT_DIM = len(UD_TAGS.vocab)
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25

model = BiLSTMBERTTagger(bert,
                         HIDDEN_DIM, 
                         OUTPUT_DIM, 
                         N_LAYERS, 
                         BIDIRECTIONAL, 
                         DROPOUT)

We can then count the number of trainable parameters. This includes all of the BERT parameters too. 113M is a lot of parameters, probably too big to fit on our GPU.

In [19]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 113,169,682 trainable parameters


To handle this, we *freeze* all of the parameters of the pretrained BERT model. This will mean they do not change from their pretrained values and that the only part we are training is the BiLSTM and the linear which predicts the tags from the output of the LSTM.

In [20]:
for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False

We can check the number of trainable parameters again. We've gone down from 113M to 3.6M, a much more sensible value.

In [21]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 3,687,442 trainable parameters


We can also print out the names of all of the trainable parameters, useful for checking if the correct things are still trainable.

As we can see the BiLSTM `lstm` and the prediction linear layer `fc` are still trainable.

In [22]:
for name, param in model.named_parameters():                
    if param.requires_grad:
        print(name)

lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
lstm.weight_ih_l0_reverse
lstm.weight_hh_l0_reverse
lstm.bias_ih_l0_reverse
lstm.bias_hh_l0_reverse
lstm.weight_ih_l1
lstm.weight_hh_l1
lstm.bias_ih_l1
lstm.bias_hh_l1
lstm.weight_ih_l1_reverse
lstm.weight_hh_l1_reverse
lstm.bias_ih_l1_reverse
lstm.bias_hh_l1_reverse
fc.weight
fc.bias


The rest of the notebook is pretty standard, we define our optimzer which we use for updating our trainable parameters with respect to the calculated gradient.

In [23]:
optimizer = optim.Adam(model.parameters())

We define a loss function, making sure to ignore losses whenever the target tag is a padding token.

In [24]:
TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]

criterion = nn.CrossEntropyLoss(ignore_index = TAG_PAD_IDX)

Then, we place the model on to the GPU, if we have one.

In [25]:
model = model.to(device)
criterion = criterion.to(device)

Like in the previous tutorial, we define a function which calculates our accuracy of predicting tags, ignoring predictions over padding tokens.

In [26]:
def categorical_accuracy(preds, y, tag_pad_idx):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]])

We then define our `train` and `evaluate` functions to train and test our model. 

In [27]:
def train(model, iterator, optimizer, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        text = batch.text
        tags = batch.udtags
                
        optimizer.zero_grad()
        
        #text = [sent len, batch size]
        
        predictions = model(text)
        
        #predictions = [sent len, batch size, output dim]
        #tags = [sent len, batch size]
        
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)
        
        #predictions = [sent len * batch size, output dim]
        #tags = [sent len * batch size]
        
        loss = criterion(predictions, tags)
                
        acc = categorical_accuracy(predictions, tags, tag_pad_idx)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [28]:
def evaluate(model, iterator, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            text = batch.text
            tags = batch.udtags
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)
            
            loss = criterion(predictions, tags)
            
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

Then, we define a helper function used to see how long an epoch takes.

In [29]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

Finally, we can train our model!

This model takes a considerable amount of time per epoch compared to the last model as the number of parameters is significantly higher. Even though we do not train our BERT model our input text still needs to go through it which is where all of this added time comes from.

In [30]:
N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut2-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 28s
	Train Loss: 0.906 | Train Acc: 72.16%
	 Val. Loss: 0.460 |  Val. Acc: 85.99%
Epoch: 02 | Epoch Time: 0m 28s
	Train Loss: 0.282 | Train Acc: 91.14%
	 Val. Loss: 0.370 |  Val. Acc: 88.01%
Epoch: 03 | Epoch Time: 0m 28s
	Train Loss: 0.234 | Train Acc: 92.54%
	 Val. Loss: 0.359 |  Val. Acc: 88.33%
Epoch: 04 | Epoch Time: 0m 28s
	Train Loss: 0.206 | Train Acc: 93.33%
	 Val. Loss: 0.327 |  Val. Acc: 89.45%
Epoch: 05 | Epoch Time: 0m 27s
	Train Loss: 0.183 | Train Acc: 94.09%
	 Val. Loss: 0.303 |  Val. Acc: 90.18%
Epoch: 06 | Epoch Time: 0m 28s
	Train Loss: 0.166 | Train Acc: 94.57%
	 Val. Loss: 0.287 |  Val. Acc: 90.47%
Epoch: 07 | Epoch Time: 0m 28s
	Train Loss: 0.152 | Train Acc: 95.02%
	 Val. Loss: 0.283 |  Val. Acc: 91.08%
Epoch: 08 | Epoch Time: 0m 28s
	Train Loss: 0.140 | Train Acc: 95.38%
	 Val. Loss: 0.279 |  Val. Acc: 90.66%
Epoch: 09 | Epoch Time: 0m 28s
	Train Loss: 0.130 | Train Acc: 95.70%
	 Val. Loss: 0.274 |  Val. Acc: 91.62%
Epoch: 10 | Epoch T

We can then load our "best" performing model and try it out on the test set. 

We beat our previous model by 2%! Note how our validation loss was still increasing so there is potential for us to train the model for longer and achieve a higher test accuracy.

In [31]:
model.load_state_dict(torch.load('tut2-model.pt'))

test_loss, test_acc = evaluate(model, test_iterator, criterion, TAG_PAD_IDX)

print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.324 | Test Acc: 90.29%


## Inference

We'll now see how to use our model to tag actual sentences. This is similar to the inference function from the previous notebook with the tokenization changed to match the format of our pretrained model.

If we pass in a string, this means we need to split it into individual tokens which we do by using the `tokenize` function of the `tokenizer`. Afterwards, we `encode` the tokens, which both adds the special tags to either side of the text sequence and converts them into integers using the vocabulary.

We then pass the text sequence through our model to get a prediction for each token and then slice off the predictions for the `[CLS]` and `[SEP]` tokens as we do not care about them.

In [32]:
def tag_sentence(model, device, sentence, tokenizer, text_field, tag_field):
    
    model.eval()
    
    if isinstance(sentence, str):
        tokens = tokenizer.tokenize(sentence)
        numericalized_tokens = tokenizer.encode(sentence)
    else:
        tokens = sentence
        numericalized_tokens = tokenizer.encode(sentence)
        
    unk_idx = text_field.unk_token
    
    unks = [t for t, n in zip(tokens, numericalized_tokens) if n == unk_idx]
    
    token_tensor = torch.LongTensor(numericalized_tokens)
    
    token_tensor = token_tensor.unsqueeze(-1).to(device)
         
    predictions = model(token_tensor)
    
    top_predictions = predictions.argmax(-1)
    
    predicted_tags = [tag_field.vocab.itos[t.item()] for t in top_predictions]
    
    predicted_tags = predicted_tags[1:-1]
    
    assert len(tokens) == len(predicted_tags)
    
    return tokens, predicted_tags, unks

We can then run an example sentence through our model and receive the predicted tags.

In [33]:
sentence = 'The Queen will deliver a speech about the conflict in North Korea at 1pm tomorrow.'

tokens, tags, unks = tag_sentence(model, 
                                  device, 
                                  sentence,
                                  tokenizer,
                                  TEXT, 
                                  UD_TAGS)

print(unks)

[]


We can then print out the tokens and their corresponding tags.

Notice how "1pm" in the input sequence has been converted to the two tokens "1" and "##pm". What's with the two hash symbols in front of the "pm"? This is due to the way the tokenizer tokenizes sentences. It uses something called [byte pair encoding](https://en.wikipedia.org/wiki/Byte_pair_encoding) to split words up into more common subsequences of characters.

In [34]:
print("Pred. Tag\tToken\n")

for token, tag in zip(tokens, tags):
    print(f"{tag}\t\t{token}")

Pred. Tag	Token

DET		the
NOUN		queen
AUX		will
VERB		deliver
DET		a
NOUN		speech
ADP		about
DET		the
NOUN		conflict
ADP		in
PROPN		north
PROPN		korea
ADP		at
NUM		1
NOUN		##pm
NOUN		tomorrow
PUNCT		.


We've now implemented a BiLSTM tagger using pretrained BERT embeddings! Well done us!