In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
import transformers
import os
import torch

device = "cuda"

FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"

tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder", padding_side="left")
tokenizer.add_special_tokens(
    {
        "additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
        "eos_token": EOD,
        "pad_token": FIM_PAD,
    }
)
tokenizer.pad_token_id


49156

In [2]:
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(
    "bigcode/santacoder", revision="aaeed52", trust_remote_code=True
).to(device)


In [3]:
import re

sample_code = """
prompts = []
suggestion_index_to_token_index = []

for i, token in enumerate(tokens):
    if token.isspace() == "":
        # Don't generate suggestions for whitespace
        continue
    else:
        prefix = ''.join(tokens[:i])
        suffix = ''.join(tokens[i + 1 :])
        prompt = FIM_PREFIX + prefix + FIM_SUFFIX + suffix + FIM_MIDDLE

        prompts.append(prompt)
        suggestion_index_to_token_index.append(i)"""

# Split the code into tokens naively on whitespace
tokens = re.split(r"(\s+)", sample_code)

len(tokens)


91

In [9]:
prompts = []
suggestion_index_to_token_index = []

for i, token in enumerate(tokens):
    if token.isspace() == "":
        # Don't generate suggestions for whitespace
        continue
    else:
        prefix = ''.join(tokens[:i])
        suffix = ''.join(tokens[i + 1 :])
        prompt = FIM_PREFIX + prefix + FIM_SUFFIX + suffix + FIM_MIDDLE

        prompts.append(prompt)
        suggestion_index_to_token_index.append(i)

<fim-prefix>
prompts = []
suggestion_index_to_token_index = []

for i, token in enumerate(tokens):
    if token.isspace() == <fim-suffix>
        # Don't generate suggestions for whitespace
        continue
    else:
        prefix = ''.join(tokens[:i])
        suffix = ''.join(tokens[i + 1 :])
        prompt = FIM_PREFIX + prefix + FIM_SUFFIX + suffix + FIM_MIDDLE

        prompts.append(prompt)
        suggestion_index_to_token_index.append(i)<fim-middle>


In [5]:
@torch.no_grad()
def get_suggestions(input_prompts):
    outputs_per_input = 5

    inputs = tokenizer(
        input_prompts, return_tensors="pt", padding=True, return_token_type_ids=False
    ).to(device)

    outputs = model.generate(
        **inputs,
        # Only allow the model to generate up to 10 tokens. If we
        # need more than that to fix your code it's not a typo anymore!
        max_new_tokens=20,
        num_return_sequences=outputs_per_input,
        num_beams=outputs_per_input,
        early_stopping=True,
        output_scores=True,
        return_dict_in_generate=True,
        # Always force the model to generate the eos token. If an EOS is improbable, that's very important information!
        forced_eos_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    # Adapted from https://discuss.huggingface.co/t/generation-probabilities-how-to-compute-probabilities-of-output-scores-for-gpt2/3175
    # Just the generated sequences, ignoring the prompt
    gen_sequences = outputs.sequences[:, inputs.input_ids.shape[-1] :]

    # For each sequence, get the probability of each token at each step
    probs = torch.stack(outputs.scores, dim=1).softmax(-1)
    
    gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)

    # Wherever the sequence is the pad token, set the probability to 1. That way we don't penalize short sequences. We already penalize premature truncation by forcing the model to generate the EOS token.
    gen_probs = torch.where(
        gen_sequences == tokenizer.pad_token_id,
        torch.ones_like(gen_probs),
        gen_probs,
    )

    # Multiply all the probabilities together to get the probability of the entire sequence being generated
    unique_prob_per_sequence = gen_probs.prod(-1)
    
    # Decode the returned sequences into strings
    sequences = tokenizer.batch_decode(gen_sequences, skip_special_tokens=True)

    sequences_with_probs = list(zip(sequences, unique_prob_per_sequence.tolist()))

    all_results = []

    for i in range(len(input_prompts)):
        batch_sequences = sequences_with_probs[i * outputs_per_input : (i + 1) * outputs_per_input]
        
        batch_sequences = sorted(batch_sequences, key=lambda x: x[1], reverse=True)
        all_results.append(batch_sequences)
        
    return all_results


get_suggestions(prompts[0:3])

