In [None]:
import ast
from pathlib import Path
import pickle

import datasets
import evaluate
import numpy as np
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments

In [None]:
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("finetuned_bigger")

In [None]:
sep = "[SEP]"

def prepare_input(example):
    tokens = tokenizer(
        example["function_definition"] + sep + example["code"] + sep + example["comment"],
        truncation=True,
        max_length=1024,
        return_tensors="pt"
    )
    return tokens

In [None]:
def parse_text(text):
    # NOTE: Doesn't collect comments and function definitions correctly
    inputs = []
    defs = []
    tree = ast.parse(text)
    for el in tree.body:
        if isinstance(el, ast.FunctionDef):
            defs.append((el.lineno - 1, el.end_lineno - 1, el.col_offset))

    inputs = []
    lines = text.split('\n')
    for lineno, line in enumerate(lines):
        if (offset := line.find('#')) != -1:
            corresponding_def = None
            for (def_l, def_el, def_off) in defs:
                if def_l <= lineno and def_off <= offset:
                    corresponding_def = (def_l, def_el, def_off)

            comment = line[offset:]
            code = '\n'.join(lines[lineno - 4:lineno + 4])
            fdef = "None"
            if corresponding_def is not None:
                fdef = [lines[corresponding_def[0]][corresponding_def[2]:]]
                cur_lineno = corresponding_def[0]
                while cur_lineno <= corresponding_def[1]:
                    if lines[cur_lineno].find("):") != -1 or lines[cur_lineno].find("->") != -1:
                        fdef += lines[corresponding_def[0] + 1:cur_lineno + 1]
                        break
                    cur_lineno += 1
                
                fdef = '\n'.join(fdef).strip()
                    
            inputs.append({
                "function_definition": fdef,
                "code": code,
                "comment": comment,
                "lineno": lineno
            })
    return inputs

In [None]:
text = Path("example.py").open("r").read()
parsed = parse_text(text)
prepared = list(map(prepare_input, parsed))
prepared[0]["input_ids"].shape, prepared[1]["input_ids"].shape

(torch.Size([1, 179]), torch.Size([1, 146]))

In [None]:
sep = "[SEP]"

def prepare_input_batch(examples):
    to_tokenize = [example["function_definition"] + sep + example["code"] + sep + example["comment"] for example in examples]
    tokens = tokenizer(
        to_tokenize,
        truncation=True,
        padding=True,
        max_length=1024,
        return_tensors="pt"
    )
    return tokens

In [None]:
text = Path("example.py").open("r").read()
parsed = parse_text(text)
prepared = prepare_input_batch(parsed)
prepared["input_ids"].shape

torch.Size([4, 179])

In [None]:
def predict(inp, model=model):
    with torch.no_grad():
        out = model(**inp)
    return nn.functional.softmax(out.logits, dim=-1)[:, 1].tolist()

In [None]:
predict(prepared)

[0.9620012044906616,
 0.8701790571212769,
 0.9321628212928772,
 0.8722432255744934]

In [None]:
def parse_and_predict(text, thrd=None):
    parsed = parse_text(text)
    preds = predict(prepare_input_batch(parsed))
    result = []
    for i, p in enumerate(preds):
        if thrd:
            p = thrd > p
        result.append((parsed[i]["lineno"], p))
        
    return result


def parse_and_predict_file(path, thrd=None):
    text = Path(path).open("r").read()
    return parse_and_predict(text, thrd)

In [None]:
parse_and_predict_file("example.py")

[(42, 0.9620012044906616),
 (48, 0.8701790571212769),
 (54, 0.9321628212928772),
 (55, 0.8722432255744934)]

In [None]:
parse_and_predict(
"""a = 3
b = 2
# The code below does some calculations based on a predefined rule that is very important
c = a - b  # Calculate and store the sum of a and b in c
d = a + b  # Calculate and store the sum of a and b in d
e = c * b  # Calculate and store the product of c and d in e
print(f"Wow, maths: {[a, b, c, d, e]}")"""
)

[(2, 0.23048676550388336),
 (3, 0.14431209862232208),
 (4, 0.1821700781583786),
 (5, 0.1499861180782318)]