In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
import transformers
import os
import torch

device = "cuda"

FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"

tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder", padding_side="left")
tokenizer.add_special_tokens(
    {
        "additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
        "eos_token": EOD,
        "pad_token": FIM_PAD,
    }
)
tokenizer.pad_token_id

49156

In [2]:

model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(
    "bigcode/santacoder", revision="aaeed52", trust_remote_code=True
).to(device)


In [3]:
import re

sample_code = """
def fib(n):
    if n < 2:
        return n
    return fib(n - 1) + fib(n - 2)
"""

# Split the code into tokens naively. Split on whitespace and all punctuation
tokens = re.split(r"(\W)", sample_code)
len(tokens)

91

In [4]:
def get_suggestions(token, prefix, suffix):
    prompt = FIM_PREFIX + prefix + FIM_SUFFIX + suffix + FIM_MIDDLE
    print(prompt)

    with torch.no_grad():
        inputs = tokenizer(
            prompt, return_tensors="pt", padding=True, return_token_type_ids=False
        ).to(device)
        
        # Only allow the model to generate up to 10 extra characters. If we
        # need more than that to fix your code it's not a typo anymore!
        max_new_tokens = len(tokenizer.encode(token, add_special_tokens=False)) + 10

        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            # do_sample=True,
            # top_k=50,
            # top_p=0.95,
            num_return_sequences=10,
            num_beams=20,
            early_stopping=True,
            output_scores=True,
            return_dict_in_generate=True,
            forced_eos_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        
        # Replace everything after the first eos token with the pad token
        # print(model.config.pad_token_id)

        print('shape')
        print(outputs.sequences.shape)
        print(outputs.sequences)

        # Adapted from https://discuss.huggingface.co/t/generation-probabilities-how-to-compute-probabilities-of-output-scores-for-gpt2/3175
        # Get the generated sequences, ignoring the user-provided prompt
        gen_sequences = outputs.sequences[:, inputs.input_ids.shape[-1]:]
        
        # Remove any tokens after the 

        probs = torch.stack(outputs.scores, dim=1).softmax(-1)
        gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
        unique_prob_per_sequence = gen_probs.prod(-1)
        print(unique_prob_per_sequence)
        sequences = tokenizer.batch_decode(gen_sequences, skip_special_tokens=False)
        for text, prob in zip(sequences, unique_prob_per_sequence):
            # text = tokenizer.decode(seq, skip_special_tokens=True)
            print("------")
            print(text)
            # print()
            
    return None

get_suggestions("fib", "\ndef ", '(n):\n    if n < 2:\n        return n\n    return fib(n - 1) + fib(n - 2)\n')

<fim-prefix>
def <fim-suffix>(n):
    if n < 2:
        return n
    return fib(n - 1) + fib(n - 2)
<fim-middle>
shape
torch.Size([10, 48])
tensor([[49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,   428,   207,    17,    25,   259,   363,   293,   258,   363,
         24240,     7,    77,   459,   207,    16,     8,   385, 24240,     7,
            77,   459,   207,    17,     8,   185, 49154, 27947, 49152, 49156,
         49156, 49156, 49156, 49156, 49156, 49156, 49156, 49156],
        [49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,   428,   207,    17,    25,   259,   363,   293,   258,   363,
         24240,     7,    77,   459,   207,    16,     8,   385, 24240,     7,
            77,   459,   207,    17,     8,   185, 49154, 27947,     7,    77,
           399,   258,   356,   293,   428,   207,    17, 49152],
        [49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,  

In [5]:


suggestions = []

for i, token in enumerate(tokens):
    if token != 'fib':
        continue
    print(i)
    # If the token is just whitespace, skip it
    if token.strip() == "":
        suggestions.append(None)
        continue
    else:
        prefix = "".join(tokens[:i])
        suffix = "".join(tokens[i + 1 :])
        suggestions.append(get_suggestions(token, prefix, suffix))


4
<fim-prefix>
def <fim-suffix>(n):
    if n < 2:
        return n
    return fib(n - 1) + fib(n - 2)
<fim-middle>
shape
torch.Size([10, 48])
tensor([[49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,   428,   207,    17,    25,   259,   363,   293,   258,   363,
         24240,     7,    77,   459,   207,    16,     8,   385, 24240,     7,
            77,   459,   207,    17,     8,   185, 49154, 27947, 49152, 49156,
         49156, 49156, 49156, 49156, 49156, 49156, 49156, 49156],
        [49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,   428,   207,    17,    25,   259,   363,   293,   258,   363,
         24240,     7,    77,   459,   207,    16,     8,   385, 24240,     7,
            77,   459,   207,    17,     8,   185, 49154, 27947,     7,    77,
           399,   258,   356,   293,   428,   207,    17, 49152],
        [49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,

In [6]:

for i, token in enumerate(tokens):
    if token != 'fib':
        continue
    print(i)
    # If the token is just whitespace, skip it
    if token.strip() == "":
        suggestions.append(None)
        continue
    else:
        prefix = "".join(tokens[:i])
        suffix = "".join(tokens[i + 1 :])
        suggestions.append(get_suggestions(token, prefix, suffix))


4
<fim-prefix>
def <fim-suffix>(n):
    if n < 2:
        return n
    return fib(n - 1) + fib(n - 2)
<fim-middle>
shape
torch.Size([10, 48])
tensor([[49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,   428,   207,    17,    25,   259,   363,   293,   258,   363,
         24240,     7,    77,   459,   207,    16,     8,   385, 24240,     7,
            77,   459,   207,    17,     8,   185, 49154, 27947, 49152, 49156,
         49156, 49156, 49156, 49156, 49156, 49156, 49156, 49156],
        [49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,   428,   207,    17,    25,   259,   363,   293,   258,   363,
         24240,     7,    77,   459,   207,    16,     8,   385, 24240,     7,
            77,   459,   207,    17,     8,   185, 49154, 27947,     7,    77,
           399,   258,   356,   293,   428,   207,    17, 49152],
        [49153,   185,   563,   207, 49155,     7,    77,   399,   258,   356,
           293,