In [1]:
import torch
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
device = 'cpu' # 'cuda'
model_id = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

# from nlp import load_dataset
# test = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
# encodings = tokenizer('\n\n'.join(test['text']), return_tensors='pt')

encodings = tokenizer('An orange ruled the world.', return_tensors='pt')

def compute_perplexity(model, encodings):
    max_length = model.config.n_positions
    stride = 512

    lls = []
    for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.input_ids.size(1))
        trg_len = end_loc - i    # may be different from stride on last loop
        input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:,:-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            log_likelihood = outputs[0] * trg_len

        lls.append(log_likelihood)

    ppl = torch.exp(torch.stack(lls).sum() / end_loc)
    return ppl, outputs

In [2]:
ppl, outputs = compute_perplexity(model, encodings)
ppl

100%|██████████| 1/1 [00:00<00:00,  2.27it/s]


tensor(465.2043)

In [13]:
input_ids = encodings.input_ids.to(device)
result = model.generate(input_ids).to('cpu')
# [r.size() for r in result]
tokenizer.decode(list(result[0]))

Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence


'An orange ruled the world.\n\nThe first time I saw it was in the movie "The'

In [18]:
outputs[1][:, -1, :]

IndexError: too many indices for tensor of dimension 0

In [15]:
outputs[1].shape

torch.Size([1, 6, 50257])

In [17]:
outputs[0]

tensor(6.1425)

In [21]:
outputs[1]

tensor([[[ -30.1623,  -28.9555,  -32.0044,  ...,  -35.6717,  -35.8829,
           -29.2568],
         [ -74.5398,  -74.5179,  -78.7193,  ...,  -85.7082,  -80.6271,
           -75.4816],
         [ -83.7110,  -85.1910,  -92.2418,  ...,  -93.3988,  -92.9523,
           -87.2709],
         [-102.3252, -100.4652, -104.4386,  ..., -105.3670, -106.3231,
          -101.8432],
         [ -95.8678,  -96.9816, -102.3193,  ..., -110.5038, -106.3977,
           -98.8357],
         [-123.4892, -122.2919, -123.8654,  ..., -134.8279, -132.7908,
          -115.4254]]])

In [22]:
dir(outputs)

['__add__',
 '__class__',
 '__contains__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getnewargs__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__mul__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__rmul__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'count',
 'index']

In [23]:
outputs.last_hidden_state

AttributeError: 'tuple' object has no attribute 'last_hidden_state'

In [40]:

model2 = GPT2LMHeadModel.from_pretrained(model_id, return_dict=True).to(device)
ppl2, outputs2 = compute_perplexity(model2, encodings)

100%|██████████| 1/1 [00:00<00:00,  9.27it/s]


In [27]:
outputs2.loss

tensor(6.1425)

In [28]:
outputs[0]

tensor(6.1425)

In [31]:
(outputs2.logits == outputs[1]).all()

tensor(True)

In [32]:
tokenizer.vocab_size

50257

In [25]:
logits = outputs[1][0]
values, indices = torch.topk(logits[1], k=1000)
values

tensor([-68.9441, -69.7438, -69.8351, -70.0671, -70.6838, -71.2393, -71.2426,
        -71.3316, -71.3925, -71.5246, -71.6758, -71.6897, -71.7453, -71.7616,
        -71.8045, -71.8453, -71.8596, -71.8740, -71.8797, -71.9139, -71.9500,
        -71.9620, -72.0048, -72.0842, -72.1317, -72.1525, -72.1947, -72.1952,
        -72.2284, -72.2377, -72.2444, -72.2655, -72.2743, -72.2872, -72.3579,
        -72.4047, -72.4137, -72.4287, -72.4522, -72.4990, -72.5173, -72.5305,
        -72.5458, -72.5458, -72.5676, -72.5762, -72.5795, -72.6138, -72.6259,
        -72.6304, -72.6330, -72.6692, -72.6926, -72.6959, -72.7065, -72.7100,
        -72.7359, -72.7483, -72.7581, -72.7614, -72.8115, -72.8233, -72.8269,
        -72.8365, -72.8471, -72.8576, -72.8826, -72.8951, -72.9084, -72.9086,
        -72.9407, -72.9418, -72.9577, -72.9657, -73.0087, -73.0144, -73.0311,
        -73.0468, -73.0479, -73.0577, -73.0579, -73.0824, -73.0859, -73.0865,
        -73.0876, -73.1149, -73.1164, -73.1211, -73.1681, -73.17

In [32]:
vocab = tokenizer.get_vocab()

In [36]:
''.join([k for i in range(10) for k, v in vocab.items() if v == i])

'!"#$%&\'()*'

In [31]:
tokenizer.decode(list(range(10)))

'!"#$%&\'()*'

In [29]:
tokenizer.decode(list(indices))



In [41]:
outputs2.__dict__

{'loss': tensor(6.1425),
 'logits': tensor([[[ -30.1623,  -28.9555,  -32.0044,  ...,  -35.6717,  -35.8829,
            -29.2568],
          [ -74.5398,  -74.5179,  -78.7193,  ...,  -85.7082,  -80.6271,
            -75.4816],
          [ -83.7110,  -85.1910,  -92.2418,  ...,  -93.3988,  -92.9523,
            -87.2709],
          [-102.3252, -100.4652, -104.4386,  ..., -105.3670, -106.3231,
           -101.8432],
          [ -95.8678,  -96.9816, -102.3193,  ..., -110.5038, -106.3977,
            -98.8357],
          [-123.4892, -122.2919, -123.8654,  ..., -134.8279, -132.7908,
           -115.4254]]]),
 'past_key_values': (tensor([[[[[-1.2651e+00,  2.2377e+00,  6.4798e-01,  ..., -9.4572e-01,
              -9.0716e-01,  1.5312e+00],
             [-2.0415e+00,  3.0485e+00,  1.2401e+00,  ..., -1.0140e+00,
              -1.7704e+00,  3.1532e+00],
             [-1.6852e+00,  2.2756e+00,  1.3499e+00,  ..., -1.8068e+00,
              -2.2438e+00,  8.0326e-01],
             [-2.2933e+00,  2.6495