In [33]:
import IPython.display
import matplotlib.pyplot as plt
import torch as t
import transformers

In [2]:
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

Downloading: 100%|██████████| 0.99M/0.99M [00:00<00:00, 2.15MB/s]
Downloading: 100%|██████████| 446k/446k [00:00<00:00, 1.17MB/s]
Downloading: 100%|██████████| 1.29M/1.29M [00:00<00:00, 2.79MB/s]
Downloading: 100%|██████████| 665/665 [00:00<00:00, 174kB/s]


In [4]:
gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2')

Downloading: 100%|██████████| 523M/523M [00:21<00:00, 26.0MB/s] 


In [54]:
def ipython_token_prob_table(token_probs):
    html_rows = ['<table>']
    for row in token_probs:
        html_cols = ['<tr>']
        for token, prob in row:
            html_cols.append(f'<td>{token}</td><td>{prob:.2g}</td>')
        html_cols.append('</tr>')
        html_rows.append(''.join(html_cols))
    html_rows.append('</table>')
    return IPython.display.HTML('\n'.join(html_rows))

In [1]:
def analyze_sentence(sentence):
    tokens = tokenizer.encode(sentence, return_tensors="pt")
    batch_size, seq_len = tokens.shape
    assert batch_size == 1
    with t.no_grad():
        probs = t.softmax(gpt2(tokens).logits, dim=-1)
        log_probs = t.log(probs)
        assert log_probs.shape == (batch_size, seq_len, tokenizer.vocab_size)
        sorted_probs, sorted_tokens = t.sort(log_probs, dim=-1, descending=True)
        log_probs_of_truth = t.gather(
            input=log_probs,
            dim=-1,
            index=tokens[:, 1:, None]
        ).squeeze(0).squeeze(-1).detach()

    plt.figure(figsize=(6.4,3.2))
    plt.plot([tokenizer.decode(token) for token in tokens[0, 1:]], log_probs_of_truth)
    plt.show()

    token_probs_table = []
    for i in range(seq_len - 1):
        token_probs_row = []
        add_token = lambda t: token_probs_row.append((tokenizer.decode(t), probs[0, i, t]))
        add_token(tokens[0, i + 1])
        for j in range(5):
            add_token(sorted_tokens[0, i, j])
        token_probs_table.append(token_probs_row)
    IPython.display.display(ipython_token_prob_table(token_probs_table))

# analyze_sentence(sentence="the transformer isn't going to fail this task.")