[[('\nFIM_PREFIX = "<s>"\nFIM_SUFFIX = "</s>"\n', 0.0),
  ('\nFIM_PREFIX = "<FIM>"\nFIM_SUFFIX = "</FIM', 0.0),
  ('\nFIM_PREFIX = "<FIM>"\nFIM_SUFFIX = "<FIM', 0.0),
  ('\nFIM_PREFIX = "FIM: "\nFIM_SUFFIX = " "', 0.0),
  ('\nFIM_PREFIX = "FIM: "\nFIM_MIDDLE = " "', 0.0)],
 [('=', 0.9707012176513672),
  ('= ', 0.00410509156063199),
  ('+=', 0.0005870978930033743),
  ('= []\nsuggestion_tokens =', 1.3415783241266244e-12),
  ('= []\nsuggestions =', 2.3447771211285294e-25)],
 [('[]', 0.6721534132957458),
  ('[]\n', 0.041928596794605255),
  ('list()', 0.01677822507917881),
  ('[', 0.00012675553443841636),
  ('[""]', 1.106857010560458e-12)]]

In [6]:

model_outputs = get_suggestions(prompts)

# Create an array of the samen length as `prompts`, full of `None`
suggestions = [None] * len(tokens)

for token_index, suggestion in zip(suggestion_index_to_token_index, model_outputs):
    suggestions[token_index] = suggestion

In [7]:
from IPython.display import HTML

css = """
#output {
  font-family: monospace;
  white-space: pre;
  background-color: #333;
  padding: 20px;
  padding-bottom: 140px;
}

.token {
  color: #000;
  position: relative;
  cursor: pointer;
}

.alternatives {
  position: absolute;
  display: none;
}

.token:hover {
  background-color: #ddd;
}

.token:hover .alternatives {
  color: #000;
  position: absolute;
  top: 14px;
  left: 0;
  display: block;
  background-color: #fff;
  z-index: 1;
  padding: 8px;
  border-radius: 8px;
  border: 1px solid #ccc;
}
"""

spans = []
for token, suggestion in zip(tokens, suggestions):
    if suggestion is None:
        spans.append(f"<span class='token'>{token}</span>")
        continue

    # Find the token in the suggestions list and determine how likely it is
    token_probability = 0
    for s, p in suggestion:
        if s == token:
            token_probability = p
            break

    top_probability = suggestion[0][1]

    token_color = "white"
    if top_probability > 0.5 and token_probability < 0.1:
        token_color = "red"
    
    # token_color = f"hsl({(token_probability) * 120}, 100%, 50%)"

    suggestions_items = []
    for s, p in suggestion:
        if s == token:
            s = f"<strong>{s}</strong>"
        if s == "":
            s = "<i>remove token</i>"

        suggestions_items.append(f"<tr><td>{p:.2f}</td><td>{s}</td></li>")

    suggestions_table = f"<table>{''.join(suggestions_items)}</table>"

    spans.append(
        f"<span class='token' style='color: {token_color}'>{token}<span class='alternatives'>{suggestions_table}</span></span>"
    )

display(
    HTML(f"<style>{css}</style><body><div id='output'>{''.join(spans)}</div></body>")
)


0,1
0.0,"FIM_PREFIX = """" FIM_SUFFIX = """""
0.0,"FIM_PREFIX = """" FIM_SUFFIX = ""0.00 FIM_PREFIX = """" FIM_SUFFIX = ""0.00 FIM_PREFIX = ""FIM: "" FIM_SUFFIX = "" ""0.00 FIM_PREFIX = ""FIM: "" FIM_MIDDLE = "" """

0,1
0.97,=
0.0,=
0.0,+=
0.0,= [] suggestion_tokens =
0.0,= [] suggestions =

