Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guide to fixed-length model perplexity evaluation #5449

Merged
merged 6 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/imgs/ppl_chunked.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/imgs/ppl_full.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/imgs/ppl_sliding.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ conversion utilities for the following models:
:caption: Research

bertology
perplexity
benchmarks

.. toctree::
Expand Down
148 changes: 148 additions & 0 deletions docs/source/perplexity.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
Perplexity of fixed-length models
=================================

Perplexity (PPL) is one of the most common metrics for evaluating language
models. Note that the metric applies specifically to classical language
models (sometimes called autoregressive or causal language models) and is not
well defined for masked language models like BERT (see :doc:`summary of the
models <model_summary>`).

Perplexity is defined as the exponentiated average log-likelihood of a
sequence. If we have a tokenized sequence :math:`X = (x_0, x_1, \dots, x_t)`,
then the perplexity of :math:`X` is,

.. math::

\text{PPL}(X)
= \exp \left\{ {-\frac{1}{t}\sum_i^t \log p_\theta (x_i|x_{<i}) } \right\}

where :math:`\log p_\theta (x_i|x_{<i})` is the log-likelihood of the ith
token conditioned on the preceding tokens :math:`x_{<i}` according to our
model.

This is also equivalent to the exponentiation of the cross-entropy between
the data and model predictions. For more intuition about perplexity and its
relationship to Bits Per Character (BPC) and data compression, check out this
`fantastic blog post on The Gradient
<https://thegradient.pub/understanding-evaluation-metrics-for-language-models/>`_.

Calculating PPL with fixed-length models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If we weren't limited by a model's context size, we would evaluate the
model's perplexity by autoregressively factorizing a sequence and
conditioning on the entire preceding subsequence at each step, as shown
below.

.. image:: imgs/ppl_full.gif
:width: 600
:alt: Full decomposition of a sequence with unlimited context length

When working with approximate models, however, we typically have a constraint
on the number of tokens the model can process. The largest version
of :doc:`GPT-2 <model_doc/gpt2>`, for example, has a fixed length of 1024
tokens, so we cannot calculate :math:`p_\theta(x_t|x_{<t})` directly when
:math:`t` is greater than 1024.

Instead, the sequence is typically broken into subsequences equal to the
model's maximum input size. If a model's max input size is :math:`k`, we
then approximate the likelihood of a token :math:`x_t` by conditioning only
on the :math:`k-1` tokens that precede it rather than the entire context.
When evaluating the model's perplexity of a sequence, a tempting but
suboptimal approach is to break the sequence into disjoint chunks and
calculate the average decomposed log-likelihood of each independently.

.. image:: imgs/ppl_chunked.gif
:width: 600
:alt: Suboptimal PPL not taking advantage of full available context

This is quick to compute since the perplexity of each segment can be computed
in one forward pass, but serves as a poor approximation of the
fully-factorized perplexity and will typically yield a higher (worse) PPL
because the model will not be able to take advantage of all the available
context at most of the prediction steps.

joeddav marked this conversation as resolved.
Show resolved Hide resolved
Instead, the PPL of fixed-length models should be evaluated with a
sliding-window strategy. This involves repeatedly sliding the
context window so that the model has more context when making each
prediction.

.. image:: imgs/ppl_sliding.gif
:width: 600
:alt: Sliding window PPL taking advantage of all available context

This is a closer approximation to the true decomposition of the
sequence probability and will typically yield a more favorable score.
However, the downside is that it would require a separate forward
pass for each token in the corpus. A good practical compromise is to employ a
strided sliding window, moving the context by larger strides rather than
sliding by 1 token a time. This allows computation to procede much faster
while still giving the model a large context to make predictions at each
step.

Example: Calculating PPL with GPT-2 in 🤗 Transformers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Let's demonstrate this process with GPT-2.

.. code-block:: python

from transformers import GPT2LMHeadModel, GPT2TokenizerFast
device = 'cuda'
model_id = 'gpt2-large'
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

We'll load in the WikiText-2 dataset and evaluate the perplexity using a few
different sliding-window strategies. Since this dataset is small and we're
just doing one forward pass over the set, we can just load and encode the
entire dataset in memory.

.. code-block:: python

from nlp import load_dataset
test = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
encodings = tokenizer('\n\n'.join(test['text']), return_tensors='pt')

With 🤗 Transformers, we can simply pass the ``input_ids`` as the ``labels``
to our model, and the average log-likelihood for each token is returned as
the loss. With our sliding window approach, however, there is overlap in the
tokens we pass to the model at each iteration. We don't want the
log-likelihood for the tokens we're just treating as context to be included
in our loss, so we can set these targets to ``-100`` so that they are
ignored. The following is an example of how we could do this with a stride of
``512``. This means that the model will have at least 512 tokens for context
when calculating the conditional likelihood of any one token.

.. code-block:: python

max_length = model.config.n_positions
stride = 512

lls = []
for i in tqdm(range(1, encodings.input_ids.size(1), stride)):
begin_loc = max(i + stride - max_length, 0)
end_loc = i + stride
input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:,:-stride] = -100

with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
log_likelihood = outputs[0] * stride

lls.append(log_likelihood)

ppl = torch.exp(torch.stack(lls).sum() / i)

Running this with the stride length equal to the max input length is
equivalent to the suboptimal, non-sliding-window strategy we discussed above.
The smaller the stride, the more context the model will have in making each
prediction, and the better the PPL will typically be.

When we run the above with ``stride = 1024``, e.g. no overlap, the resulting
PPL is ``19.64``, which is about the same as the ``19.93`` reported in the
GPT-2 paper. By using ``stride = 512`` and thereby employing our striding
window strategy, this jumps down to ``16.53``. This is not only a more
favorable score, but is calculated in a way that is closer to the true
autoregressive decomposition of a sequence likelihood.