In [1]:
in_text = code_example = \
"""# Handshake successful, kill previous client if there is any.
with current_client_pid.get_lock():
    old_pid = current_client_pid.value
    if old_pid != 0:
        print(f"Booting previous client (pid={old_pid})")
        os.kill(old_pid, signal.SIGKILL)
        current_client_pid.value = os.getpid()"""

In [2]:
import transformers
import torch
from torch import Tensor
from typing import Tuple, List


def get_src_tokens_and_logits(in_text: str, model_name: str, device: str = None) -> Tuple[List[str], Tensor]:
    model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    print('Model: ', model.__class__.__name__)
    print('Adding bos token...')
    in_text = f'{tokenizer.bos_token}{in_text}'
    print('Input:')
    print(in_text)
    print('------')
    inputs: Tensor = tokenizer(in_text, return_tensors='pt').data['input_ids'].squeeze()
    print("Input token ids' shape: ", inputs.shape)
    if not device:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    print('Device: ', device)
    outputs: Tensor = model(inputs.to(torch.device(device)))
    # NOTE: The last logit vector is used for predicting the next token, which is not part of the input. Therefore, we exclude it.
    print('Discarding last logit vector...')
    logits: Tensor = outputs.logits.data[:-1]
    print('logits.shape: ', logits.shape)
    # tokens: List[str] = tokenizer.convert_ids_to_tokens(inputs[1:])
    src_tokens: List[str] = tokenizer.batch_decode([[i] for i in inputs[1:]])
    return src_tokens, logits

def get_scores(logits: Tensor) -> Tensor:
    ret: Tensor = torch.distributions.Categorical(logits=logits).entropy()
    print('scores.shape: ', ret.shape)
    return ret

# def get_topk(logits: Tensor, topk: int = 5):
#     topk_prob_values, topk_prob_inds = torch.topk(logits, k=topk, dim=1)
#     return topk_prob_values, topk_prob_inds

# tokens, logits = get_src_tokens_and_logits(code_example, model_name='Salesforce/codegen-350M-mono')
# scores = get_scores(logits)

Model:  CodeGenForCausalLM
Adding bos token...
Input:
<|endoftext|># Handshake successful, kill previous client if there is any.
with current_client_pid.get_lock():
    old_pid = current_client_pid.value
    if old_pid != 0:
        print(f"Booting previous client (pid={old_pid})")
        os.kill(old_pid, signal.SIGKILL)
        current_client_pid.value = os.getpid()
------
Input token ids' shape:  torch.Size([98])
Device:  cpu


  attn_weights = torch.where(causal_mask, attn_weights, mask_value)


Discarding last logit vector...
logits.shape:  torch.Size([97, 51200])
scores.shape:  torch.Size([97])


In [3]:
# logits -= logits.min(dim=-1, keepdim=True).values
# logits /= logits.max(dim=-1, keepdim=True).values
# logits

In [20]:
import gradio as gr
import matplotlib as mpl
import numpy as np
from typing import List


def get_colors(scores: Tensor) -> List[str]:
    cmap = mpl.colormaps['YlOrBr']
    rgbas: np.ndarray = cmap(scores)
    return np.apply_along_axis(mpl.colors.rgb2hex, -1, rgbas)

def get_html(in_text: str):
        tokens, logits = get_src_tokens_and_logits(in_text, model_name='Salesforce/codegen-350M-mono')
        scores = get_scores(logits)
        colors = get_colors(scores)
        assert len(tokens) == len(colors), f'len(tokens)={len(tokens)} != len(colors)={len(colors)}'
        ret = ''.join([f'<span style="background-color: {c}" title="token={t}, score={s:.5d}">{t}</span>' for t, s, c in zip(tokens, scores, colors)])
        return f'<pre><code class="python">{ret}</code></pre>'

demo = gr.Interface(
    fn=get_html,
    inputs=gr.Textbox(label='Code example', placeholder=code_example, value=code_example),
    outputs=gr.Markdown()
)

if __name__ == "__main__":
    demo.launch()

Running on local URL:  http://127.0.0.1:7863

To create a public link, set `share=True` in `launch()`.


In [26]:
from IPython.display import display, HTML
html = get_html(code_example)
display(HTML(html))

Model:  CodeGenForCausalLM
Adding bos token...
Input:
<|endoftext|># Handshake successful, kill previous client if there is any.
with current_client_pid.get_lock():
    old_pid = current_client_pid.value
    if old_pid != 0:
        print(f"Booting previous client (pid={old_pid})")
        os.kill(old_pid, signal.SIGKILL)
        current_client_pid.value = os.getpid()
------
Input token ids' shape:  torch.Size([98])
Device:  cpu
Discarding last logit vector...
logits.shape:  torch.Size([97, 51200])
scores.shape:  torch.Size([97])
