In [None]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from data.dataset.aigcodeset import AIGCodeSet

In [89]:
from tree_sitter import Language, Parser
import tree_sitter_python as tspython

# Set up the parser
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)

def extract_features(code_snippet):
    tree = parser.parse(bytes(code_snippet, "utf8"))
    root_node = tree.root_node

    # Queries to find specific patterns
    function_query = PY_LANGUAGE.query("(function_definition) @func")
    if_query = PY_LANGUAGE.query("(if_statement) @if_stmt")
    while_query = PY_LANGUAGE.query("(while_statement) @while_stmt")
    for_query = PY_LANGUAGE.query("(for_statement) @for_stmt")
    comment_query = PY_LANGUAGE.query("(comment) @comment")
    import_query = PY_LANGUAGE.query("(import_statement) @import")
    import_from_query = PY_LANGUAGE.query("(import_from_statement) @import_from")
    class_query = PY_LANGUAGE.query("(class_definition) @class_def")
    binary_op_query = PY_LANGUAGE.query("(binary_operator) @binop")
    error_query = PY_LANGUAGE.query("(ERROR) @error")

    # Calculate max nesting depth
    max_nesting_depth = 0
    current_depth = 0

    def traverse(node, depth):
        nonlocal max_nesting_depth, current_depth
        current_depth = depth
        max_nesting_depth = max(max_nesting_depth, current_depth)

        # Increment depth only for block-level constructs
        node_type = node.type
        if node_type in {"function_definition", "class_definition", "if_statement", "for_statement", "while_statement"}:
            for child in node.children:
                traverse(child, depth + 1)
        else:
            for child in node.children:
                traverse(child, depth)

    # Start traversal from the root
    traverse(root_node, 0)

    # Extract features
    features = {
        "function_defs": len(function_query.captures(root_node)),
        "if_statements": len(if_query.captures(root_node)),
        "while_loops": len(while_query.captures(root_node)),
        "for_loops": len(for_query.captures(root_node)),
        "imports": len(import_query.captures(root_node)) + len(import_from_query.captures(root_node)),
        "comments": len(comment_query.captures(root_node)),
        "class_defs": len(class_query.captures(root_node)),
        "max_nesting_depth": max_nesting_depth,
        "binary_ops": len(binary_op_query.captures(root_node)),
        "errors": len(error_query.captures(root_node)),
    }
    return features

# Example usage
code_with_error = """
def add(a, b)  # Missing colon
    return a + b
"""
code_valid = """
import os
import sys
def subtract(a, b):
    if a > b:
        for i in range(3):
            for hhh in range(45):
                c = a - b
        return c
    return a - b
"""
features_err = extract_features(code_with_error)
features_valid = extract_features(code_valid)
print("Errored code:", features_err)
print("Valid code:", features_valid)

Errored code: {'function_defs': 0, 'if_statements': 0, 'while_loops': 0, 'for_loops': 0, 'imports': 0, 'comments': 1, 'class_defs': 0, 'max_nesting_depth': 0, 'binary_ops': 0, 'errors': 1}
Valid code: {'function_defs': 1, 'if_statements': 1, 'while_loops': 0, 'for_loops': 1, 'imports': 1, 'comments': 0, 'class_defs': 0, 'max_nesting_depth': 4, 'binary_ops': 1, 'errors': 0}


In [60]:
train, val, test = AIGCodeSet(cache_dir='../../data/').get_dataset(split=True)

In [61]:
codes = train['code']

In [91]:
features = {}
for code in codes:
    features[code] = extract_features(code)

In [94]:
list(features.items())[2][1]

{'function_defs': 1,
 'if_statements': 1,
 'while_loops': 1,
 'for_loops': 1,
 'imports': 2,
 'comments': 1,
 'class_defs': 1,
 'max_nesting_depth': 4,
 'binary_ops': 1,
 'errors': 0}