# Bringing contextual word representations into your models

In [1]:
__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Fall 2020"

## Contents

1. [Overview](#Overview)
1. [General set-up](#General-set-up)
1. [Hugging Face BERT interfaces](#Hugging-Face-BERT-interfaces)
  1. [Hugging Face BERT set-up](#Hugging-Face-BERT-set-up)
  1. [Hugging Face BERT basics](#Hugging-Face-BERT-basics)
  1. [BERT featurization with Hugging Face](#BERT-featurization-with-Hugging-Face)
    1. [Simple feed-forward experiment](#Simple-feed-forward-experiment)
    1. [A feed-forward experiment with the sst module](#A-feed-forward-experiment-with-the-sst-module)
    1. [An RNN experiment with the sst module](#An-RNN-experiment-with-the-sst-module)
  1. [BERT fine-tuning with Hugging Face](#BERT-fine-tuning-with-Hugging-Face)
    1. [HfBertClassifier](#HfBertClassifier)
    1. [HfBertClassifier experiment](#HfBertClassifier-experiment)
1. [Using ELMo](#Using-ELMo)
  1. [ELMo Allen NLP set-up](#ELMo-Allen-NLP-set-up)
  1. [ELMo fine-tuning](#ELMo-fine-tuning)
    1. [AllenNLP ELMo interfaces](#AllenNLP-ELMo-interfaces)
    1. [ElmoRNNClassifier](#ElmoRNNClassifier)
    1. [ElmoRNNClassifier experiment](#ElmoRNNClassifier-experiment)

## Overview

This notebook provides a basic introduction to using pre-trained [BERT](https://github.com/google-research/bert) and [ELMo](https://allennlp.org/elmo) representations. It is meant as a practical companion to our lecture on contextual word representations. The goal of this notebook is just to help you use these representations in your own work. The BERT and ELMo teams have done amazing work to make these resources available to the community. Many projects can benefit from them, so it is probably worth your time to experiment.

This notebook should be considered an experimental extension to the regular course materials. It has some special requirements – libraries and data files – that are not part of the core requirements for this repository. All these tools are very new and being updated frequently, so you might need to do some fiddling to get all of this to work. As I said, though, it's probably worth the effort!

A number of the experiments in this notebook are resource intensive. I've included timing information for the expensive steps, to give you a sense for how long things are likely to take. I ran this notebook on a laptop with a single NVIDIA RTX 2080 GPU. 

## General set-up

The following are requirements that you'll already have met if you've been working in this repository. As you can see, we'll use the [Stanford Sentiment Treebank](sst_01_overview.ipynb) for illustrations, and we'll try out a few different deep learning models.

In [2]:
import os
import sst
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
from torch_rnn_classifier import TorchRNNModel
from torch_rnn_classifier import TorchRNNClassifier
from torch_rnn_classifier import TorchRNNClassifierModel
from torch_rnn_classifier import TorchRNNClassifier
from sklearn.metrics import classification_report
import utils

In [3]:
utils.fix_random_seeds()

In [5]:
PATH_TO_DATA = utils.get_data_path()

SST_HOME = os.path.join(PATH_TO_DATA, "trees")

## Hugging Face BERT interfaces

### Hugging Face BERT set-up

To install this library, run

```pip install transformers```

I've tested this code with versions 2.4, 2.5, and 2.11 of `transformers`. Try to get at least 2.5. It requires `pip >= 20` and, I think, a version of [Rust](https://www.rust-lang.org) at least as high as 1.21.1.

In [6]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

The `transformers` library does a lot of logging. To avoid ending up with a cluttered notebook, I am changing the logging level. You might want to skip this as you scale up to building production systems, since the logging is very good – it gives you a lot of insights into what the models and code are doing.

In [7]:
import logging
logger = logging.getLogger()
logger.level = logging.ERROR

### Hugging Face BERT basics

To start, let's get a feel for the basic API that `transformers` provides. The first step is specifying the pretrained parameters we'll be using:

In [8]:
hf_weights_name = 'bert-base-cased'

There are lots other options for pretrained weights. See [this section of the project README.md](https://github.com/huggingface/transformers#quick-tour) for a good overview and code that documents how these weights align with different Transformer model classes.

Next, we specify a tokenizer and a model that match both each other and our choice of pretrained weights:

In [9]:
hf_tokenizer = BertTokenizer.from_pretrained(hf_weights_name)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




In [10]:
hf_model = BertModel.from_pretrained(hf_weights_name)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




It's illuminating to see what the tokenizer does to example texts:

In [11]:
hf_example_texts = [
    "Encode sentence 1. [SEP] And sentence 2!",
    "Bert knows Snuffleupagus"]

The `encode` method maps individual strings to indices into the underlying embedding used by the model:

In [12]:
ex0_ids = hf_tokenizer.encode(hf_example_texts[0], add_special_tokens=True)

ex0_ids

[101, 13832, 13775, 5650, 122, 119, 102, 1262, 5650, 123, 106, 102]

We can get a better feel for what these representations are like by mapping the indices back to "words":

In [13]:
hf_tokenizer.convert_ids_to_tokens(ex0_ids)

['[CLS]',
 'En',
 '##code',
 'sentence',
 '1',
 '.',
 '[SEP]',
 'And',
 'sentence',
 '2',
 '!',
 '[SEP]']

For modeling, we will often need to pad (and perhaps truncate) token lists so that we can work with fixed-dimensional tensors: The `batch_encode_plus` has a lot of options for doing this:

In [15]:
hf_example_ids = hf_tokenizer.batch_encode_plus(
    hf_example_texts,
    add_special_tokens=True,
    return_attention_mask=True,
    padding='longest')

In [16]:
hf_example_ids.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

The `token_type_ids` is used for multi-text inputs like NLI. The `'input_ids'` field gives the indices for each of the two examples:

In [17]:
hf_example_ids['input_ids']

[[101, 13832, 13775, 5650, 122, 119, 102, 1262, 5650, 123, 106, 102],
 [101, 15035, 3520, 156, 14787, 13327, 4455, 28026, 1116, 102, 0, 0]]

For fine-tuning, we want to avoid attending to padded tokens. The `'attention_mask'` captures the needed mask, which we'll be able to feed directly to the pretrained BERT model:

In [18]:
hf_example_ids['attention_mask']

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]

Finally, we can run these indices and masks through the pretrained model:

In [19]:
X_hf_example = torch.tensor(hf_example_ids['input_ids'])
X_hf_example_mask = torch.tensor(hf_example_ids['attention_mask'])

with torch.no_grad():
    hf_final_hidden_states, cls_output = hf_model(
        X_hf_example, attention_mask=X_hf_example_mask)

BERT representations are pretty large – this shows the shape of the tensor for 2 examples, with the second padded to the length of the larger one in the batch (12). The individual representations have dimensionality 768.

In [20]:
hf_final_hidden_states.shape

torch.Size([2, 12, 768])

Those are all the essential ingredients for working with these parameters in Hugging Face. Of course, the library has a lot of other functionality, but the above suffices to featurize and to fine-tune.

### BERT featurization with Hugging Face

To start, we'll use the Hugging Face interfaces just to featurize examples to create inputs to a separate model. In this setting, the BERT parameters are frozen. The heart of this approach is the following featurizer, which flattens an SST tree into a string, tokenizes it, and computes its hidden representations:

In [21]:
def hugging_face_bert_phi(tree):
    s = " ".join(tree.leaves())
    input_ids = hf_tokenizer.encode(s, add_special_tokens=True)
    X = torch.tensor([input_ids])
    with torch.no_grad():
        final_hidden_states, cls_output = hf_model(X)
        return final_hidden_states.squeeze(0).numpy()

#### Simple feed-forward experiment

For a simple feed-forward experiment, we can get the representation of the `[CLS]` tokens and use them as the inputs to a shallow neural network:

In [22]:
def hugging_face_bert_classifier_phi(tree):
    reps = hugging_face_bert_phi(tree)
    #return reps.mean(axis=0)  # Another good, easy option.
    return reps[0]

Next we read in the SST train and dev portions as a lists of `(tree, label)` pairs:

In [23]:
hf_train = list(sst.train_reader(SST_HOME, class_func=sst.binary_class_func))

hf_dev = list(sst.dev_reader(SST_HOME, class_func=sst.binary_class_func))

Split the input/output pairs out into separate lists:

In [24]:
X_hf_tree_train, y_hf_train = zip(*hf_train)

X_hf_tree_dev, y_hf_dev = zip(*hf_dev)

In the next step, we featurize all of the examples. These steps are likely to be the slowest in these experiments:

In [27]:
%time X_hf_train = [hugging_face_bert_classifier_phi(tree) for tree in X_hf_tree_train]

CPU times: user 39min 5s, sys: 30 s, total: 39min 35s
Wall time: 4min 57s


In [30]:
%time X_hf_dev = [hugging_face_bert_classifier_phi(tree) for tree in X_hf_tree_dev]

CPU times: user 4min 55s, sys: 4.07 s, total: 4min 59s
Wall time: 37.5 s


Now that all the examples are featurized, we can fit a model and evaluate it:

In [31]:
hf_mod = TorchShallowNeuralClassifier(
    early_stopping=True,
    hidden_dim=300)

In [32]:
%time _ = hf_mod.fit(X_hf_train, y_hf_train)

Stopping after epoch 38. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 2.018371820449829

CPU times: user 22.9 s, sys: 1.79 s, total: 24.6 s
Wall time: 5.88 s


In [33]:
hf_preds = hf_mod.predict(X_hf_dev)

In [34]:
print(classification_report(y_hf_dev, hf_preds, digits=3))

              precision    recall  f1-score   support

    negative      0.850     0.834     0.842       428
    positive      0.843     0.858     0.850       444

    accuracy                          0.846       872
   macro avg      0.846     0.846     0.846       872
weighted avg      0.846     0.846     0.846       872



#### A feed-forward experiment with the sst module

It is straightforward to conduct experiments like the above using `sst.experiment`, which will enable you to do a wider range of experiments without writing or copy-pasting a lot of code. 

In [35]:
def fit_hf_shallow_network(X, y):
    mod = TorchShallowNeuralClassifier(
        hidden_dim=300,
        early_stopping=True)
    mod.fit(X, y)
    return mod

In [36]:
%%time
_ = sst.experiment(
    SST_HOME,
    hugging_face_bert_classifier_phi,
    fit_hf_shallow_network,
    train_reader=sst.train_reader,
    assess_reader=sst.dev_reader,
    class_func=sst.binary_class_func,
    vectorize=False)  # Pass in the BERT reps directly!

Stopping after epoch 22. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 2.384592831134796

              precision    recall  f1-score   support

    negative      0.839     0.839     0.839       428
    positive      0.845     0.845     0.845       444

    accuracy                          0.842       872
   macro avg      0.842     0.842     0.842       872
weighted avg      0.842     0.842     0.842       872

CPU times: user 44min 32s, sys: 35.6 s, total: 45min 7s
Wall time: 5min 38s


#### An RNN experiment with the sst module

We can also use BERT representations as the input to an RNN. There is just one key change from how we used these models before:

* Previously, we would feed in lists of tokens, and they would be converted to indices into a fixed embedding space. This presumes that all words have the same representation no matter what their context is. 

* With BERT, we skip the embedding entirely and just feed in lists of BERT vectors, which means that the same word can be represented in different ways.

`TorchRNNClassifier` supports this via `use_embedding=False`. In turn, you needn't supply a vocabulary:

In [37]:
def fit_hf_rnn(X, y):
    mod = TorchRNNClassifier(
        vocab=[],
        early_stopping=True,
        use_embedding=False)  # Pass in the BERT hidden states directly!
    mod.fit(X, y)
    return mod

In [39]:
%%time
_ = sst.experiment(
    SST_HOME,
    hugging_face_bert_phi,
    fit_hf_rnn,
    train_reader=sst.train_reader,
    assess_reader=sst.dev_reader,
    class_func=sst.binary_class_func,
    vectorize=False)  # Pass in the BERT hidden states directly!

Stopping after epoch 24. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.7002094984054565

              precision    recall  f1-score   support

    negative      0.868     0.827     0.847       428
    positive      0.841     0.878     0.859       444

    accuracy                          0.853       872
   macro avg      0.854     0.853     0.853       872
weighted avg      0.854     0.853     0.853       872

CPU times: user 49min 24s, sys: 1min 11s, total: 50min 35s
Wall time: 6min 20s


### BERT fine-tuning with Hugging Face

The above experiments are quite successful – BERT gives us a reliable boost compared to other methods we've explored for the SST task. However, we might expect to do even better if we fine-tune the BERT parameters as part of fitting our SST classifier. To do that, we need to incorporate the Hugging Face BERT model into our classifier. This too is quite straightforward.

#### HfBertClassifier

The most important step is to create an `nn.Module` subclass that has, for its parameters, both the BERT model and parameters for our own classifier. Here we define a very simple fine-tuning set-up in which some layers built on top of the output corresponding to `[CLS]` are used as the basis for the SST classifier:

In [40]:
class HfBertClassifierModel(nn.Module):
    def __init__(self, n_classes, weights_name='bert-base-cased'):
        super().__init__()
        self.n_classes = n_classes
        self.weights_name = weights_name
        self.bert = BertModel.from_pretrained(self.weights_name)
        self.bert.train()
        self.hidden_dim = self.bert.embeddings.word_embeddings.embedding_dim
        # The only new parameters -- the classifier:
        self.classifier_layer = nn.Linear(
            self.hidden_dim, self.n_classes)

    def forward(self, indices, mask):
        final_hidden_states, cls_output = self.bert(
            indices, attention_mask=mask)
        return self.classifier_layer(cls_output)

As you can see, `self.bert` does the heavy-lifting: it reads in all the pretrained BERT parameters, and I've specified `self.bert.train()` just to make sure that these parameters can be updated during our training process. 

In `forward`, `self.bert` is used to process inputs, and then `cls_output` is fed into `self.classifier_layer`. Hugging Face has already added a layer on top of the actual output for `[CLS]`, so we can specify the model as

$$
\begin{align}
[h_{1}, \ldots, h_{n}] &= \text{BERT}([x_{1}, \ldots, x_{n}]) \\
h &= \tanh(h_{1}W_{hh} + b_{h}) \\
y &= \textbf{softmax}(hW_{hy} + b_{y})
\end{align}$$

for a tokenized input sequence $[x_{1}, \ldots, x_{n}]$. 

The Hugging Face documentation somewhat amusingly says, of `cls_output`,

> This output is usually _not_ a good summary of the semantic content of the input, you're often better with averaging or pooling the sequence of hidden-states for the whole input sequence.

which is entirely reasonable, but it will require more resources, so we'll do the simpler thing here.

For the training and prediction interface, we can subclass `TorchShallowNeuralClassifier` so that we don't have to write any of our own data-handling, training, or prediction code. The central changes are using `HfBertClassifierModel` in `build_graph` and processing the data with `batch_encode_plus`.

In [41]:
class HfBertClassifier(TorchShallowNeuralClassifier):
    def __init__(self, weights_name, *args, **kwargs):
        self.weights_name = weights_name
        self.tokenizer = BertTokenizer.from_pretrained(self.weights_name)
        super().__init__(*args, **kwargs)
        self.params += ['weights_name']

    def build_graph(self):
        return HfBertClassifierModel(self.n_classes_, self.weights_name)

    def build_dataset(self, X, y=None):
        data = self.tokenizer.batch_encode_plus(
            X,
            max_length=None,
            add_special_tokens=True,
            pad_to_max_length=True,
            return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        if y is None:
            dataset = torch.utils.data.TensorDataset(indices, mask)
        else:
            self.classes_ = sorted(set(y))
            self.n_classes_ = len(self.classes_)
            class2index = dict(zip(self.classes_, range(self.n_classes_)))
            y = [class2index[label] for label in y]
            y = torch.tensor(y)
            dataset = torch.utils.data.TensorDataset(indices, mask, y)
        return dataset

#### HfBertClassifier experiment

That's it! Let's see how we do on the SST binary, root-only problem. Because fine-tuning is expensive, we'll conduct a modest hyperparameter search and run the model for just one epoch per setting evaluation, as we did when [assessing NLI models](nli_02_models.ipynb).

In [42]:
def bert_fine_tune_phi(tree):
    return " ".join(tree.leaves())

In [43]:
def fit_hf_bert_classifier_with_hyperparameter_search(X, y):
    basemod = HfBertClassifier(
        weights_name='bert-base-cased',
        batch_size=8,  # Small batches to avoid memory overload.
        max_iter=1,  # We'll search based on 1 iteration for efficiency.
        n_iter_no_change=5,   # Early-stopping params are for the
        early_stopping=True)  # final evaluation.

    param_grid = {
        'gradient_accumulation_steps': [1, 4, 8],
        'eta': [0.00005, 0.0001, 0.001],
        'hidden_dim': [100, 200, 300]}

    bestmod = utils.fit_classifier_with_hyperparameter_search(
        X, y, basemod, cv=3, param_grid=param_grid)

    return bestmod

In [44]:
%%time
hf_bert_classifier_xval = sst.experiment(
    SST_HOME,
    bert_fine_tune_phi,
    fit_hf_bert_classifier_with_hyperparameter_search,
    train_reader=sst.train_reader,
    assess_reader=sst.dev_reader,
    class_func=sst.binary_class_func,
    vectorize=False)  # Pass in the BERT hidden state directly!

Finished epoch 1 of 1; error is 33.944158037076704

Best params: {'eta': 5e-05, 'gradient_accumulation_steps': 8, 'hidden_dim': 300}
Best score: 0.895
              precision    recall  f1-score   support

    negative      0.936     0.820     0.874       428
    positive      0.845     0.946     0.893       444

    accuracy                          0.884       872
   macro avg      0.891     0.883     0.883       872
weighted avg      0.890     0.884     0.884       872

CPU times: user 1h 21min 19s, sys: 1min 52s, total: 1h 23min 12s
Wall time: 1h 23min 3s


And now on to the final test-set evaluation, using the best model from above:

In [45]:
optimized_hf_bert_classifier = hf_bert_classifier_xval['model']

In [46]:
# Remove the rest of the experiment results to clear out some memory:
del hf_bert_classifier_xval

In [47]:
def fit_optimized_hf_bert_classifier(X, y):
    optimized_hf_bert_classifier.max_iter = 1000
    optimized_hf_bert_classifier.fit(X, y)
    return optimized_hf_bert_classifier

In [48]:
%%time
_ = sst.experiment(
    SST_HOME,
    bert_fine_tune_phi,
    fit_optimized_hf_bert_classifier,
    train_reader=(sst.train_reader, sst.dev_reader),
    assess_reader=sst.test_reader,
    class_func=sst.binary_class_func,
    vectorize=False)  # Pass in the BERT hidden state directly!

Stopping after epoch 13. Validation score did not improve by tol=1e-05 for more than 5 epochs. Final error is 1.2714146176476788

              precision    recall  f1-score   support

    negative      0.915     0.910     0.913       912
    positive      0.910     0.915     0.913       909

    accuracy                          0.913      1821
   macro avg      0.913     0.913     0.913      1821
weighted avg      0.913     0.913     0.913      1821

CPU times: user 14min 32s, sys: 7.25 s, total: 14min 39s
Wall time: 14min 38s


The above is just one of the many possible ways to fine-tune BERT using our course modules or new modules you write. The crux of it is creating an `nn.Module` that combines the BERT parameters with your model's new parameters.

## Using ELMo

### ELMo Allen NLP set-up

There are a number of ways to use pre-trained ELMo models. We'll use the simplest of the AllenNLP interfaces. Run the following to install [AllenNLP](https://allennlp.org):

```sh
pip install allennlp
```
I've tested this notebook with versions, 0.8.0, 0.9.0, and 1.0.0.

Mac users: If your installation fails, make sure your Xcode tools are up to date by running `xcode-select --install`.

In [42]:
from allennlp.modules.elmo import Elmo, batch_to_ids
import torch
import torch.nn as nn

We'll use the following models, which will download from S3 to a local temp directory the first time you use them with `ElmoEmbedder` or `Elmo` as described below.

In [43]:
elmo_file_path = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/"

options_file = elmo_file_path + "elmo_2x4096_512_2048cnn_2xhighway_options.json"

weights_file = elmo_file_path + "elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

For more models:

https://allennlp.org/elmo

For additional details:

https://docs.allennlp.org/master/api/modules/elmo/

### ELMo fine-tuning

Fine-tuning ELMo proceeds in essentially the same way it did for BERT: we create an `nn.Module` that combines the parameters from ELMo with our task-specific parameters and then optimize everything on the new task. To illustrate, I'll define an RNN on top of the ELMo model using new subclasses of `TorchRNNClassifier` and `TorchRNNClassifierModel`.

#### AllenNLP ELMo interfaces

To start, let's get a feel for the primary interface, and then we'll write the classes that will allow us to use these components systematically.

The interface to the ELMo parameters in this context is the class `Elmo`:

In [44]:
elmo = Elmo(options_file, weights_file, num_output_representations=1)

This model expects tokenized inputs:

In [45]:
elmo_example_texts = [
    ["Encode", "sentence", "1", "."],
    ["ELMo", "knows" "Snuffleupagus"]]

The ELMo model processes its tokens at the character-level, creating convolutional representations for the words from various character n-gram combinations:

In [46]:
elmo_character_ids = batch_to_ids(elmo_example_texts)

# First word of the first example:
elmo_character_ids[0][0]

tensor([259,  70, 111, 100, 112, 101, 102, 260, 261, 261, 261, 261, 261, 261,
        261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261,
        261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261,
        261, 261, 261, 261, 261, 261, 261, 261])

`elmo` embeds these at the word-level:

In [47]:
elmo_embeddings = elmo(elmo_character_ids)

`elmo_embeddings` is a dict. The value of the key `'elmo_representations'` is a list of tensors corresponding to each layer of the model. In other words, each tensor in the list is a complete representation of the example. The final element of the list is the final representation layer. I specified `num_output_representations=1` when initializing `elmo` above, so we get a list of length 1:

In [48]:
elmo_embeddings['elmo_representations']

[tensor([[[-0.7944,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -1.9283],
          [-0.0000,  1.1850,  0.6947,  ...,  0.0000, -0.9297, -0.2358],
          [ 0.0000,  0.5358,  0.7767,  ..., -0.6500, -0.0777, -0.4875],
          [-0.0000, -0.4965, -0.0000,  ..., -0.1605,  0.0000,  0.2256]],
 
         [[ 0.4553, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.5380],
          [ 0.0000, -0.5309, -0.0000,  ..., -0.2244, -0.0000,  0.7476],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
        grad_fn=<MulBackward0>)]

#### ElmoRNNClassifier

The above are the representations we will be fine-tuning. There are many ways to cdo this. In my simple illustration, I just take the top layer, as we did in the simpler featurization example above, but now keeping each word representation separate for use in the input to the task-specific RNN. Here is the `nn.Module` built on `TorchRNNClassifierModel`:

In [49]:
class ElmoRNNClassifierModel(TorchRNNClassifierModel):
    def __init__(self,
            options_file,
            weights_file,
            rnn,
            output_dim,
            classifier_activation):
        super().__init__(rnn, output_dim, classifier_activation)
        self.options_file = options_file
        self.weights_file = weights_file
        self.elmo = Elmo(
            self.options_file,
            self.weights_file,
            num_output_representations=1,
            dropout=0)
        self.elmo.train()

    def forward(self, X, seq_lengths):
        result = self.elmo(X)
        X = result['elmo_representations'][-1]
        outputs, state  = self.rnn(X, seq_lengths)
        state = self.get_batch_final_states(state)
        if self.rnn.bidirectional:
            state = torch.cat((state[0], state[1]), dim=1)
        h = self.classifier_activation(self.hidden_layer(state))
        logits = self.classifier_layer(h)
        return logits

And here is the subclass of `TorchRNNClassifier` that lets us take advantage of all the optimization and prediction methods of that class:

In [50]:
class ElmoRNNClassifier(TorchRNNClassifier):
    def __init__(self, vocab, options_file, weights_file, **model_kwargs):
        self.options_file = options_file
        self.weights_file = weights_file
        # Values determined by using ELMo:
        model_kwargs['use_embedding'] = False
        model_kwargs['embedding'] = None
        model_kwargs['embed_dim'] = 1024
        super().__init__(vocab, **model_kwargs)
        self.params += ['options_file', 'weights_file']

    def build_graph(self):

        # The RNN is setup just as in a regular `TorchRNNClassifier`:
        rnn = TorchRNNModel(
            vocab_size=len(self.vocab),
            embedding=self.embedding,
            use_embedding=self.use_embedding,
            embed_dim=self.embed_dim,
            rnn_cell_class=self.rnn_cell_class,
            hidden_dim=self.hidden_dim,
            bidirectional=self.bidirectional,
            freeze_embedding=self.freeze_embedding)

        # The Classifier layer uses our new `ElmoRNNClassifierModel`:
        model = ElmoRNNClassifierModel(
            options_file=self.options_file,
            weights_file=self.weights_file,
            rnn=rnn,
            output_dim=self.n_classes_,
            classifier_activation=self.classifier_activation)

        return model

    def build_dataset(self, X, y=None):
        seq_lengths = [len(ex) for ex in X]
        seq_lengths = torch.tensor(seq_lengths)
        X = batch_to_ids(X)
        if y is None:
            return torch.utils.data.TensorDataset(X, seq_lengths)
        else:
            self.classes_ = sorted(set(y))
            self.n_classes_ = len(self.classes_)
            class2index = dict(zip(self.classes_, range(self.n_classes_)))
            y = [class2index[label] for label in y]
            y = torch.tensor(y)
            return torch.utils.data.TensorDataset(X, seq_lengths, y)

#### ElmoRNNClassifier experiment

And finally here is a self-contained evaluation involving a modest hyperparameter search:

In [51]:
def elmo_fine_tune_phi(tree):
    return tree.leaves()

In [52]:
def fit_elmo_rnn_with_hyperparameter_search(X, y):
    basemod = ElmoRNNClassifier(
        vocab=[],
        options_file=options_file,
        weights_file=weights_file,
        batch_size=8,  # Kept small so that we can explore large networks.
        max_iter=1,  # We'll search based on 1 iteration for efficiency.
        n_iter_no_change=5,   # Early-stopping params are for the
        early_stopping=True)  # final evalution.

    param_grid = {
        'gradient_accumulation_steps': [1, 4, 8],
        'eta': [0.001, 0.01, 0.05],
        'hidden_dim': [50, 100, 200]}

    bestmod = utils.fit_classifier_with_hyperparameter_search(
        X, y, basemod, cv=3, param_grid=param_grid)

    return bestmod

In [53]:
%%time
elmo_rnn_xval = sst.experiment(
    SST_HOME,
    elmo_fine_tune_phi,
    fit_elmo_rnn_with_hyperparameter_search,
    train_reader=sst.train_reader,
    assess_reader=sst.dev_reader,
    class_func=sst.binary_class_func,
    vectorize=False)  # Pass in the ELMo reps directly!

Finished epoch 1 of 1; error is 330.32408917695284

Best params: {'eta': 0.001, 'gradient_accumulation_steps': 1, 'hidden_dim': 200}
Best score: 0.856
              precision    recall  f1-score   support

    negative      0.848     0.848     0.848       428
    positive      0.854     0.854     0.854       444

    accuracy                          0.851       872
   macro avg      0.851     0.851     0.851       872
weighted avg      0.851     0.851     0.851       872

CPU times: user 2h 4min 55s, sys: 4min 16s, total: 2h 9min 12s
Wall time: 2h 1min 35s


And now we move to the test-set evaluation using the best model we found:

In [54]:
optimized_elmo_rnn = elmo_rnn_xval['model']

In [55]:
# Remove the unneeded experimental data to save memory:
del elmo_rnn_xval

In [56]:
def fit_optimized_elmo_rnn(X, y):
    optimized_elmo_rnn.max_iter = 20
    optimized_elmo_rnn.fit(X, y)
    return optimized_elmo_rnn

In [57]:
%%time
_ = sst.experiment(
    SST_HOME,
    elmo_fine_tune_phi,
    fit_optimized_elmo_rnn,
    train_reader=(sst.train_reader, sst.dev_reader),
    assess_reader=sst.test_reader,
    class_func=sst.binary_class_func,
    vectorize=False)  # Pass in the BERT hidden state directly!

Stopping after epoch 10. Validation score did not improve by tol=1e-05 for more than 5 epochs. Final error is 27.70847176760435

              precision    recall  f1-score   support

    negative      0.891     0.880     0.886       912
    positive      0.882     0.892     0.887       909

    accuracy                          0.886      1821
   macro avg      0.886     0.886     0.886      1821
weighted avg      0.886     0.886     0.886      1821

CPU times: user 15min 31s, sys: 34.3 s, total: 16min 6s
Wall time: 16min
