# Bringing contextual word representations into your models

In [1]:
__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Spring 2022"

## Contents

1. [Overview](#Overview)
1. [General set-up](#General-set-up)
1. [Hugging Face BERT models and tokenizers](#Hugging-Face-BERT-models-and-tokenizers)
1. [BERT featurization with Hugging Face](#BERT-featurization-with-Hugging-Face)
    1. [Simple feed-forward experiment](#Simple-feed-forward-experiment)
    1. [A feed-forward experiment with the sst module](#A-feed-forward-experiment-with-the-sst-module)
    1. [An RNN experiment with the sst module](#An-RNN-experiment-with-the-sst-module)
1. [BERT fine-tuning with Hugging Face](#BERT-fine-tuning-with-Hugging-Face)
    1. [HfBertClassifier](#HfBertClassifier)
    1. [HfBertClassifier experiment](#HfBertClassifier-experiment)

## Overview

This notebook provides a basic introduction to using pre-trained [BERT](https://github.com/google-research/bert) representations with the Hugging Face library. It is meant as a practical companion to our lecture on contextual word representations. The goal of this notebook is just to help you use these representations in your own work.

If you haven't already, I encourage you to review the notebook [vsm_03_contextualreps.ipynb](vsm_03_contextualreps.ipynb) before working with this one. That notebook covers the fundamentals of these models; this one dives into the details more quickly.

A number of the experiments in this notebook are resource-intensive. I've included timing information for the expensive steps, to give you a sense for how long things are likely to take. I ran this notebook on a laptop with a single NVIDIA RTX 2080 GPU. 

## General set-up

The following are requirements that you'll already have met if you've been working in this repository. As you can see, we'll use the [Stanford Sentiment Treebank](sst_01_overview.ipynb) for illustrations, and we'll try out a few different deep learning models.

In [1]:
%env TOKENIZERS_PARALLELISM=true
import os
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import transformers
from transformers import BertModel, DistilBertTokenizerFast

from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
from torch_rnn_classifier import TorchRNNModel
from torch_rnn_classifier import TorchRNNClassifier
from torch_rnn_classifier import TorchRNNClassifierModel
from torch_rnn_classifier import TorchRNNClassifier
import sst
import utils


env: TOKENIZERS_PARALLELISM=true


In [2]:
utils.fix_random_seeds()
device = "cuda:0" if torch.cuda.is_available() else "cpu"


In [3]:
SST_HOME = os.path.join("data", "sentiment")

The `transformers` library does a lot of logging. To avoid ending up with a cluttered notebook, I am changing the logging level. You might want to skip this as you scale up to building production systems, since the logging is very good – it gives you a lot of insights into what the models and code are doing.

In [4]:
transformers.logging.set_verbosity_error()

## Hugging Face BERT models and tokenizers

We'll illustrate with the BERT-base cased model:

In [5]:
weights_name = 'bert-base-cased'

There are lots other options for pretrained weights. See [this Hugging Face directory](https://huggingface.co/models).

Next, we specify a tokenizer and a model that match both each other and our choice of pretrained weights:

In [6]:
bert_tokenizer = DistilBertTokenizerFast.from_pretrained(weights_name)

In [7]:
bert_model = BertModel.from_pretrained(weights_name).to(device)

For modeling (as opposed to creating static representations), we will mostly process examples in batches – generally very small ones, as these models consume _a lot_ of memory. Here's a small batch of texts to use as the starting point for illustrations:

In [8]:
example_texts = [
    "Encode sentence 1. [SEP] And sentence 2!",
    "Bert knows Snuffleupagus"]

We will often need to pad (and perhaps truncate) token lists so that we can work with fixed-dimensional tensors: The `batch_encode_plus` has a lot of options for doing this:

In [9]:
example_ids = bert_tokenizer.batch_encode_plus(
    example_texts,
    add_special_tokens=True,
    return_attention_mask=True,
    padding='longest')

In [10]:
example_ids.keys()

dict_keys(['input_ids', 'attention_mask'])

The `token_type_ids` is used for multi-text inputs like NLI. The `'input_ids'` field gives the indices for each of the two examples:

In [11]:
example_ids['input_ids']

[[101, 13832, 13775, 5650, 122, 119, 102, 1262, 5650, 123, 106, 102],
 [101, 15035, 3520, 156, 14787, 13327, 4455, 28026, 1116, 102, 0, 0]]

Notice that the final two tokens of the second example are pad tokens.

For fine-tuning, we want to avoid attending to padded tokens. The `'attention_mask'` captures the needed mask, which we'll be able to feed directly to the pretrained BERT model:

In [12]:
example_ids['attention_mask']

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]

Finally, we can run these indices and masks through the pretrained model:

In [13]:
X_example = torch.tensor(example_ids['input_ids']).to(device)
X_example_mask = torch.tensor(example_ids['attention_mask']).to(device)

with torch.no_grad():
    reps = bert_model(X_example, attention_mask=X_example_mask)

Hugging Face BERT models create a special `pooler_output` representation that is the final representation above the [CLS] extended with a single layer of parameters:

In [14]:
reps.pooler_output.shape

torch.Size([2, 768])

We have two examples, each representented by a single vector of dimension 768, which is $d_{model}$ for BERT base using the notation from [the original Transformers paper](https://arxiv.org/abs/1706.03762). This is an easy basis for fine-tuning, as we will see.

We can also access the final output for each state:

In [15]:
reps.last_hidden_state.shape

torch.Size([2, 12, 768])

Here, we have 2 examples, each padded to the length of the longer one (12), and each of those representations has dimension 768. These representations can be used for sequence modeling, or pooled somehow for simple classifiers.

Those are all the essential ingredients for working with these parameters in Hugging Face. Of course, the library has a lot of other functionality, but the above suffices to featurize and to fine-tune.

## BERT featurization with Hugging Face

To start, we'll use the Hugging Face interfaces just to featurize examples to create inputs to a separate model. In this setting, the BERT parameters are frozen.

In [16]:
def bert_phi(text):
    input_ids = bert_tokenizer.encode(text, add_special_tokens=True)
    X = torch.tensor([input_ids]).to(device)
    with torch.no_grad():
        reps = bert_model(X)
        return reps.last_hidden_state.squeeze(0).to("cpu").numpy()

### Simple feed-forward experiment

For a simple feed-forward experiment, we can get the representation of the `[CLS]` tokens and use them as the inputs to a shallow neural network:

In [17]:
def bert_classifier_phi(text):
    reps = bert_phi(text)
    #return reps.mean(axis=0)  # Another good, easy option.
    return reps[0]

Next we read in the SST train and dev splits:

In [18]:
train = sst.train_reader(SST_HOME)

dev = sst.dev_reader(SST_HOME)

Split the input/output pairs out into separate lists:

In [19]:
X_str_train = train.sentence.values
y_train = train.label.values

X_str_dev = dev.sentence.values
y_dev = dev.label.values

In the next step, we featurize all of the examples. These steps are likely to be the slowest in these experiments:

In [21]:
%time X_train = [bert_classifier_phi(text) for text in X_str_train]

CPU times: user 2min 13s, sys: 835 ms, total: 2min 14s
Wall time: 2min 14s


In [22]:
%time X_dev = [bert_classifier_phi(text) for text in X_str_dev]

CPU times: user 17.2 s, sys: 67.9 ms, total: 17.3 s
Wall time: 17.3 s


Now that all the examples are featurized, we can fit a model and evaluate it:

In [28]:
model = TorchShallowNeuralClassifier(
    early_stopping=True,
    hidden_dim=300, )

In [29]:
%time _ = model.fit(X_train, y_train)

Stopping after epoch 130. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.291560411453247

CPU times: user 2min 3s, sys: 223 ms, total: 2min 3s
Wall time: 13.2 s


In [30]:
preds = model.predict(X_dev)

In [31]:
print(classification_report(y_dev, preds, digits=3))

              precision    recall  f1-score   support

    negative      0.700     0.808     0.751       428
     neutral      0.500     0.192     0.278       229
    positive      0.715     0.836     0.771       444

    accuracy                          0.691      1101
   macro avg      0.638     0.612     0.600      1101
weighted avg      0.665     0.691     0.660      1101



### A feed-forward experiment with the sst module

It is straightforward to conduct experiments like the above using `sst.experiment`, which will enable you to do a wider range of experiments without writing or copy-pasting a lot of code. 

In [35]:
def fit_shallow_network(X, y):
    mod = TorchShallowNeuralClassifier(
        hidden_dim=300,
        early_stopping=True)
    mod.fit(X, y)
    return mod

In [36]:
%%time
_ = sst.experiment(
    sst.train_reader(SST_HOME),
    bert_classifier_phi,
    fit_shallow_network,
    assess_dataframes=sst.dev_reader(SST_HOME),
    vectorize=False)  # Pass in the BERT reps directly!

Stopping after epoch 137. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.239248812198639

              precision    recall  f1-score   support

    negative      0.699     0.799     0.746       428
     neutral      0.463     0.162     0.239       229
    positive      0.705     0.845     0.768       444

    accuracy                          0.685      1101
   macro avg      0.622     0.602     0.585      1101
weighted avg      0.652     0.685     0.650      1101

CPU times: user 4min 40s, sys: 1.14 s, total: 4min 41s
Wall time: 2min 45s


### An RNN experiment with the sst module

We can also use BERT representations as the input to an RNN. There is just one key change from how we used these models before:

* Previously, we would feed in lists of tokens, and they would be converted to indices into a fixed embedding space. This presumes that all words have the same representation no matter what their context is. 

* With BERT, we skip the embedding entirely and just feed in lists of BERT vectors, which means that the same word can be represented in different ways.

`TorchRNNClassifier` supports this via `use_embedding=False`. In turn, you needn't supply a vocabulary:

In [37]:
def fit_rnn(X, y):
    mod = TorchRNNClassifier(
        vocab=[],
        early_stopping=True,
        use_embedding=False)  # Pass in the BERT hidden states directly!
    mod.fit(X, y)
    return mod

In [38]:
%%time
_ = sst.experiment(
    sst.train_reader(SST_HOME),
    bert_phi,
    fit_rnn,
    assess_dataframes=sst.dev_reader(SST_HOME),
    vectorize=False)  # Pass in the BERT hidden states directly!

Stopping after epoch 34. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.5571811199188232

              precision    recall  f1-score   support

    negative      0.741     0.694     0.717       428
     neutral      0.370     0.266     0.310       229
    positive      0.690     0.831     0.754       444

    accuracy                          0.660      1101
   macro avg      0.600     0.597     0.593      1101
weighted avg      0.643     0.660     0.647      1101

CPU times: user 6min 56s, sys: 1min 21s, total: 8min 17s
Wall time: 3min 7s


## BERT fine-tuning with Hugging Face

The above experiments are quite successful – BERT gives us a reliable boost compared to other methods we've explored for the SST task. However, we might expect to do even better if we fine-tune the BERT parameters as part of fitting our SST classifier. To do that, we need to incorporate the Hugging Face BERT model into our classifier. This too is quite straightforward.

### HfBertClassifier

The most important step is to create an `nn.Module` subclass that has, for its parameters, both the BERT model and parameters for our own classifier. Here we define a very simple fine-tuning set-up in which some layers built on top of the output corresponding to `[CLS]` are used as the basis for the SST classifier:

In [39]:
class HfBertClassifierModel(nn.Module):
    def __init__(self, n_classes, weights_name='bert-base-cased'):
        super().__init__()
        self.n_classes = n_classes
        self.weights_name = weights_name
        self.bert = BertModel.from_pretrained(self.weights_name)
        self.bert.train()
        self.hidden_dim = self.bert.embeddings.word_embeddings.embedding_dim
        # The only new parameters -- the classifier:
        self.classifier_layer = nn.Linear(
            self.hidden_dim, self.n_classes)

    def forward(self, indices, mask):
        reps = self.bert(
            indices, attention_mask=mask)
        return self.classifier_layer(reps.pooler_output)

As you can see, `self.bert` does the heavy-lifting: it reads in all the pretrained BERT parameters, and I've specified `self.bert.train()` just to make sure that these parameters can be updated during our training process. 

In `forward`, `self.bert` is used to process inputs, and then `pooler_output` is fed into `self.classifier_layer`. Hugging Face has already added a layer on top of the actual output for `[CLS]`, so we can specify the model as

$$
\begin{align}
[h_{1}, \ldots, h_{n}] &= \text{BERT}([x_{1}, \ldots, x_{n}]) \\
h &= \tanh(h_{1}W_{hh} + b_{h}) \\
y &= \textbf{softmax}(hW_{hy} + b_{y})
\end{align}$$

for a tokenized input sequence $[x_{1}, \ldots, x_{n}]$. 

The Hugging Face documentation somewhat amusingly says, of `pooler_output`,

> This output is usually _not_ a good summary of the semantic content of the input, you're often better with averaging or pooling the sequence of hidden-states for the whole input sequence.

which is entirely reasonable, but it will require more resources, so we'll do the simpler thing here.

For the training and prediction interface, we can subclass `TorchShallowNeuralClassifier` so that we don't have to write any of our own data-handling, training, or prediction code. The central changes are using `HfBertClassifierModel` in `build_graph` and processing the data with `batch_encode_plus`.

In [44]:
class HfBertClassifier(TorchShallowNeuralClassifier):
    def __init__(self, weights_name, *args, **kwargs):
        self.weights_name = weights_name
        self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.weights_name)
        super().__init__(*args, **kwargs)
        self.params += ['weights_name']

    def build_graph(self):
        return HfBertClassifierModel(self.n_classes_, self.weights_name)

    def build_dataset(self, X, y=None):
        data = self.tokenizer.batch_encode_plus(
            X,
            max_length=None,
            add_special_tokens=True,
            padding='longest',
            return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        if y is None:
            dataset = torch.utils.data.TensorDataset(indices, mask)
        else:
            self.classes_ = sorted(set(y))
            self.n_classes_ = len(self.classes_)
            class2index = dict(zip(self.classes_, range(self.n_classes_)))
            y = [class2index[label] for label in y]
            y = torch.tensor(y)
            dataset = torch.utils.data.TensorDataset(indices, mask, y)
        return dataset

### HfBertClassifier experiment

That's it! Let's see how we do on the SST binary, root-only problem. Because fine-tuning is expensive, we'll conduct a modest hyperparameter search and run the model for just one epoch per setting evaluation, as we did when [assessing NLI models](nli_02_models.ipynb).

In [45]:
def bert_fine_tune_phi(text):
    return text

In [46]:
def fit_hf_bert_classifier_with_hyperparameter_search(X, y):
    basemod = HfBertClassifier(
        weights_name='bert-base-cased',
        batch_size=8,  # Small batches to avoid memory overload.
        max_iter=1,  # We'll search based on 1 iteration for efficiency.
        n_iter_no_change=5,   # Early-stopping params are for the
        early_stopping=True)  # final evaluation.

    param_grid = {
        'gradient_accumulation_steps': [1, 4, 8],
        'eta': [0.00005, 0.0001, 0.001],
        'hidden_dim': [100, 200, 300]}

    bestmod = utils.fit_classifier_with_hyperparameter_search(
        X, y, basemod, cv=3, param_grid=param_grid)

    return bestmod

In [47]:
%%time
bert_classifier_xval = sst.experiment(
    sst.train_reader(SST_HOME),
    bert_fine_tune_phi,
    fit_hf_bert_classifier_with_hyperparameter_search,
    assess_dataframes=sst.dev_reader(SST_HOME),
    vectorize=False)  # Pass in the BERT hidden state directly!

Finished epoch 1 of 1; error is 93.697665332816541

Best params: {'eta': 0.0001, 'gradient_accumulation_steps': 8, 'hidden_dim': 200}
Best score: 0.583
              precision    recall  f1-score   support

    negative      0.640     0.967     0.770       428
     neutral      0.375     0.013     0.025       229
    positive      0.794     0.797     0.796       444

    accuracy                          0.700      1101
   macro avg      0.603     0.593     0.530      1101
weighted avg      0.647     0.700     0.625      1101

CPU times: user 1h 15min 20s, sys: 29 s, total: 1h 15min 49s
Wall time: 1h 11min 54s


And now on to the final test-set evaluation, using the best model from above:

In [48]:
optimized_bert_classifier = bert_classifier_xval['model']

In [49]:
# Remove the rest of the experiment results to clear out some memory:
del bert_classifier_xval

In [50]:
def fit_optimized_hf_bert_classifier(X, y):
    optimized_bert_classifier.max_iter = 1000
    optimized_bert_classifier.fit(X, y)
    return optimized_bert_classifier

In [51]:
test_df = sst.sentiment_reader(
    os.path.join(SST_HOME, "sst3-test-labeled.csv"))

In [52]:
%%time
_ = sst.experiment(
    sst.train_reader(SST_HOME),
    bert_fine_tune_phi,
    fit_optimized_hf_bert_classifier,
    assess_dataframes=test_df,
    vectorize=False)  # Pass in the BERT hidden state directly!

Stopping after epoch 10. Validation score did not improve by tol=1e-05 for more than 5 epochs. Final error is 6.126567473358591

              precision    recall  f1-score   support

    negative      0.799     0.707     0.750       912
     neutral      0.312     0.409     0.354       389
    positive      0.830     0.815     0.822       909

    accuracy                          0.699      2210
   macro avg      0.647     0.644     0.642      2210
weighted avg      0.726     0.699     0.710      2210

CPU times: user 9min 1s, sys: 1.09 s, total: 9min 3s
Wall time: 8min 57s
