https://github.com/huggingface/evaluate/blob/0ca575d7aa0764ea646dcd5a27cb952e587ce9eb/metrics/perplexity/perplexity.py#L14

In [37]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.nn import CrossEntropyLoss
from typing import Iterable

In [38]:
MODEL = "bigcode/gpt_bigcode-santacoder"
model = AutoModelForCausalLM.from_pretrained(MODEL).half().cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL, padding_side="left")

### The Real Thing

In [93]:
loss_fct = CrossEntropyLoss(reduction="none")
def compute_last_n_perplexity(snippets: Iterable, last_n: int, max_length: int=None):
    """
      Computes the perplexity of the last n tokens of a result.

      Args:
        snippets: The code snippets to compute perplexity over.
        last_n: The last n tokens to compute perplexity from.
        max_length: The most tokens to feed into the model

      Returns:
        The average peroplexity for each result in the batch over the last n tokens.
    """
    def compute(snippet):
        input_toks = tokenizer(snippet, return_tensors="pt").to(model.device)
        if max_length:
            input_toks.input_ids = input_toks.input_ids[..., :max_length].contiguous()
            input_toks.attention_mask = input_toks.attention_mask[..., :max_length].contiguous()

        print(input_toks.input_ids.size())
        print(input_toks.attention_mask.size())

        with torch.no_grad():
            r = model(**input_toks).logits
        shift_logits = r[..., :-1, :].contiguous()
        shift_labels = input_toks.input_ids[..., 1:].contiguous()
        shift_attention_mask_batch = input_toks.attention_mask[..., 1:].contiguous()

        shift_logits = shift_logits[..., -last_n:, :].contiguous()
        shift_labels = shift_labels[..., -last_n:].contiguous()
        shift_attention_mask_batch = shift_attention_mask_batch[..., -last_n:].contiguous()

        perplexity_batch = torch.exp(
                        (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
                        / shift_attention_mask_batch.sum(1)
                    )

        return perplexity_batch
    return [compute(snippet).cpu().numpy()[0] for snippet in snippets]

### Strangeness of Perplexity
Perplexity seems to change when the initial padding of input changes. Unsure why this is ask Daniel

In [86]:
INPUT1 = "def foo(n):\n\treturn n + 1"
INPUT2 = "def foo(n: str):\n\treturn n + 1" 
INPUT3 = "# s is an integer\ndef foo(s):\n\treturn s + 1"
INPUT4 = "def foo(s: int):\n\treturn s + 1"
INPUT5 = "def foo(s):\n\treturn s + 1"
INPUT6 = "this is really long input that is not a valid python code. this is really long input that is not a valid python code"

In [87]:
compute_last_n_perplexity([INPUT1, INPUT2], 2)

torch.Size([1, 12])
torch.Size([1, 14])


[1.249, 4.434]

In [88]:
compute_last_n_perplexity([INPUT1, INPUT3], 2)

torch.Size([1, 12])
torch.Size([1, 18])


[1.249, 1.259]

In [89]:
compute_last_n_perplexity([INPUT1, INPUT6], 2)

torch.Size([1, 12])
torch.Size([1, 25])


[1.249, 1.019]

In [90]:
from pathlib import Path
import itertools

d = "../TypeWhich/src/"

def read_file(p):
    with open(p, "r", encoding="utf-8") as f:
        return f.read()

results = [ read_file(p) for p in itertools.chain(Path(d).glob("*.rs")) ]

In [91]:
results[0]

'mod benchmark;\nmod cgen;\nmod eval;\nmod grift;\nmod ins_and_outs;\nmod insert_coercions;\nmod parser;\nmod precision;\nmod pretty;\nmod syntax;\nmod type_check;\nmod z3_state;\n\nuse clap::Clap;\nuse std::io::*;\nuse std::path::Path;\n\n#[derive(Clap)]\nenum Parser {\n    Empty,\n    Grift,\n}\n\n#[derive(Clone, Copy, PartialEq)]\nenum Annot {\n    Ignore,\n    Hard,\n}\n\nimpl std::str::FromStr for Annot {\n    type Err = &\'static str;\n\n    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {\n        match s {\n            "ignore" => Ok(Annot::Ignore),\n            "hard" => Ok(Annot::Hard),\n            _ => Err("invalid annotation behavior"),\n        }\n    }\n}\n\nimpl std::str::FromStr for Parser {\n    type Err = &\'static str;\n\n    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {\n        match s {\n            "grift" => Ok(Parser::Grift),\n            "empty" => Ok(Parser::Empty),\n            _ => Err("invalid parser"),\n        }\n    }\n}\

In [94]:
compute_last_n_perplexity(results, 50, max_length=2000)

torch.Size([1, 2000])
torch.Size([1, 2000])


RuntimeError: The size of tensor a (2048) must match the size of tensor b (6151) at non-singleton dimension 2