0,1
0.67,[]
0.04,[]
0.02,list()
0.0,[
0.0,"[""""]"

0,1
0.92,suggestion_index_to_token_index
0.0,suggestion_index_to_token_index
0.0,"for i, token in enumerate(tokens):  if token.isspace() == """":"
0.0,"suggestion_index_to_token_index = [] for i, token in enumerate"
0.0,suggestions = [] suggestion_index_to_token_index

0,1
0.93,=
0.0,=
0.0,= [] tokens =
0.0,= [] suggestions =
0.0,= [] suggestion_tokens =

0,1
0.33,[]
0.0,[]
0.0,[] tokens = tokenizer.tokenize(text)
0.0,[] tokens = tokenizer.tokenize(sentence)
0.0,[] tokens = tokenizer.tokenize(text)

0,1
0.42,for
0.0,for
0.0,"for i, token in enumerate(tokens):  if token.isspace() == """":"
0.0,"for i, token in enumerate(tokens):  if token.isspace() == True:"
0.0,"for i, token in enumerate(tokens):  if token.isspace() == """":"

0,1
0.85,"i,"
0.01,"index,"
0.0,"idx,"
0.0,"(i,"
0.0,"i,"

0,1
0.7,token
0.0,token in enumerate(tokens):  prefix = ''.join(tokens[:i])  suffix =
0.0,"token in enumerate(tokens):  if token.isspace() == """":  # Don't"
0.0,token in enumerate(tokens):  if token.isspace() == False:  prefix =
0.0,token in enumerate(tokens):  if token.isspace() == True:  # Don

0,1
0.92,in
0.0,in en
0.0,"in enumerate(tokens):  if token.isspace() == """":  # Don't generate"
0.0,in enumerate(tokens):  if token.isspace() == True:  # Don't
0.0,in enumerate(tokens):  prefix = ''.join(tokens[:i])  suffix = ''.

0,1
0.32,enumerate(tokens):
0.0,enumerate(tokens):
0.0,enumerate(tokens):
0.0,"enumerate(tokens):  if token.isspace() == """":  # Don't generate suggestions"
0.0,enumerate(tokens):  if token.isspace():  # Don't generate suggestions for whitespace

0,1
0.76,if
0.0,if not
0.0,if token.isalpha() or
0.0,if token.isalpha() == False and
0.0,if token.isalnum() or

0,1
0.4,token
0.0,"token == "" "" or token"
0.0,token.strip()
0.0,len(token)
0.0,token.isspace()

0,1
0.53,or token ==
0.28,or
0.0,and
0.0,or token in
0.0,or token.strip() ==

0,1
0.55,True:
0.19,False:
0.0,True or token.isnumeric() == True:
0.0,True or token == '\n':
0.0,True:  # Don't generate suggestions for spaces

0,1
0.77,#
0.0,continue  elif i == 0:  #
0.0,if i == 0:  #
0.0,continue  elif token == FIM_MIDDLE:  #
0.0,continue  elif token == FIM_PREFIX:  #

0,1
0.0,don't
0.0,skip whitespace  continue  elif token.isalpha() == False:  #
0.0,skip whitespace  continue  elif token.isalpha() == True:  #
0.0,skip whitespace  continue  elif token.isalpha() or token.isdigit() or token ==
0.0,skip empty tokens  continue  elif token.isalpha() == False:  #

0,1
0.24,add
0.09,include
0.06,suggest
0.03,show
0.03,generate

0,1
0.15,a prompt
0.07,prompts
0.01,any prompts
0.0,prompt
0.0,empty prompts

0,1
0.37,for
0.04,with
0.02,for a
0.0,for empty
0.0,for the

0,1
0.0,spaces.
0.0,spaces  continue  elif token.isupper() == False:  # Don't generate suggestions
0.0,spaces  continue  elif token.isalpha() == False:  # Don't generate suggestions
0.0,spaces  continue  elif token.isupper() == True:  # Don't generate suggestions
0.0,spaces.  continue  elif token.isupper() == True:  # Don't generate

0,1
0.3,continue
0.0,continue
0.0,pass
0.0,continue  if i == len(tokens) - 1:  prefix = ''.join(
0.0,continue  if i == 0:  prefix = ''.join(tokens[:i])

0,1
0.02,remove token
0.0,elif token.startswith(FIM_PREFIX) and token.endswith(FIM_SUFFIX
0.0,elif token == FIM_PREFIX or token == FIM_SUFFIX or token == FIM
0.0,elif token == FIM_PREFIX or token == FIM_SUFFIX:  # Don't
0.0,elif token.startswith(FIM_PREFIX):  # Don't generate suggestions for FIM

0,1
0.8,prefix
0.0,# Generate suggestions  prefix
0.0,if i == 0:  prefix
0.0,if i == 0:  prefix = token  else:  prefix
0.0,if i == 0:  prefix = ''  else:  prefix

0,1
0.98,=
0.0,= '
0.0,=
0.0,= FIM_PREFIX +
0.0,= ''  if i > 0:  prefix =

0,1
0.34,''.join(tokens[:i])
0.0,prefix + token
0.0,''.join(tokens[0:i])
0.0,''.join(tokens[: i])
0.0,''.join(tokens[:i]) +''

0,1
0.75,suffix
0.0,suffix = ''.join(tokens[i + 1 :])  middle
0.0,suffix = ''.join(tokens[i:])  middle
0.0,suffix = ''.join(tokens[i + 1 :])  middle = ''.join(tokens
0.0,suffix = ''.join(tokens[i + 1 :])  prompt = FIM_PREFIX

0,1
0.99,=
0.0,=
0.0,= '
0.0,= '...'  middle =
0.0,= ''.join(tokens[i + 1 :])  middle =

0,1
0.54,''.join(tokens[i
0.03,' '.join(tokens[i
0.0,''.join(tokens[i +
0.0,''.join(tokens[i
0.0,""""".join(tokens[i"

0,1
0.85,+
0.0,:
0.0,: len(tokens) -
0.0,+ 1 :])  middle = ''.join(tokens[i +
0.0,+ 1:])  middle = ''.join(tokens[i +

0,1
0.39,1
0.38,1:
0.0,remove token
0.0,1 :
0.0,1:len(tokens)

0,1
0.4,:])
0.0,:])
0.0,: len(tokens)])
0.0,:len(tokens)])
0.0,: len(tokens)])

