https://github.com/huggingface/evaluate/blob/0ca575d7aa0764ea646dcd5a27cb952e587ce9eb/metrics/perplexity/perplexity.py#L14

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.nn import CrossEntropyLoss
from typing import Iterable, Union

In [2]:
MODEL = "/home/arjun/bigcode/models/bigcode_15b_1000m"
model = AutoModelForCausalLM.from_pretrained(MODEL).half().cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL, padding_side="left")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

### The Real Thing

In [3]:
loss_fct = CrossEntropyLoss(reduction="none")
def compute_last_n_perplexity(snippets: Iterable, last_n: int):
    """
      Computes the perplexity of the last n tokens of a result.

      Args:
        snippets: The code snippets to compute perplexity over.
        last_n: The last n tokens to compute perplexity from.
        max_length: The most tokens to feed into the model

      Returns:
        The average peroplexity for each result in the batch over the last n tokens.
    """
    def compute(snippet):
        input_toks = tokenizer(snippet, return_tensors="pt").to(model.device)

        with torch.no_grad():
            r = model(**input_toks).logits
        shift_logits = r[..., :-1, :].contiguous()
        shift_labels = input_toks.input_ids[..., 1:].contiguous()
        shift_attention_mask_batch = input_toks.attention_mask[..., 1:].contiguous()

        shift_logits = shift_logits[..., -last_n:, :].contiguous()
        shift_labels = shift_labels[..., -last_n:].contiguous()
        shift_attention_mask_batch = shift_attention_mask_batch[..., -last_n:].contiguous()

        perplexity_batch = torch.exp(
                        (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
                        / shift_attention_mask_batch.sum(1)
                    )

        return perplexity_batch
    return [compute(snippet).cpu().numpy()[0] for snippet in snippets]

### Strangeness of Perplexity
Perplexity seems to change when the initial padding of input changes. Unsure why this is ask Daniel

In [4]:
INPUT1 = "def foo(n):\n\treturn n + 1"
INPUT2 = "def foo(n: str):\n\treturn n + 1" 
INPUT3 = "# s is an integer\ndef foo(s):\n\treturn s + 1"
INPUT4 = "def foo(s: int):\n\treturn s + 1"
INPUT5 = "def foo(s):\n\treturn s + 1"
INPUT6 = "this is really long input that is not a valid python code. this is really long input that is not a valid python code"

In [5]:
compute_last_n_perplexity([INPUT1, INPUT2], 2)

[1.2295, 5.535]

In [6]:
compute_last_n_perplexity([INPUT1, INPUT3], 2)

[1.2295, 1.417]

In [7]:
compute_last_n_perplexity([INPUT1, INPUT6], 2)

[1.2295, 1.004]

In [8]:
from pathlib import Path
import itertools

d = "../TypeWhich/src/"

def read_file(p):
    with open(p, "r", encoding="utf-8") as f:
        return f.read()

all_files = "".join([ read_file(p) for p in itertools.chain(Path(d).glob("*.rs")) ])

In [9]:
all_files



In [10]:
def compute_last_n_perplexity_for_last_k_tokens(string: str, last_k: Union[int, Iterable], last_n: int):
    """
      Computes the average perplexity of the last n tokens of a string, considering the last_k tokens of the entire string.

      Args:
        string: String to compute perplexity over
        last_k: Length or lengths of the last tokens of the string to compute perplexity over
        last_n: How many last tokens to compute the average perplexity over.

      Returns:
        The average perplexity of the last n tokens for each string subset of the last k tokens of the entire string.
        Format is a dictionary of last_k tokens to perplexity.
    """
    tokens = tokenizer(string, return_tensors="pt").to(model.device)

    def compute_perplexity(input_toks):
        with torch.no_grad():
            r = model(**input_toks).logits
        shift_logits = r[..., :-1, :].contiguous()
        shift_labels = input_toks["input_ids"][..., 1:].contiguous()
        shift_attention_mask_batch = input_toks["attention_mask"][..., 1:].contiguous()

        shift_logits = shift_logits[..., -last_n:, :].contiguous()
        shift_labels = shift_labels[..., -last_n:].contiguous()
        shift_attention_mask_batch = shift_attention_mask_batch[..., -last_n:].contiguous()

        return torch.exp(
                        (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
                        / shift_attention_mask_batch.sum(1)
                    )
    
    if type(last_k) is int:
        return {last_k : compute_perplexity({
            "input_ids" : tokens.input_ids[:, -last_k:],
            "attention_mask" : tokens.attention_mask[:, -last_k:]
        }).cpu().numpy()[0]}
    else: 
        return {k: compute_perplexity({
            "input_ids" : tokens.input_ids[:, -k:],
            "attention_mask" : tokens.attention_mask[:, -k:]
        }).cpu().numpy()[0] for k in last_k}

In [11]:
compute_last_n_perplexity_for_last_k_tokens(INPUT1, [1000, 500], 50)

{1000: 7.12, 500: 7.12}

In [12]:
files = [ p for p in itertools.chain(Path("./").glob("*.txt")) ]
files

[PosixPath('php_code.txt'),
 PosixPath('js_code.txt'),
 PosixPath('java_code.txt'),
 PosixPath('cs_code.txt'),
 PosixPath('rb_code.txt'),
 PosixPath('rust_code.txt'),
 PosixPath('r_code.txt'),
 PosixPath('go_code.txt'),
 PosixPath('c_code.txt'),
 PosixPath('cpp_code.txt')]

In [13]:
for f in files:
    with open(f, "r", encoding="utf-8") as data:
        print(f)
        print(compute_last_n_perplexity_for_last_k_tokens(data.read(), [8000, 2000], 50))

php_code.txt
{8000: 1.598, 2000: 1.596}
js_code.txt
{8000: 1.674, 2000: 1.696}
java_code.txt
{8000: 1.011, 2000: 1.019}
cs_code.txt
{8000: 1.881, 2000: 1.918}
rb_code.txt
{8000: 1.4375, 2000: 1.417}
rust_code.txt
{8000: 2.04, 2000: 1.894}
r_code.txt
{8000: 1.103, 2000: 1.132}
go_code.txt
{8000: 1.602, 2000: 1.558}
c_code.txt
{8000: 1.626, 2000: 1.541}
cpp_code.txt
{8000: 2.826, 2000: 2.81}
