<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fclassification/applications/classification/pos_tagging/POS%20Tagging%20with%20BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
[K     |▌                               | 10kB 23.3MB/s eta 0:00:01[K     |█                               | 20kB 3.0MB/s eta 0:00:01[K     |█▌                              | 30kB 3.9MB/s eta 0:00:01[K     |██                              | 40kB 4.1MB/s eta 0:00:01[K     |██▍                             | 51kB 3.5MB/s eta 0:00:01[K     |███                             | 61kB 3.8MB/s eta 0:00:01[K     |███▍                            | 71kB 4.2MB/s eta 0:00:01[K     |███▉                            | 81kB 4.4MB/s eta 0:00:01[K     |████▍                           | 92kB 4.7MB/s eta 0:00:01[K     |████▉                           | 102kB 4.7MB/s eta 0:00:01[K     |█████▍                          | 112kB 4.7MB/s eta 0:00:01[K     |█████▉                          | 122kB 4.7

In [2]:
import time
import random
import functools

import spacy
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torchtext import data, datasets
import torchtext.vocab as vocab

from transformers import BertTokenizer, BertModel

import matplotlib.pyplot as plt
import seaborn as sns

  import pandas.util.testing as tm


In [3]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [5]:
# get the special tokens, it is required to use these while encoding in Field
init_token = tokenizer.cls_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, pad_token_idx, unk_token_idx)

101 0 100


In [6]:
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']

print(max_input_length)

512


In [7]:
# limit the input tokens to max length - 1, [CLS] is the initial token
def cut_and_convert_to_id(tokens, tokenizer, max_input_length):
    tokens = tokens[:max_input_length-1]
    tokens = tokenizer.convert_tokens_to_ids(tokens)
    return tokens

In [8]:
# limit the tags to max length - 1
def cut_to_max_length(tokens, max_input_length):
    tokens = tokens[:max_input_length-1]
    return tokens

In [9]:
text_preprocessor = functools.partial(cut_and_convert_to_id,
                                      tokenizer = tokenizer,
                                      max_input_length = max_input_length)

tag_preprocessor = functools.partial(cut_to_max_length,
                                     max_input_length = max_input_length)

In [10]:
# initial token is [CLS]
# pad token is [PAD]
# unknown token is [UNK]
# lower=True as the model used is bert-base-uncased
TEXT = data.Field(use_vocab = False,
                  lower = True,
                  preprocessing = text_preprocessor,
                  init_token = init_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

# init token is <pad> as the pos corresponding to [CLS]
# it will be ignored while cal. the loss
# unknown token is not required as the tags are limited
UD_TAGS = data.Field(unk_token = None,
                     init_token = '<pad>',
                     preprocessing = tag_preprocessor)

In [11]:
fields = (("text", TEXT), ("udtags", UD_TAGS))

In [12]:
train_data, valid_data, test_data = datasets.UDPOS.splits(fields)

downloading en-ud-v2.zip


en-ud-v2.zip: 100%|██████████| 688k/688k [00:00<00:00, 2.19MB/s]


extracting


In [13]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 12543
Number of validation examples: 2002
Number of testing examples: 2077


In [14]:
print(vars(train_data.examples[0]))

{'text': [2632, 1011, 100, 1024, 2137, 2749, 2730, 100, 14093, 2632, 1011, 100, 1010, 1996, 14512, 2012, 1996, 8806, 1999, 1996, 2237, 1997, 100, 1010, 2379, 1996, 9042, 3675, 1012], 'udtags': ['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']}


In [15]:
print(vars(train_data.examples[0])['text'])

[2632, 1011, 100, 1024, 2137, 2749, 2730, 100, 14093, 2632, 1011, 100, 1010, 1996, 14512, 2012, 1996, 8806, 1999, 1996, 2237, 1997, 100, 1010, 2379, 1996, 9042, 3675, 1012]


In [16]:
print(vars(train_data.examples[0])['udtags'])

['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']


In [17]:
UD_TAGS.build_vocab(train_data)

In [18]:
print(f"Tokens in UD_TAG vocabulary: {len(UD_TAGS.vocab)}")

Tokens in UD_TAG vocabulary: 18


In [19]:
print(UD_TAGS.vocab.itos)

['<pad>', 'NOUN', 'PUNCT', 'VERB', 'PRON', 'ADP', 'DET', 'PROPN', 'ADJ', 'AUX', 'ADV', 'CCONJ', 'PART', 'NUM', 'SCONJ', 'X', 'INTJ', 'SYM']


In [20]:
BATCH_SIZE = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

In [21]:
class POSTagger(nn.Module):
    def __init__(self, output_dim, dropout):
        super().__init__()
        
        # bert model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # bert model hidden size
        d_model = self.bert.config.to_dict()['hidden_size']
        
        # prediction layer
        self.fc = nn.Linear(d_model, output_dim)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, text):
        # text => [seq_len, batch_size]

        text = text.permute(1, 0)
        # text => [batch_size, seq_len]

        embedded = self.dropout(self.bert(text)[0])
        # only take the outputs, pooled output is not required
        # embedded => [batch_size, seq_len, d_model]

        embedded = embedded.permute(1, 0 , 2)
        # embedded => [seq_len, batch_size, d_model]

        predictions = self.fc(self.dropout(embedded))
        # predictions => [seq_len, batch_size, output_dim]

        return predictions

In [22]:
OUTPUT_DIM = len(UD_TAGS.vocab)
DROPOUT = 0.25

model = POSTagger(OUTPUT_DIM, DROPOUT)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




In [23]:
model = model.to(device)

In [24]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 109,496,082 trainable parameters


In [25]:
# learning rate should be low, as this is a fine-tuning process
LR = 5e-5
TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index = TAG_PAD_IDX).to(device)

In [26]:
def categorical_accuracy(preds, y, tag_pad_idx):
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]])

In [27]:
def train(model, iterator, criterion, optimizer, tag_pad_idx):
    model.train()

    epoch_loss = 0
    epoch_acc = 0

    for batch in iterator:
        text = batch.text
        tags = batch.udtags
        # text => [seq_len, batch_size]
        # tags => [seq_len, batch_size]

        optimizer.zero_grad()

        logits = model(text)
        # logits => [seq_len, batch_size, output_dim]

        logits = logits.view(-1, logits.shape[-1])
        # logits => [seq_len * batch_size, output_dim]

        tags = tags.view(-1)
        # tags => [seq_len * batch_size]

        loss = criterion(logits, tags)
        acc = categorical_accuracy(logits, tags, tag_pad_idx)

        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [28]:
def evaluate(model, iterator, criterion, trg_pad_idx):
    model.eval()

    epoch_loss = 0
    epoch_acc = 0

    with torch.no_grad():
        for batch in iterator:
            text = batch.text
            tags = batch.udtags
            # text => [seq_len, batch_size]
            # tags => [seq_len, batch_size]

            optimizer.zero_grad()

            logits = model(text)
            # logits => [seq_len, batch_size, output_dim]

            logits = logits.view(-1, logits.shape[-1])
            # logits => [seq_len * batch_size, output_dim]

            tags = tags.view(-1)
            # tags => [seq_len * batch_size]

            loss = criterion(logits, tags)
            acc = categorical_accuracy(logits, tags, trg_pad_idx)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
    
    return epoch_loss / len(iterator), epoch_acc / len(iterator)


In [29]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


In [30]:
N_EPOCHS = 10
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, criterion, optimizer, TAG_PAD_IDX)
    valid_loss, val_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc * 100:.2f} | Val. Loss: {valid_loss:.3f} | Val Acc: {val_acc * 100:.2f}')



Epoch: 01 | Epoch Time: 1m 15s
	Train Loss: 0.385 | Train Acc: 89.09 | Val. Loss: 0.296 | Val Acc: 90.95
Epoch: 02 | Epoch Time: 1m 15s
	Train Loss: 0.117 | Train Acc: 96.62 | Val. Loss: 0.275 | Val Acc: 91.65
Epoch: 03 | Epoch Time: 1m 15s
	Train Loss: 0.077 | Train Acc: 97.81 | Val. Loss: 0.259 | Val Acc: 92.08
Epoch: 04 | Epoch Time: 1m 15s
	Train Loss: 0.054 | Train Acc: 98.46 | Val. Loss: 0.269 | Val Acc: 91.95
Epoch: 05 | Epoch Time: 1m 14s
	Train Loss: 0.040 | Train Acc: 98.83 | Val. Loss: 0.287 | Val Acc: 92.25
Epoch: 06 | Epoch Time: 1m 15s
	Train Loss: 0.030 | Train Acc: 99.14 | Val. Loss: 0.298 | Val Acc: 92.99
Epoch: 07 | Epoch Time: 1m 15s
	Train Loss: 0.025 | Train Acc: 99.27 | Val. Loss: 0.306 | Val Acc: 92.68
Epoch: 08 | Epoch Time: 1m 15s
	Train Loss: 0.020 | Train Acc: 99.40 | Val. Loss: 0.382 | Val Acc: 91.72
Epoch: 09 | Epoch Time: 1m 15s
	Train Loss: 0.018 | Train Acc: 99.48 | Val. Loss: 0.338 | Val Acc: 92.23
Epoch: 10 | Epoch Time: 1m 15s
	Train Loss: 0.015 | Tra

In [31]:
model.load_state_dict(torch.load('model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion, TAG_PAD_IDX)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc * 100:.2f}')


Test Loss: 0.285 | Test Acc: 91.12


In [32]:
def tag_sentence(model, device, sentence, tokenizer, text_field, tag_field):
    
    model.eval()
    
    if isinstance(sentence, str):
        tokens = tokenizer.tokenize(sentence)
    else:
        tokens = sentence
    
    numericalized_tokens = tokenizer.convert_tokens_to_ids(tokens)
    numericalized_tokens = [text_field.init_token] + numericalized_tokens
        
    unk_idx = text_field.unk_token
    
    unks = [t for t, n in zip(tokens, numericalized_tokens) if n == unk_idx]
    
    token_tensor = torch.LongTensor(numericalized_tokens)
    
    token_tensor = token_tensor.unsqueeze(-1).to(device)
         
    predictions = model(token_tensor)
    
    top_predictions = predictions.argmax(-1)
    
    predicted_tags = [tag_field.vocab.itos[t.item()] for t in top_predictions]
    
    predicted_tags = predicted_tags[1:]
        
    assert len(tokens) == len(predicted_tags)
    
    return tokens, predicted_tags, unks

In [40]:
sentence = 'The Queen will deliver a speech about the conflict in North Korea at 1pm tomorrow.'

tokens, tags, unks = tag_sentence(model, 
                                  device, 
                                  sentence,
                                  tokenizer,
                                  TEXT, 
                                  UD_TAGS)

In [41]:
unks

[]

In [42]:
print("Pred. Tag\tToken\n")

for token, tag in zip(tokens, tags):
    print(f"{tag}\t\t{token}")


Pred. Tag	Token

DET		the
PROPN		queen
AUX		will
VERB		deliver
DET		a
NOUN		speech
ADP		about
DET		the
NOUN		conflict
ADP		in
PROPN		north
PROPN		korea
ADP		at
NUM		1
NOUN		##pm
NOUN		tomorrow
PUNCT		.


In [36]:
sentence = 'I love this movie'

tokens, tags, unks = tag_sentence(model, 
                                  device, 
                                  sentence,
                                  tokenizer,
                                  TEXT, 
                                  UD_TAGS)


print("Pred. Tag\tToken\n")

for token, tag in zip(tokens, tags):
    print(f"{tag}\t\t{token}")

Pred. Tag	Token

PRON		i
VERB		love
DET		this
NOUN		movie