0,1
0.67,prompt
0.0,prompt
0.0,prompt = FIM_PREFIX + prefix + FIM_MIDDLE + token + FIM_
0.0,if len(prefix) > 0 and len(suffix) > 0:  prompt
0.0,prompt = FIM_PREFIX + prefix + FIM_SUFFIX + suffix + FIM_

0,1
0.73,=
0.02,=
0.0,= F
0.0,= prefix +
0.0,= PREFIX +

0,1
0.33,prefix
0.02,PREFIX
0.0,prefix + PREFIX_MIDDLE
0.0,prefix + FIM_MIDDLE
0.0,prefix + FIM_PREFIX

0,1
0.89,+
0.03,+
0.01,+ FIM_MIDDLE +
0.0,+ prefix + FIM_MIDDLE +
0.0,+ prefix + FIM_MIDDLE + FIM_SUGGESTION +

0,1
0.85,prefix
0.0,prefix +''
0.0,prefix
0.0,prefix + token
0.0,"prefix + "" """

0,1
0.78,+
0.02,+ F
0.0,+
0.0,+ FIM_MIDDLE
0.0,+ FIM_MIDDLE +

0,1
0.25,FIM_MIDDLE
0.01,""" """
0.0,FIM_SUFFIX
0.0,SUGGESTION
0.0,FIM_MIDDLE

0,1
0.32,+
0.0,+ prefix +
0.0,+ token + FIM_SUFFIX +
0.0,+ FIM_MIDDLE +
0.0,+ FIM_MIDDLE + token + FIM_MIDDLE +

0,1
0.18,suffix
0.0,prefix + FIM_MIDDLE
0.0,suffix  prompt = prompt
0.0,prefix + FIM_MIDDLE + suffix
0.0,prefix + FIM_MIDDLE + suffix + FIM_SUFFIX

0,1
0.74,+
0.0,+
0.0,+ F
0.0,+ FIM_MIDDLE  prompts.append(prompt)  suggestion_index_to_
0.0,+ FIM_MIDDLE  prompts.append(prompt)  suggestion_index_to_

0,1
0.0,""" """
0.0,' '
0.0,'\n'
0.0,FIM_PROMPT
0.0,FIM_PROMPT

0,1
0.11,prompts.append(prompt)
0.0,suggestions.append(prompt)
0.0,prompts.append(prompt)
0.0,suggestions.append(prompt)
0.0,if prompt not in prompts:  prompts.append(prompt)

0,1
0.0,suggestion_index_to_token_index.append(i + 1)
0.0,suggestion_index_to_token_index.append(i) print(prompts
0.0,suggestion_index_to_token_index.append(i) with open(
0.0,suggestion_index_to_token_index.append(i) with open('
0.0,"suggestion_index_to_token_index.append(i) with open("""
