In [274]:
in_text = code_example = \
"""# Handshake successful, kill previous client if there is any.
with current_client_pid.get_lock():
    old_pid = current_client_pid.value
    if old_pid != 0:
        print(f"Booting previous client (pid={old_pid})")
        os.kill(old_pid, signal.SIGKILL)
        current_client_pid.value = os.getpid()"""

In [None]:
import transformers
import torch
from torch import Tensor
from typing import Tuple, List


def get_tokens_and_logits(in_text: str, model_name: str, device: str = None) -> Tuple[List[str], Tensor]:
    model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    print('Model: ', model.__class__.__name__)
    print('Adding bos token...')
    in_text = f'{tokenizer.bos_token}{in_text}'
    print('Input:')
    print(in_text)
    print('------')
    inputs: Tensor = tokenizer(in_text, return_tensors='pt').data['input_ids'].squeeze()
    print("Input token ids' shape: ", inputs.shape)
    if not device:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    print('Device: ', device)
    outputs: Tensor = model(inputs.to(torch.device(device)))
    # NOTE: The last logit vector is used for predicting the next token, which is not part of the input. Therefore, we exclude it.
    print('Discarding last logit vector...')
    logits: Tensor = outputs.logits.data[:-1]
    print('logits.shape: ', logits.shape)
    tokens: List[str] = tokenizer.decode(inputs[1:])
    return tokens, logits

def get_scores(logits: Tensor) -> Tensor:
    ret: Tensor = torch.distributions.Categorical(logits=logits).entropy()
    print('scores.shape: ', ret.shape)
    return ret

# def get_topk(logits: Tensor, topk: int = 5):
#     topk_prob_values, topk_prob_inds = torch.topk(logits, k=topk, dim=1)
#     return topk_prob_values, topk_prob_inds

tokens, logits = get_tokens_and_logits(code_example, model_name='Salesforce/codegen-350M-mono')
scores = get_scores(logits)

Model:  CodeGenForCausalLM
Adding bos token...
Input:
<|endoftext|># Handshake successful, kill previous client if there is any.
with current_client_pid.get_lock():
    old_pid = current_client_pid.value
    if old_pid != 0:
        print(f"Booting previous client (pid={old_pid})")
        os.kill(old_pid, signal.SIGKILL)
        current_client_pid.value = os.getpid()
------
Input token ids' shape:  torch.Size([98])
Device:  cpu
Discarding last logit vector...
logits.shape:  torch.Size([97, 51200])


In [174]:
# logits -= logits.min(dim=-1, keepdim=True).values
# logits /= logits.max(dim=-1, keepdim=True).values
# logits

tensor([[0.8421, 0.8407, 0.8675,  ..., 0.0477, 0.0475, 0.0475],
        [0.6350, 0.5575, 0.5599,  ..., 0.0682, 0.0678, 0.0679],
        [0.8095, 0.7691, 0.7231,  ..., 0.0457, 0.0455, 0.0456],
        ...,
        [0.4795, 0.5214, 0.6179,  ..., 0.0676, 0.0675, 0.0676],
        [0.6694, 0.6995, 0.7520,  ..., 0.0511, 0.0512, 0.0512],
        [0.7615, 0.7607, 0.8356,  ..., 0.0477, 0.0478, 0.0478]])

In [340]:
import gradio as gr
import matplotlib as mpl
import numpy as np
from typing import List


def get_colors(scores: Tensor) -> List[str]:
    cmap = mpl.colormaps['YlOrBr']
    rgbas: np.ndarray = cmap(scores)
    return np.apply_along_axis(mpl.colors.rgb2hex, -1, rgbas)

def get_markdown_heatmap(in_text: str):
        tokens, logits = get_tokens_and_logits(in_text, model_name='Salesforce/codegen-350M-mono')
        scores = get_scores(logits)
        colors = get_colors(scores)
        assert len(tokens) == len(colors)
        ret = ''.join([f'<span style="background-color: {c}">{t}</span>' for t, c in zip(tokens, colors)])
        return f'```{ret}'

# with gr.Blocks() as demo:
#     gr.Markdown(
#     """
#     # Demo
#     """)
#     inp = gr.Textbox(label='Code example', placeholder=code_example, value=code_example)
#     tokens, logits = get_tokens_and_logits(inp, model_name='Salesforce/codegen-350M-mono')
#     scores = get_scores(logits)
#     colors = get_colors(scores)
#     out = gr.Markdown(label="Heatmap")
#     btn = gr.Button(value="Run")
#     btn.click(get_markdown_heatmap, inputs=[tokens, colors], outputs=out)
demo = gr.Interface(
    fn=get_markdown_heatmap,
    inputs=gr.Textbox(label='Code example', placeholder=code_example, value=code_example),
    outputs=gr.Markdown()
)


if __name__ == "__main__":
    demo.launch()

Running on local URL:  http://127.0.0.1:7888

To create a public link, set `share=True` in `launch()`.


Model:  CodeGenForCausalLM
Adding bos token...
Input:
<|endoftext|># Handshake successful, kill previous client if there is any.
with current_client_pid.get_lock():
    old_pid = current_client_pid.value
    if old_pid != 0:
        print(f"Booting previous client (pid={old_pid})")
        os.kill(old_pid, signal.SIGKILL)
        current_client_pid.value = os.getpid()
------
Input token ids' shape:  torch.Size([98])
Device:  cpu
Discarding last logit vector...
logits.shape:  torch.Size([97, 51200])
scores.shape:  torch.Size([97])


Traceback (most recent call last):
  File "/Users/nadavt/opt/anaconda3/envs/detecting-fake-text/lib/python3.8/site-packages/gradio/routes.py", line 384, in run_predict
    output = await app.get_blocks().process_api(
  File "/Users/nadavt/opt/anaconda3/envs/detecting-fake-text/lib/python3.8/site-packages/gradio/blocks.py", line 1032, in process_api
    result = await self.call_function(
  File "/Users/nadavt/opt/anaconda3/envs/detecting-fake-text/lib/python3.8/site-packages/gradio/blocks.py", line 844, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/Users/nadavt/opt/anaconda3/envs/detecting-fake-text/lib/python3.8/site-packages/anyio/to_thread.py", line 28, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(func, *args, cancellable=cancellable,
  File "/Users/nadavt/opt/anaconda3/envs/detecting-fake-text/lib/python3.8/site-packages/anyio/_backends/_asyncio.py", line 818, in run_sync_in_worker_thread
    return await future
  File "/Users/na