In [None]:
from pathlib import Path
import pickle

import datasets
import evaluate
import numpy as np
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments

In [None]:
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("finetuned")

In [None]:
sep = "[SEP]"

def prepare_input(example):
    tokens = tokenizer(
        example["function_definition"] + sep + example["code"] + sep + example["comment"],
        truncation=True,
        max_length=1024,
        return_tensors="pt"
    )
    return tokens

In [None]:
example_fd = """def _create_rearrange_callable(
    tensor_ndim: int, pattern: str, **axes_lengths: int
) -> Callable[[torch.Tensor], torch.Tensor]"""
example_code = """    n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
    
    if n_dims == 0:
        # an identity rearrangement on a 0-dimension tensor
        return lambda tensor: tensor
    
    first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))"""
example_com = "# an identity rearrangement on a 0-dimension tensor"
example = {
    "function_definition": example_fd,
    "code": example_code,
    "comment": example_com
}
print(example)

{'function_definition': 'def _create_rearrange_callable(\n    tensor_ndim: int, pattern: str, **axes_lengths: int\n) -> Callable[[torch.Tensor], torch.Tensor]', 'code': '    n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims\n    \n    if n_dims == 0:\n        # an identity rearrangement on a 0-dimension tensor\n        return lambda tensor: tensor\n    \n    first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))', 'comment': '# an identity rearrangement on a 0-dimension tensor'}


In [None]:
inp = prepare_input(example)
print(inp)

