In [None]:
# CLEANED VERSION OF GRAPH CREATION WITH COMMENTS
# This notebook uses the CoDeTM4Cleaned dataset to avoid data leakage

<VSCode.Cell id="#VSC-2de3e2b1" language="python">
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))


</VSCode.Cell>
<VSCode.Cell id="#VSC-78e6b420" language="python">
from data.dataset.codet_m4_cleaned import CoDeTM4Cleaned

</VSCode.Cell>
<VSCode.Cell id="#VSC-ca99a2fb" language="python">
train, val, test  = CoDeTM4Cleaned('../../data/codet_cleaned_20250812_201438/').get_dataset(['train','val','test'], columns='all', dynamic_split_sizing=False)

</VSCode.Cell>
<VSCode.Cell id="#VSC-36e9e2bd" language="python">
from datasets import concatenate_datasets

</VSCode.Cell>
<VSCode.Cell id="#VSC-db3db8b2" language="python">
codet = concatenate_datasets([train, val, test])

</VSCode.Cell>
<VSCode.Cell id="#VSC-9715ed3c" language="python">
import tree_sitter_python as tspython
import tree_sitter_cpp as tscpp
import tree_sitter_java as tsjava
from tree_sitter import Parser, Language

TS_PYTHON = Language(tspython.language())
TS_JAVA = Language(tsjava.language())
TS_CPP = Language(tscpp.language())

PYTHON_PARSER, JAVA_PARSER, CPP_PARSER = Parser(language=TS_PYTHON), Parser(language=TS_JAVA), Parser(language=TS_CPP)

# Enable comment parsing by setting `parser.set_included_ranges` with full range of the source code
# This is a workaround: tree-sitter parsers by default include comments as nodes, 
# so no extra flag is needed, but if previously you filtered comments, do not filter now.

</VSCode.Cell>
<VSCode.Cell id="#VSC-9fd6153b" language="python">
import matplotlib.pyplot as plt

</VSCode.Cell>
<VSCode.Cell id="#VSC-819bfc7e" language="python">
language_counts = codet['language']
plt.figure(figsize=(5, 5))
plt.hist(language_counts, bins=len(set(language_counts)), edgecolor='black')
plt.xlabel('Language')
plt.ylabel('Count')
plt.title('Distribution of Programming Languages in CoDeTM4 Cleaned Dataset')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

python_count = language_counts.count("python")
java_count = language_counts.count("java")
cpp_count = language_counts.count("cpp")

print(f'Python count: {python_count}')
print(f'Java count: {java_count}')
print(f'C++ count: {cpp_count}')

</VSCode.Cell>
<VSCode.Cell id="#VSC-d1700f0c" language="python">
def get_parser(language):
    match language:
        case 'python':
            return PYTHON_PARSER
        case 'java':
            return JAVA_PARSER
        case 'cpp':
            return CPP_PARSER
    raise ValueError(f"Unsupported language: {language}")

</VSCode.Cell>
<VSCode.Cell id="#VSC-6f87d16c" language="python">
def create_tree(sample, code_key='cleaned_code'):
    language = sample['language']
    parser = get_parser(language)
    # Include full range to keep comments included
    code_bytes = bytes(sample[code_key], 'utf-8')
    # Use full range
    #parser.set_included_ranges([ (0, len(code_bytes)) ])
    tree = parser.parse(code_bytes)
    # Reset included ranges to None after parse to avoid side effects
    #parser.set_included_ranges(None)
    return tree

</VSCode.Cell>
<VSCode.Cell id="#VSC-9882b03e" language="python">
from tree_sitter import TreeCursor

def walk_tree(cursor: TreeCursor, depth=0):
    indent = '  ' * depth
    print(f"{indent}{cursor.node.type}")

    if cursor.goto_first_child():
        walk_tree(cursor, depth+1)
    
        while cursor.goto_next_sibling():
            walk_tree(cursor, depth+1)

        cursor.goto_parent()

</VSCode.Cell>
<VSCode.Cell id="#VSC-c0ee1c51" language="python">
from typing import Set

def get_node_types_from_tree(cursor: TreeCursor, types: Set[str]=None) -> Set[str]:
    if types is None:
        types = set()
    
    types.add(cursor.node.type)

    if cursor.goto_first_child():
        get_node_types_from_tree(cursor, types)
    
        while cursor.goto_next_sibling():
            get_node_types_from_tree(cursor, types)

        cursor.goto_parent()

    return types

