In [1]:
import torch

# Encoding
def encode(list_of_strings, pad_token_id=0):
    max_length = max([len(string) for string in list_of_strings])

    # create emtpy tensors
    attention_masks = torch.zeros((len(list_of_strings), max_length), dtype=torch.long)
    input_ids = torch.full((len(list_of_strings), max_length), pad_token_id, dtype=torch.long)

    for idx, string in enumerate(list_of_strings):
        # make sure string is in byte format
        if not isinstance(string, bytes):
            string = str.encode(string)

        input_ids[idx, :len(string)] = torch.tensor([x + 2 for x in string])
        attention_masks[idx, :len(string)] = 1

    return input_ids, attention_masks
    
# Decoding
def decode(outputs_ids):
    decoded_outputs = []
    for output_ids in outputs_ids.tolist():
        # transform id back to char IDs < 2 are simply transformed to ""
        decoded_outputs.append("".join([chr(x - 2) if x > 1 else "" for x in output_ids]))
    return decoded_outputs

In [None]:
from transformers import ReformerModelWithLMHead

model = ReformerModelWithLMHead.from_pretrained("google/reformer-enwik8")

In [52]:
encoded, attention_masks = encode(["In the year 1961"])
decode(model.generate(encoded, do_sample=True, max_length=20))

['In the year 1961, re']

In [53]:
input_ids = encoded
masked = input_ids.clone()
masked[:, :-1] = -100
masked

tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100,   51]])

In [54]:
decode(masked)

['1']

In [55]:
out = model.forward(input_ids=input_ids, labels=input_ids)

In [56]:
model.forward??

In [58]:
out.loss.neg().exp()

tensor(0.2227, grad_fn=<ExpBackward0>)

In [45]:
last_token_logits = out.logits[0, -2, :]
decode(last_token_logits.topk(20).indices.unsqueeze(0))

["1602345879 .],<\nt&)'"]

In [62]:
out.logits.softmax(dim=2)[:, :, 51]

tensor([[4.0522e-03, 3.4218e-03, 7.4647e-02, 4.2927e-04, 1.0230e-05, 1.6872e-05,
         2.0796e-02, 6.7687e-06, 1.1147e-07, 3.2344e-06, 3.3391e-05, 4.1315e-01,
         9.1986e-02, 1.0991e-01, 1.3634e-01, 7.9240e-07]],
       grad_fn=<SelectBackward0>)

In [46]:
last_token_logits.softmax(dim=0).topk(20).values

tensor([1.3634e-01, 1.3585e-01, 1.2241e-01, 1.0184e-01, 1.0049e-01, 1.0026e-01,
        8.9645e-02, 8.1478e-02, 7.2285e-02, 5.6066e-02, 1.1347e-03, 3.4175e-04,
        2.1189e-04, 1.5333e-04, 1.0002e-04, 9.7569e-05, 9.2751e-05, 9.1451e-05,
        7.6007e-05, 5.8729e-05], grad_fn=<TopkBackward0>)

In [48]:
import numpy as np
np.log(1.3634e-1)

-1.992603513047498

In [49]:
1/1.3634

0.7334604664808567