In [1]:
in_text = code_example = \
"""# Handshake successful, kill previous client if there is any.
with current_client_pid.get_lock():
    old_pid = current_client_pid.value
    if old_pid != 0:
        print(f"Booting previous client (pid={old_pid})")
        os.kill(old_pid, signal.SIGKILL)
        current_client_pid.value = os.getpid()"""

In [2]:
import transformers
import torch
from torch import Tensor
from typing import Tuple, List
import logging


def get_src_tokens_and_logits(in_text: str, model_name: str, device: str = None, verbose=False) -> Tuple[List[str], Tensor]:
    logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO)
    model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    logging.info(f'Model: {model.__class__.__name__}')
    logging.debug('Adding bos token...')
    in_text = f'{tokenizer.bos_token}{in_text}'
    logging.debug('Input:')
    logging.debug(in_text)
    logging.debug('------')
    inputs: Tensor = tokenizer(in_text, return_tensors='pt').data['input_ids'].squeeze()
    logging.debug(f"Input token ids' shape: {inputs.shape}")
    if not device:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    logging.info(f'Device: {device}')
    outputs: Tensor = model(inputs.to(torch.device(device)))
    # NOTE: The last logit vector is used for predicting the next token, which is not part of the input. Therefore, we exclude it.
    logging.debug('Discarding last logit vector...')
    logits: Tensor = outputs.logits.data[:-1]
    logging.debug(f'logits.shape: {logits.shape}')
    # tokens: List[str] = tokenizer.convert_ids_to_tokens(inputs[1:])
    src_tokens: List[str] = tokenizer.batch_decode([[i] for i in inputs[1:]])
    return src_tokens, logits

def get_scores(logits: Tensor, verbose=False) -> Tensor:
    logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO)
    ret: Tensor = torch.distributions.Categorical(logits=logits).entropy()
    logging.debug(f'scores.shape: {ret.shape}')
    return ret

# def get_topk(logits: Tensor, topk: int = 5):
#     topk_prob_values, topk_prob_inds = torch.topk(logits, k=topk, dim=1)
#     return topk_prob_values, topk_prob_inds

# tokens, logits = get_src_tokens_and_logits(code_example, model_name='Salesforce/codegen-350M-mono')
# scores = get_scores(logits)

In [3]:
# logits -= logits.min(dim=-1, keepdim=True).values
# logits /= logits.max(dim=-1, keepdim=True).values
# logits

In [4]:
import gradio as gr
import matplotlib as mpl
import numpy as np
from typing import List


def get_colors(scores: Tensor) -> List[str]:
    cmap = mpl.colormaps['YlOrBr']
    norm = mpl.colors.Normalize(vmin=scores.min(), vmax=scores.max())
    rgbas: np.ndarray = cmap(norm(scores))
    return np.apply_along_axis(mpl.colors.rgb2hex, -1, rgbas)

def get_html(in_text: str):
        tokens, logits = get_src_tokens_and_logits(in_text, model_name='Salesforce/codegen-350M-mono')
        scores = get_scores(logits)
        colors = get_colors(scores)
        assert len(tokens) == len(colors), f'len(tokens)={len(tokens)} != len(colors)={len(colors)}'
        ret = ''.join([f'<span style="background-color: {c}" title="token={t}, score={s:.5f}">{t}</span>' for t, s, c in zip(tokens, scores, colors)])
        return f'<pre><code class="python">{ret}</code></pre>'

demo = gr.Interface(
    fn=get_html,
    inputs=gr.Textbox(label='Code example', placeholder=code_example, value=code_example),
    outputs=gr.Markdown()
)

if __name__ == "__main__":
    gr.close_all()
    demo.launch(server_port=7860)

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


In [5]:
from IPython.display import display, HTML
html = get_html(code_example)
display(HTML(html))

INFO:root:Model: CodeGenForCausalLM
INFO:root:Device: cpu
  attn_weights = torch.where(causal_mask, attn_weights, mask_value)


**Reentrancy** example:
```solidity
function Collect(uint _am)
public
payable
{
    if(balances[msg.sender]>=MinSum && balances[msg.sender]>=_am)
    {
        if(msg.sender.call.value(_am)())
        {
            balances[msg.sender]-=_am;
            Log.AddMessage(msg.sender,_am,"Collect");
        }
    }
}
```

In [6]:
import re

PATTERN_SMARTBUGS = r"\s*\/\/ <yes> <report> [A-Z_]+"

smart_contract = """    function withdraw() {
        withdrawalCounter += 1;
        // calculate the fibonacci number for the current withdrawal user
        // this sets calculatedFibNumber
        // <yes> <report> ACCESS_CONTROL
        require(fibonacciLibrary.delegatecall(fibSig, withdrawalCounter));
        msg.sender.transfer(calculatedFibNumber * 1 ether);
    }

    // allow users to call fibonacci library functions
    function() public {
        // <yes> <report> ACCESS_CONTROL
        require(fibonacciLibrary.delegatecall(msg.data));
    }"""

In [8]:
def get_raw_code_and_tgt_line_numbers(annotation_pattern: str, annotated_code: str, verbose=False) -> Tuple[str, List[int]]:
    """
    Remove lines that match the regular expression and extract their line numbers.
    """
    line_numbers: List[int] = []
    line_number = 1
    new_lines = []
    for line in smart_contract.split('\n'):
        match = re.match(annotation_pattern, line)
        if match:
            logging.debug(f"Match found in line {line_number}: {line}")
            line_numbers.append(line_number)
        else:
            new_lines.append(line)
        line_number += 1
    raw_code: str = '\n'.join(new_lines)
    logging.debug(f"Raw code (without annotation): {raw_code}")
    return line_numbers, raw_code

line_numbers, raw_code = get_raw_code_and_tgt_line_numbers(PATTERN_SMARTBUGS, smart_contract)
print(line_numbers)
print(raw_code)

[5, 12]
    function withdraw() {
        withdrawalCounter += 1;
        // calculate the fibonacci number for the current withdrawal user
        // this sets calculatedFibNumber
        require(fibonacciLibrary.delegatecall(fibSig, withdrawalCounter));
        msg.sender.transfer(calculatedFibNumber * 1 ether);
    }

    // allow users to call fibonacci library functions
    function() public {
        require(fibonacciLibrary.delegatecall(msg.data));
    }