</VSCode.Cell>
<VSCode.Cell id="#VSC-dde8c2ee" language="python">
def extract_types(sample):
    # Import everything needed inside the function
    import tree_sitter_python as tspython
    import tree_sitter_cpp as tscpp
    import tree_sitter_java as tsjava
    from tree_sitter import Parser, Language, TreeCursor
    from typing import Set
    
    # Create parsers locally
    TS_PYTHON = Language(tspython.language())
    TS_JAVA = Language(tsjava.language())
    TS_CPP = Language(tscpp.language())
    
    PYTHON_PARSER = Parser(language=TS_PYTHON)
    JAVA_PARSER = Parser(language=TS_JAVA)
    CPP_PARSER = Parser(language=TS_CPP)
    
    def get_parser(language):
        match language:
            case 'python':
                return PYTHON_PARSER
            case 'java':
                return JAVA_PARSER
            case 'cpp':
                return CPP_PARSER
        raise ValueError(f"Unsupported language: {language}")
    
    def get_node_types_from_tree(cursor: TreeCursor, types: Set[str]=None) -> Set[str]:
        if types is None:
            types = set()
        
        types.add(cursor.node.type)

        if cursor.goto_first_child():
            get_node_types_from_tree(cursor, types)
        
            while cursor.goto_next_sibling():
                get_node_types_from_tree(cursor, types)

            cursor.goto_parent()

        return types
    
    parser = get_parser(sample['language'])
    code_bytes = sample['code'].encode('utf-8')
    # parser.set_included_ranges([ (0, len(code_bytes)) ])
    tree = parser.parse(code_bytes)
    # parser.set_included_ranges(None)
    cursor = tree.walk()
    types = get_node_types_from_tree(cursor)
    return {"types": list(types)}

result = codet.map(extract_types, batched=False, num_proc=8)

all_types = set()
for tlist in result['types']:
    all_types.update(tlist)

print(f"Collected {len(all_types)} unique node types")

</VSCode.Cell>
<VSCode.Cell id="#VSC-698dca1b" language="python">
all_types = sorted(list(all_types))

</VSCode.Cell>
<VSCode.Cell id="#VSC-e0b7f9c4" language="python">
type_to_ind = {t: i for i, t in enumerate(all_types)}

</VSCode.Cell>
<VSCode.Cell id="#VSC-011b30fb" language="python">
len(all_types)
</VSCode.Cell>
<VSCode.Cell id="#VSC-27c1a491" language="python">
for t in all_types:
    print(t)
</VSCode.Cell>
<VSCode.Cell id="#VSC-f85aa862" language="python">
from torch_geometric.data import Data
from typing import List, Tuple, Dict
from tree_sitter import Node

</VSCode.Cell>
<VSCode.Cell id="#VSC-5f05d8d7" language="python">
from torch_geometric.data import Data
from typing import List, Tuple, Dict
from torch import tensor, long as tlong
from tree_sitter import TreeCursor

def tree_to_graph(cursor: TreeCursor, id_map: Dict = None, next_id: int = 0, edges: List[Tuple[int, int]] = None) -> Tuple[List[Tuple[int, int]], Dict, int]:
    if edges is None:
        edges = []
    if id_map is None:
        id_map = {}

    # Assign ID to current node
    if cursor.node not in id_map:
        id_map[cursor.node] = next_id
        next_id += 1
    current_id = id_map[cursor.node]

    if cursor.goto_first_child():
        # Process first child
        if cursor.node not in id_map:
            id_map[cursor.node] = next_id
            next_id += 1
        child_id = id_map[cursor.node]
        edges.append((current_id, child_id))
        edges, id_map, next_id = tree_to_graph(cursor, id_map, next_id, edges)
        
        # Process siblings
        while cursor.goto_next_sibling():
            if cursor.node not in id_map:
                id_map[cursor.node] = next_id
                next_id += 1
            child_id = id_map[cursor.node]
            edges.append((current_id, child_id))
            edges, id_map, next_id = tree_to_graph(cursor, id_map, next_id, edges)
        
        cursor.goto_parent()

    return edges, id_map, next_id

</VSCode.Cell>
<VSCode.Cell id="#VSC-b8f5d0a1" language="python">
from tqdm import tqdm

</VSCode.Cell>
<VSCode.Cell id="#VSC-095b7222" language="python">
def create_graph(sample, code_key='cleaned_code'):
    tree = create_tree(sample, code_key=code_key)
    edges, id_map, _ = tree_to_graph(tree.walk())
    edge_index = tensor(edges, dtype=tlong).t().contiguous()
    x = [type_to_ind[node.type] for node, _ in sorted(id_map.items(), key=lambda kv: kv[1])]
    x = tensor(x, dtype=tlong)
    y = tensor([sample['target_binary']], dtype=tlong)

    graph_features = tensor(list(sample['features'].values()))
    
    metadata = {
        'language': sample['language'],
        'target': sample['target'],
        'target_binary': sample['target_binary'],
        'code': sample['code'],
        'cleaned_code': sample['cleaned_code']
    }
    
    data = Data(
        x=x, 
        y=y, 
        edge_index=edge_index, 
        graph_features=graph_features,
        metadata=metadata
    )    
    return data


