# Sentiment Analysis with Frozen Bert

In another Bert text classification tutorial, we fine-tuned Bert model (i.e. updated all of its weights). In this tutorial, we will freeze (not train) the transformer and only train the remainder of the model which learns from the representations produced by the transformer. In this case we will be using a **multi-layer bi-directional GRU**, however any other model can learn from these representations.

Note: any codes are same as previous tutorial are not given textual comments for simplicity.

## Preparing the Data

In [2]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ee/fc/bd726a15ab2c66dc09306689d04da07a3770dad724f0883f0a4bfb745087/transformers-2.4.1-py3-none-any.whl (475kB)
[K     |████████████████████████████████| 481kB 7.1MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)
[K     |████████████████████████████████| 870kB 25.6MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 42.1MB/s 
[?25hCollecting tokenizers==0.0.11
[?25l  Downloading https://files.pythonhosted.org/packages/5e/36/7af38d572c935f8e0462ec7b4f7a46d73a2b3b1a938f50a5e8132d5b2dc5/tokenizers-0.0.11-cp36-cp36m-manylinux1_x86_64.whl (3.1MB)
[K     |██

In [0]:
import torch
import random
import numpy as np

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

HBox(children=(IntProgress(value=0, description='Downloading', max=231508, style=ProgressStyle(description_wid…




In [5]:
len(tokenizer.vocab)

30522

In [6]:
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

[CLS] [SEP] [PAD] [UNK]


In [7]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [8]:
MAX_INPUT_LENGTH = tokenizer.max_model_input_sizes['bert-base-uncased']
print(MAX_INPUT_LENGTH)

512


In [0]:
def tokenize_and_cut(sentence, max_length):
    tokens = tokenizer.tokenize(sentence)
    tokens = tokens[:max_length-2]
    return tokens

In [10]:
tokens = tokenizer.encode('Hello WORLD how ARE yoU?', max_length=MAX_INPUT_LENGTH)

print(tokens)

[101, 7592, 2088, 2129, 2024, 2017, 1029, 102]


In [0]:
from torchtext import data, datasets
TEXT = data.Field(
    batch_first = True,
    use_vocab = False, ### already built in Bert 
    tokenize = lambda x:tokenizer.encode(x, max_length=MAX_INPUT_LENGTH),
    # tokenize = lambda x:tokenize_and_cut(x, max_length=MAX_INPUT_LENGTH),
    # preprocessing = tokenizer.convert_tokens_to_ids,
    # init_token = init_token_idx,
    # eos_token = eos_token_idx,
    pad_token = pad_token_idx,
    unk_token = unk_token_idx
)

LABEL = data.LabelField(dtype=torch.float)

In [12]:
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
train_data, valid_data = train_data.split(random_state = random.seed(SEED))

downloading aclImdb_v1.tar.gz


aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:07<00:00, 10.5MB/s]


In [13]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 17500
Number of validation examples: 7500
Number of testing examples: 25000


In [14]:
print(vars(train_data.examples[0]))

{'text': [101, 2023, 2003, 2069, 5399, 8702, 2005, 4599, 1997, 1000, 2919, 3185, 1000, 4024, 1012, 2009, 2003, 2062, 4276, 19927, 2005, 2493, 1997, 3359, 1005, 1055, 3769, 3226, 1024, 1996, 4827, 2015, 1010, 1996, 7390, 1010, 1996, 13818, 1010, 1998, 2008, 2307, 1000, 2308, 1005, 1055, 5622, 2497, 1000, 2617, 1997, 1996, 2220, 3359, 1005, 1055, 1010, 2043, 2009, 2001, 2145, 4840, 1998, 3117, 2005, 1037, 2969, 1011, 4846, 1010, 2981, 2450, 2000, 4839, 1012, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 1000, 3565, 5428, 3600, 1000, 1006, 11830, 10454, 3385, 1007, 2018, 1037, 12256, 12928, 11272, 10377, 2075, 2065, 24646, 7096, 11787, 2476, 1006, 2044, 2035, 1010, 2054, 2003, 1037, 3462, 16742, 2021, 1037, 13877, 2012, 2382, 1010, 2199, 2519, 1011, 1011, 2008, 3632, 2005, 1996, 3287, 3924, 2205, 1007, 1010, 2016, 7771, 2105, 2007, 3674, 2273, 1010, 2071, 4047, 2841, 1998, 2500, 1006, 2007, 16894, 1007, 1998, 2347, 1005, 1056, 5079, 2091, 2000, 2505, 1012, 2023, 2003, 1996, 2785, 1997, 

In [15]:
tokens = tokenizer.convert_ids_to_tokens(vars(train_data.examples[0])['text'])

print(tokens)

['[CLS]', 'this', 'is', 'only', 'somewhat', 'attractive', 'for', 'fans', 'of', '"', 'bad', 'movie', '"', 'entertainment', '.', 'it', 'is', 'more', 'worth', '##while', 'for', 'students', 'of', '1970', "'", 's', 'pop', 'culture', ':', 'the', 'fashion', '##s', ',', 'the', 'furniture', ',', 'the', 'attitudes', ',', 'and', 'that', 'great', '"', 'women', "'", 's', 'li', '##b', '"', 'moment', 'of', 'the', 'early', '1970', "'", 's', ',', 'when', 'it', 'was', 'still', 'fresh', 'and', 'novel', 'for', 'a', 'self', '-', 'employed', ',', 'independent', 'woman', 'to', 'exist', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', '"', 'super', '##chi', '##ck', '"', '(', 'joyce', 'jill', '##son', ')', 'had', 'a', 'mon', '##eta', '##rily', 'reward', '##ing', 'if', 'stu', '##lt', '##ifying', 'career', '(', 'after', 'all', ',', 'what', 'is', 'a', 'flight', 'attendant', 'but', 'a', 'waitress', 'at', '30', ',', '000', 'feet', '-', '-', 'that', 'goes', 'for', 'the', 'male', 'ones', 'too', ')', ',', 'she', 'slept'

In [16]:
LABEL.build_vocab(train_data)
print(LABEL.vocab.stoi)

defaultdict(<function _default_unk_index at 0x7fcfff64c488>, {'neg': 0, 'pos': 1})


In [0]:
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device)

## Build the Model

In [18]:
from transformers import BertModel
bert = BertModel.from_pretrained('bert-base-uncased')

HBox(children=(IntProgress(value=0, description='Downloading', max=361, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=440473133, style=ProgressStyle(description_…




Next, build the real model above the base Bert model. Instead of using an embedding layer to get embeddings for our text, we'll be using the pre-trained transformer model. These embeddings will then be fed into a GRU to produce a prediction for the sentiment of the input sentence. We get the embedding dimension size (called the `hidden_size`) from the transformer via its config attribute. The rest of the initialization is standard.

Within the forward pass, we wrap the transformer in a `no_grad` to ensure no gradients are calculated over this part of the model. The transformer actually returns the embeddings for the whole sequence as well as a *pooled* output. The [documentation](https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel) states that the pooled output is "usually not a good summary of the semantic content of the input, you’re often better with averaging or pooling the sequence of hidden-states for the whole input sequence", hence we will not be using it. The rest of the forward pass is the standard implementation of a **Bi-GRU** model, where we take the hidden state over the final time-step, and pass it through a linear layer to get our predictions.





In [0]:
import torch.nn as nn

class BertGRUSentiment(nn.Module):
    def __init__(self, bert, hidden_dim, output_dim, n_layers, bidirectional, dropout):
        super().__init__()
        self.bert = bert 
        embedding_dim = bert.config.to_dict()['hidden_size'] ## i.e. 756
        self.rnn = nn.GRU(embedding_dim, hidden_dim, 
                          num_layers=n_layers, bidirectional=bidirectional, 
                          batch_first=True, ### required by the output of bert
                          dropout=0 if n_layers < 2 else dropout)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)

    def forward(self, text):
        # text = [batch_size, sent_len]
        with torch.no_grad(): ### turn off grad
            embedded = self.bert(text)[0] ### get the output of Bert top layer
        # embedded = [batch_size, sent_len, emb_dim]

        _, hidden = self.rnn(embedded)
        # hidden = [n_layers*n_directions, batch_size, hidden_dim]

        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
        else:
            hidden = self.dropout(hidden[-1,:,:])
        # hidden = [batch_size, hidden_dim]

        output = self.out(hidden)
        # output = [batch_size, output_dim]

        return output






Next, we create an instance of our model using standard hyperparameters.

In [0]:
HIDDEN_DIM = 256
OUTPUT_DIM = 1 
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25

model = BertGRUSentiment(bert, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT)

In order to freeze paramers (not train them) we need to set their `requires_grad` attribute to `False`. To do this, we simply loop through all of the `named_parameters` in our model and if they're a part of the `bert` transformer model, we set `requires_grad = False`. 

In [0]:
for name, param in model.named_parameters():
    if name.startswith('bert'):
        param.requires_grad = False

In [22]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 2,759,169 trainable parameters


Comparing with fine-tuning the whole Bert model, the # of parameters has reduced from *100M+* to less than *3M*!

We can double check the names of the trainable parameters, ensuring they make sense. As we can see, they are all the parameters of the GRU (rnn) and the linear layer (out).

In [23]:
for name, param in model.named_parameters():                
    if param.requires_grad:
        print(name)

rnn.weight_ih_l0
rnn.weight_hh_l0
rnn.bias_ih_l0
rnn.bias_hh_l0
rnn.weight_ih_l0_reverse
rnn.weight_hh_l0_reverse
rnn.bias_ih_l0_reverse
rnn.bias_hh_l0_reverse
rnn.weight_ih_l1
rnn.weight_hh_l1
rnn.bias_ih_l1
rnn.bias_hh_l1
rnn.weight_ih_l1_reverse
rnn.weight_hh_l1_reverse
rnn.bias_ih_l1_reverse
rnn.bias_hh_l1_reverse
out.weight
out.bias


## Train and Test the Model

In [0]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters())

In [0]:
criterion = nn.BCEWithLogitsLoss()

In [0]:
model = model.to(device)
criterion = criterion.to(device)

In [0]:
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

In [0]:
def train(model, iterator, optimizer, criterion):
    model.train()
    epoch_loss = 0
    epoch_acc = 0

    for batch in iterator:
        optimizer.zero_grad()

        predictions = model(batch.text).squeeze(1)
        loss = criterion(predictions, batch.label)
        acc = binary_accuracy(predictions, batch.label)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [0]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            predictions = model(batch.text).squeeze(1)
            
            loss = criterion(predictions, batch.label)
            
            acc = binary_accuracy(predictions, batch.label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [0]:
import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [31]:
N_EPOCHS = 5
best_valid_loss = float("inf")
best_model = "bert_frozen.pt"

for epoch in range(N_EPOCHS):
    start_time = time.time()

    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), best_model)
    
    print(f'Epoch: {epoch+1} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 1 | Epoch Time: 7m 14s
	Train Loss: 0.505 | Train Acc: 74.03%
	 Val. Loss: 0.395 |  Val. Acc: 83.69%
Epoch: 2 | Epoch Time: 7m 14s
	Train Loss: 0.275 | Train Acc: 88.85%
	 Val. Loss: 0.258 |  Val. Acc: 89.70%
Epoch: 3 | Epoch Time: 7m 14s
	Train Loss: 0.233 | Train Acc: 90.97%
	 Val. Loss: 0.227 |  Val. Acc: 91.00%
Epoch: 4 | Epoch Time: 7m 14s
	Train Loss: 0.207 | Train Acc: 92.04%
	 Val. Loss: 0.232 |  Val. Acc: 90.90%
Epoch: 5 | Epoch Time: 7m 14s
	Train Loss: 0.181 | Train Acc: 93.02%
	 Val. Loss: 0.228 |  Val. Acc: 91.64%


In [36]:
model.load_state_dict(torch.load(best_model))

test_loss, test_acc = evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.202 | Test Acc: 91.91%


The performance is close to, or even slightly better than the fine-tuning results of the whole Bert.

## Inference on New Sentence

In [0]:
def predict_sentiment(model, tokenizer, sentence):
    model.eval()
    #tokens = tokenizer.tokenize(sentence)
    #tokens = tokens[:max_input_length-2]
    #indexed = [init_token_idx] + tokenizer.convert_tokens_to_ids(tokens) + [eos_token_idx]
    indexed = tokenizer.encode(sentence)
    tensor = torch.LongTensor(indexed).to(device)
    tensor = tensor.unsqueeze(0)
    prediction = torch.sigmoid(model(tensor))
    return prediction.item()

In [34]:
predict_sentiment(model, tokenizer, "This film is terrible")

0.01239826250821352

In [35]:
predict_sentiment(model, tokenizer, "This film is great")

0.9740698933601379

## References:

https://colab.research.google.com/github/bentrevett/pytorch-sentiment-analysis/blob/master/6%20-%20Transformers%20for%20Sentiment%20Analysis.ipynb#scrollTo=hSIop_NnLdCx