https://github.com/huggingface/evaluate/blob/0ca575d7aa0764ea646dcd5a27cb952e587ce9eb/metrics/perplexity/perplexity.py#L14

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.nn import CrossEntropyLoss

In [2]:
MODEL = "bigcode/gpt_bigcode-santacoder"
model = AutoModelForCausalLM.from_pretrained(MODEL).half().cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL, padding_side="left")

In [3]:
#TODO Confirm that this is correct, and if necessary, get someone at bigcode to fix it so it defaults to this
tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})

0

### The Real Thing

In [4]:
def get_tokens(inputs):
    return tokenizer(inputs, return_tensors="pt", padding=True).to(model.device)
    

In [5]:
loss_fct = CrossEntropyLoss(reduction="none")
def compute_last_n_perplexity(input_toks, last_n):
    """
      Computes the perplexity of the last n tokens of a result.

      Args:
        r: The result logits from the model.
        last_n: The last n tokens to compute perplexity from.

      Returns:
        The average peroplexity for each result in the batch over the last n tokens.
    """
    with torch.no_grad():
        r = model(**input_toks).logits
    shift_logits = r[..., :-1, :].contiguous()
    shift_labels = input_toks.input_ids[..., 1:].contiguous()
    shift_attention_mask_batch = input_toks.attention_mask[..., 1:].contiguous()

    shift_logits = shift_logits[..., :-last_n, :]
    shift_labels = shift_labels[..., :-last_n]
    shift_attention_mask_batch = shift_attention_mask_batch[..., :-last_n]

    perplexity_batch = torch.exp(
                    (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
                    / shift_attention_mask_batch.sum(1)
                )

    return perplexity_batch

### Strangeness of Perplexity
Perplexity seems to change when the initial padding of input changes. Unsure why this is ask Daniel

In [6]:
INPUT1 = "def foo(n):\n\treturn n + 1"
INPUT2 = "def foo(n: str):\n\treturn n + 1" 
INPUT3 = "# s is an integer\ndef foo(s):\n\treturn s + 1"
INPUT4 = "def foo(s: int):\n\treturn s + 1"
INPUT5 = "def foo(s):\n\treturn s + 1"
INPUT6 = "this is really long input that is not a valid python code. this is really long input that is not a valid python code"

In [7]:
compute_last_n_perplexity(get_tokens([INPUT1, INPUT2]), 2)

tensor([14.0469, 13.8281], device='cuda:0', dtype=torch.float16)

In [9]:
compute_last_n_perplexity(get_tokens([INPUT1, INPUT3]), 2)

tensor([14.4141, 11.8984], device='cuda:0', dtype=torch.float16)

In [8]:
compute_last_n_perplexity(get_tokens([INPUT1, INPUT6]), 2)

tensor([14.7812, 16.1094], device='cuda:0', dtype=torch.float16)