</VSCode.Cell>
<VSCode.Cell id="#VSC-8d1ec36b" language="python">
codet

</VSCode.Cell>
<VSCode.Cell id="#VSC-89d0fc6e" language="python">
def create_graphs(dataset, desc_keyword, code_key='cleaned_code'):
    graphs = []

    for i, sample in enumerate(tqdm(dataset, desc=f'Creating {desc_keyword} graphs')):
        data = create_graph(sample, code_key)
        graphs.append(data)

    return graphs
    

</VSCode.Cell>
<VSCode.Cell id="#VSC-1b040f8d" language="python">
from torch_geometric.data import Data
from torch import save
import gc

train_graphs = create_graphs(train, 'train', 'code')
save(train_graphs, '../../data/codet_graphs/train_graphs_cleaned_comments.pt')
del train, train_graphs
gc.collect()
val_graphs = create_graphs(val, 'val', 'code')
save(val_graphs, '../../data/codet_graphs/val_graphs_cleaned_comments.pt')
del val, val_graphs
gc.collect()
test_graphs = create_graphs(test, 'test', 'code')
save(test_graphs, '../../data/codet_graphs/test_graphs_cleaned_comments.pt')
del test, test_graphs
gc.collect()
save(type_to_ind, '../../data/codet_graphs/type_to_ind_cleaned_comments.pt')

</VSCode.Cell>
<VSCode.Cell id="#VSC-6d2855c3" language="python">
from torch import load

</VSCode.Cell>
<VSCode.Cell id="#VSC-5c327545" language="python">
from torch_geometric.data import Data

</VSCode.Cell>
<VSCode.Cell id="#VSC-2135cbfa" language="python">
train_graphs = load('../../data/codet_graphs/train_graphs_cleaned_comments.pt', weights_only=False)

</VSCode.Cell>
<VSCode.Cell id="#VSC-bc1a04a5" language="python">
type_to_ind = load('../../data/codet_graphs/type_to_ind_cleaned_comments.pt', weights_only=False)

</VSCode.Cell>
<VSCode.Cell id="#VSC-3fe03c64" language="python">
len(train_graphs)

</VSCode.Cell>
<VSCode.Cell id="#VSC-98f655f8" language="python">
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

def visualize_graph(data, figsize=(15, 10), show_labels=True):
    # Convert to NetworkX graph
    G = to_networkx(data, to_undirected=False)
    
    plt.figure(figsize=figsize)
    
    # Create tree layout manually
    pos = {}
    
    # Find root (node with no incoming edges)
    root = 0
    for node in G.nodes():
        if G.in_degree(node) == 0:
            root = node
            break
    
    # Assign positions level by level
    levels = {}
    queue = [(root, 0)]
    
    while queue:
        node, level = queue.pop(0)
        levels[node] = level
        
        # Add children to next level
        for child in G.successors(node):
            queue.append((child, level + 1))
    
    # Group nodes by level
    level_groups = {}
    for node, level in levels.items():
        if level not in level_groups:
            level_groups[level] = []
        level_groups[level].append(node)
    
    # Position nodes
    for level, nodes in level_groups.items():
        for i, node in enumerate(nodes):
            x = i - len(nodes) / 2  # Center nodes horizontally
            y = -level  # Higher levels at top
            pos[node] = (x, y)
    
    # Create labels if requested
    labels = None
    if show_labels:
        # Create reverse mapping from index to type
        ind_to_type = {v: k for k, v in type_to_ind.items()}
        labels = {}
        for node in G.nodes():
            node_type_idx = data.x[node].item()
            node_type = ind_to_type.get(node_type_idx, f"idx_{node_type_idx}")
            labels[node] = node_type
    
    nx.draw(G, pos, with_labels=show_labels, labels=labels, 
            node_color='lightblue', node_size=500, arrows=True, 
            font_size=6 if show_labels else 8)
    
    plt.title("AST Tree")
    plt.show()


</VSCode.Cell>
<VSCode.Cell id="#VSC-5f2ae2d6" language="python">
train_graphs[1]
</VSCode.Cell>
<VSCode.Cell id="#VSC-ef568209" language="python">
visualize_graph(train_graphs[1], show_labels=True)

</VSCode.Cell>