{'input_ids': tensor([[50281,  1545,   795,  6953,    64,   250,  3298,   912,    64,  4065,
           494,     9,   187, 50274, 26109,    64,  2109,   303,    27,   540,
            13,  3102,    27,  1213,    13,  1401, 44832,    64,  3985,    84,
            27,   540,   187,    10,  5204,  9368,   494, 14598, 13473,   348,
            15, 39596,  1092, 30162,    15, 39596,    62, 50282, 50274,    79,
            64,  4528,    84,   426,   295,    64, 19389,    64,  4528,    84,
           559,   295,    64,   437,  2824,   261,    64,  4528,    84,   559,
           295,    64, 46339,    64,  4528,    84,   187, 50274,   187, 50274,
           338,   295,    64,  4528,    84,  2295,   470,    27,   187, 50270,
             4,   271,  6489, 47410,   327,   247,   470,    14, 39120, 13148,
           187, 50270,  2309, 29331, 13148,    27, 13148,   187, 50274,   187,
         50274,  7053,    64,  2437,    64,  4528,    84,    27,   308, 13932,
            60,  1344,    13,  3346,  

In [None]:
with torch.no_grad():
    out = model(**inp)
out

SequenceClassifierOutput(loss=None, logits=tensor([[-3.5259,  2.5648]]), hidden_states=None, attentions=None)

In [None]:
out.logits.argmax()

tensor(1)

In [None]:
def predict(inp, model=model):
    with torch.no_grad():
        out = model(**inp)
    return out.logits

In [None]:
contents = Path("example.py").open("r").readlines()
len(contents)

57

In [None]:
inputs = []
last_def = -1
for i, l in enumerate(contents):
    if l.find("def ") != -1:
        last_def = i
    if l.find("#") != -1:
        comment = l
        fdef = "None"
        if last_def != -1:
            fdef = contents[last_def]
        code = ''.join(contents[i-4:i+3])
        inputs.append({
            "function_definition": fdef,
            "code": code,
            "comment": comment
        })
print(len(inputs))
for idx, i in enumerate(inputs):
    print(idx, "---------------")
    for k, v in i.items():
        print(k)
        print(v)
        

4
0 ---------------
function_definition
        def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:

code
    if fn is not None:

        @functools.wraps(fn)
        def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:
            # cache this on the first invocation to avoid adding too much overhead.
            disable_fn = getattr(fn, "__dynamo_disable", None)
            if disable_fn is None:

comment
            # cache this on the first invocation to avoid adding too much overhead.

1 ---------------
function_definition
        def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:

code
            if disable_fn is None:
                import torch._dynamo

                disable_fn = torch._dynamo.disable(fn, recursive)
                fn.__dynamo_disable = disable_fn  # type: ignore[attr-defined]

            return disable_fn(*args, **kwargs)

comment
                fn.__dynamo_disable = disable_fn  # type: ignore[attr-defined]

2 ---------------
function_definition
  

In [None]:
prep_inp = list(map(prepare_input, inputs))

In [None]:
for i in prep_inp:
    print(predict(i))

tensor([[-3.9032,  1.5109]])
tensor([[-4.5503,  2.1453]])
tensor([[-1.2684,  1.1635]])
tensor([[-2.0355,  0.6928]])


In [None]:
logits = predict(i)
probs = nn.functional.softmax(logits, dim=-1)
probs

tensor([[0.0613, 0.9387]])

In [None]:
probs[0, 1]

tensor(0.9387)

In [None]:
def predict(inp, model=model):
    with torch.no_grad():
        out = model(**inp)
    return nn.functional.softmax(out.logits, dim=-1)[0, 1].item()

In [None]:
predict(i)

0.9386754035949707

In [None]:
import ast

In [None]:
tree = ast.parse(Path("example.py").open("r").read())

In [None]:
tree.body

[<ast.Expr at 0x7f67e02dbfa0>,
 <ast.Import at 0x7f67e02dbf40>,
 <ast.ImportFrom at 0x7f67e02dbee0>,
 <ast.ImportFrom at 0x7f67e02dbd60>,
 <ast.Assign at 0x7f67e02dbd00>,
 <ast.Assign at 0x7f67e02dbc10>,
 <ast.FunctionDef at 0x7f67e02dbb20>,
 <ast.FunctionDef at 0x7f67e02db790>,
 <ast.FunctionDef at 0x7f67e02da7d0>]

In [None]:
tree.body[-3].lineno, tree.body[-3].end_lineno, tree.body[-3].col_offset, tree.body[-3].end_col_offset

(16, 18, 0, 26)

In [None]:
''.join(contents[16:18])

'    fn: Callable[_P, _T], recursive: bool = True\n) -> Callable[_P, _T]: ...\n'

In [None]:
''.join(contents[15:18]).strip()

'def _disable_dynamo(\n    fn: Callable[_P, _T], recursive: bool = True\n) -> Callable[_P, _T]: ...'

In [None]:
tree2 = ast.parse("def hello(a):\n    a += a\n    return a")
tree2.body

[<ast.FunctionDef at 0x7f67e02dae90>]

In [None]:
tree2.body[0].name, tree2.body[0].lineno, tree.body[0].end_lineno

('hello', 1, 4)

In [None]:
for i in tree.body:
    if isinstance(i, ast.FunctionDef):
        print(i)
        lineno = i.lineno - 1
        while lineno <= i.end_lineno - 1:
            if contents[lineno].find("):") != -1 or contents[lineno].find("->") != -1:
                print(''.join(contents[i.lineno - 1:lineno+1]))
                break
            lineno += 1

<ast.FunctionDef object at 0x7f67e02afb20>
def _disable_dynamo(
    fn: Callable[_P, _T], recursive: bool = True
) -> Callable[_P, _T]: ...

<ast.FunctionDef object at 0x7f67e02af790>
def _disable_dynamo(
    fn: Literal[None] = None, recursive: bool = True
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...

<ast.FunctionDef object at 0x7f67e02af280>
def _disable_dynamo(
    fn: Optional[Callable[_P, _T]] = None, recursive: bool = True
) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]:



In [None]:
tree.body[-1].lineno, tree.body[-1].end_lineno, tree.body[-1].col_offset, tree.body[-1].end_col_offset

(27, 57, 0, 70)

In [None]:
def parse_text(text):
    # NOTE: Doesn't collect comments and function definitions correctly
    inputs = []
    defs = []
    tree = ast.parse(text)
    for el in tree.body:
        if isinstance(el, ast.FunctionDef):
            defs.append((el.lineno - 1, el.end_lineno - 1, el.col_offset))

    inputs = []
    lines = text.split('\n')
    for lineno, line in enumerate(lines):
        if (offset := line.find('#')) != -1:
            corresponding_def = None
            for (def_l, def_el, def_off) in defs:
                if def_l <= lineno and def_off <= offset:
                    corresponding_def = (def_l, def_el, def_off)

            comment = line[offset:]
            code = '\n'.join(lines[lineno - 4:lineno + 4])
            fdef = "None"
            if corresponding_def is not None:
                fdef = [lines[corresponding_def[0]][corresponding_def[2]:]]
                cur_lineno = corresponding_def[0]
                while cur_lineno <= corresponding_def[1]:
                    if lines[cur_lineno].find("):") != -1 or lines[cur_lineno].find("->") != -1:
                        fdef += lines[corresponding_def[0] + 1:cur_lineno + 1]
                        break
                    cur_lineno += 1
                
                fdef = '\n'.join(fdef).strip()
                    
            inputs.append({
                "function_definition": fdef,
                "code": code,
                "comment": comment,
                "lineno": lineno
            })
    return inputs

In [None]:
inputs = parse_text(Path("example.py").open("r").read())

print(len(inputs))
for idx, i in enumerate(inputs):
    print(idx, "---------------")
    for k, v in i.items():
        print(k)
        print(v)

4
0 ---------------
function_definition
def _disable_dynamo(
    fn: Optional[Callable[_P, _T]] = None, recursive: bool = True

) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]:
code
    if fn is not None:

        @functools.wraps(fn)
        def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:
            # cache this on the first invocation to avoid adding too much overhead.
            disable_fn = getattr(fn, "__dynamo_disable", None)
            if disable_fn is None:
                import torch._dynamo
comment
# cache this on the first invocation to avoid adding too much overhead.
lineno
42
1 ---------------
function_definition
def _disable_dynamo(
    fn: Optional[Callable[_P, _T]] = None, recursive: bool = True

) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]:
code
            if disable_fn is None:
                import torch._dynamo

                disable_fn = torch._dynamo.disable(fn, recursive)
                f

In [None]:
for inp in inputs:
    print(predict(prepare_input(inp)))

0.9996691942214966
0.9998605251312256
0.9934530258178711
0.9993849992752075


In [None]:
def parse_and_predict(text, thrd=None):
    parsed = parse_text(text)
    preds = [predict(prepare_input(p)) for p in parsed]
    result = []
    for i, p in enumerate(preds):
        if thrd:
            p = thrd > p
        result.append((parsed[i]["lineno"], p))
        
    return result


def parse_and_predict_file(path, thrd=None):
    text = Path(path).open("r").read()
    return parse_and_predict(text, thrd)

In [None]:
parse_and_predict_file("example.py")

[(42, 0.9996691942214966),
 (48, 0.9998605251312256),
 (54, 0.9934530258178711),
 (55, 0.9993849992752075)]

In [None]:
parse_and_predict_file("example.py", 0.99)

[(42, False), (48, False), (54, False), (55, False)]