In [None]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))



In [None]:
from data.dataset.codet_m4_cleaned import CoDeTM4Cleaned


In [None]:
train, val, test  = CoDeTM4Cleaned('../../data/codet_cleaned_20250812_201438/').get_dataset(['train','val','test'], columns='all', dynamic_split_sizing=False)


In [None]:
from datasets import concatenate_datasets


In [None]:
codet = concatenate_datasets([train, val, test])


In [None]:
import tree_sitter_python as tspython
import tree_sitter_cpp as tscpp
import tree_sitter_java as tsjava
from tree_sitter import Parser, Language

TS_PYTHON = Language(tspython.language())
TS_JAVA = Language(tsjava.language())
TS_CPP = Language(tscpp.language())

PYTHON_PARSER, JAVA_PARSER, CPP_PARSER = Parser(language=TS_PYTHON), Parser(language=TS_JAVA), Parser(language=TS_CPP)

# Enable comment parsing by setting `parser.set_included_ranges` with full range of the source code
# This is a workaround: tree-sitter parsers by default include comments as nodes, 
# so no extra flag is needed, but if previously you filtered comments, do not filter now.


In [None]:
import matplotlib.pyplot as plt


In [None]:
language_counts = codet['language']
plt.figure(figsize=(5, 5))
plt.hist(language_counts, bins=len(set(language_counts)), edgecolor='black')
plt.xlabel('Language')
plt.ylabel('Count')
plt.title('Distribution of Programming Languages in CoDeTM4 Cleaned Dataset')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

python_count = language_counts.count("python")
java_count = language_counts.count("java")
cpp_count = language_counts.count("cpp")

print(f'Python count: {python_count}')
print(f'Java count: {java_count}')
print(f'C++ count: {cpp_count}')


In [None]:
def get_parser(language):
    match language:
        case 'python':
            return PYTHON_PARSER
        case 'java':
            return JAVA_PARSER
        case 'cpp':
            return CPP_PARSER
    raise ValueError(f"Unsupported language: {language}")


In [None]:
def create_tree(sample, code_key='cleaned_code'):
    language = sample['language']
    parser = get_parser(language)
    # Include full range to keep comments included
    code_bytes = bytes(sample[code_key], 'utf-8')
    # Use full range
    #parser.set_included_ranges([ (0, len(code_bytes)) ])
    tree = parser.parse(code_bytes)
    # Reset included ranges to None after parse to avoid side effects
    #parser.set_included_ranges(None)
    return tree


In [None]:
from tree_sitter import TreeCursor

def walk_tree(cursor: TreeCursor, depth=0):
    indent = '  ' * depth
    print(f"{indent}{cursor.node.type}")

    if cursor.goto_first_child():
        walk_tree(cursor, depth+1)
    
        while cursor.goto_next_sibling():
            walk_tree(cursor, depth+1)

        cursor.goto_parent()


In [None]:
from typing import Set

def get_node_types_from_tree(cursor: TreeCursor, types: Set[str]=None) -> Set[str]:
    if types is None:
        types = set()
    
    types.add(cursor.node.type)

    if cursor.goto_first_child():
        get_node_types_from_tree(cursor, types)
    
        while cursor.goto_next_sibling():
            get_node_types_from_tree(cursor, types)

        cursor.goto_parent()

    return types


In [None]:
def extract_types(sample):
    # Import everything needed inside the function
    import tree_sitter_python as tspython
    import tree_sitter_cpp as tscpp
    import tree_sitter_java as tsjava
    from tree_sitter import Parser, Language, TreeCursor
    from typing import Set
    
    # Create parsers locally
    TS_PYTHON = Language(tspython.language())
    TS_JAVA = Language(tsjava.language())
    TS_CPP = Language(tscpp.language())
    
    PYTHON_PARSER = Parser(language=TS_PYTHON)
    JAVA_PARSER = Parser(language=TS_JAVA)
    CPP_PARSER = Parser(language=TS_CPP)
    
    def get_parser(language):
        match language:
            case 'python':
                return PYTHON_PARSER
            case 'java':
                return JAVA_PARSER
            case 'cpp':
                return CPP_PARSER
        raise ValueError(f"Unsupported language: {language}")
    
    def get_node_types_from_tree(cursor: TreeCursor, types: Set[str]=None) -> Set[str]:
        if types is None:
            types = set()
        
        types.add(cursor.node.type)

        if cursor.goto_first_child():
            get_node_types_from_tree(cursor, types)
        
            while cursor.goto_next_sibling():
                get_node_types_from_tree(cursor, types)

            cursor.goto_parent()

        return types
    
    parser = get_parser(sample['language'])
    code_bytes = sample['code'].encode('utf-8')
    # parser.set_included_ranges([ (0, len(code_bytes)) ])
    tree = parser.parse(code_bytes)
    # parser.set_included_ranges(None)
    cursor = tree.walk()
    types = get_node_types_from_tree(cursor)
    return {"types": list(types)}

result = codet.map(extract_types, batched=False, num_proc=8)

all_types = set()
for tlist in result['types']:
    all_types.update(tlist)

print(f"Collected {len(all_types)} unique node types")


In [None]:
all_types = sorted(list(all_types))


In [None]:
type_to_ind = {t: i for i, t in enumerate(all_types)}


In [None]:
len(all_types)

In [None]:
for t in all_types:
    print(t)

In [None]:
from torch_geometric.data import Data
from typing import List, Tuple, Dict
from tree_sitter import Node


In [None]:
from torch_geometric.data import Data
from typing import List, Tuple, Dict
from torch import tensor, long as tlong
from tree_sitter import TreeCursor

def tree_to_graph(cursor: TreeCursor, id_map: Dict = None, next_id: int = 0, edges: List[Tuple[int, int]] = None) -> Tuple[List[Tuple[int, int]], Dict, int]:
    if edges is None:
        edges = []
    if id_map is None:
        id_map = {}

    # Assign ID to current node
    if cursor.node not in id_map:
        id_map[cursor.node] = next_id
        next_id += 1
    current_id = id_map[cursor.node]

    if cursor.goto_first_child():
        # Process first child
        if cursor.node not in id_map:
            id_map[cursor.node] = next_id
            next_id += 1
        child_id = id_map[cursor.node]
        edges.append((current_id, child_id))
        edges, id_map, next_id = tree_to_graph(cursor, id_map, next_id, edges)
        
        # Process siblings
        while cursor.goto_next_sibling():
            if cursor.node not in id_map:
                id_map[cursor.node] = next_id
                next_id += 1
            child_id = id_map[cursor.node]
            edges.append((current_id, child_id))
            edges, id_map, next_id = tree_to_graph(cursor, id_map, next_id, edges)
        
        cursor.goto_parent()

    return edges, id_map, next_id


In [None]:
from tqdm import tqdm


In [None]:
def create_graph(sample, code_key='cleaned_code'):
    tree = create_tree(sample, code_key=code_key)
    edges, id_map, _ = tree_to_graph(tree.walk())
    edge_index = tensor(edges, dtype=tlong).t().contiguous()
    x = [type_to_ind[node.type] for node, _ in sorted(id_map.items(), key=lambda kv: kv[1])]
    x = tensor(x, dtype=tlong)
    y = tensor([sample['target_binary']], dtype=tlong)

    graph_features = tensor(list(sample['features'].values()))
    
    metadata = {
        'language': sample['language'],
        'target': sample['target'],
        'target_binary': sample['target_binary'],
        'code': sample['code'],
        'cleaned_code': sample['cleaned_code']
    }
    
    data = Data(
        x=x, 
        y=y, 
        edge_index=edge_index, 
        graph_features=graph_features,
        metadata=metadata
    )    
    return data



In [None]:
codet


In [None]:
from torch_geometric.data import Data
from torch import tensor, long as tlong
from typing import List, Dict, Tuple
from collections import defaultdict, deque
import torch
def compute_depths(num_nodes: int, edges: List[Tuple[int, int]]) -> tlong:
    """Compute depth (distance from root) for each node."""
    depths = torch.zeros(num_nodes, dtype=torch.long)
    tree = defaultdict(list)
    for parent, child in edges:
        tree[parent].append(child)
    visited = [False] * num_nodes
    queue = deque([0])  # assume root node has ID 0
    visited[0] = True
    while queue:
        node = queue.popleft()
        for child in tree[node]:
            if not visited[child]:
                depths[child] = depths[node] + 1
                visited[child] = True
                queue.append(child)
    return depths

def compute_child_indices(num_nodes: int, edges: List[Tuple[int, int]]) -> tlong:
    """Compute sibling index for each node (order among its siblings)."""
    child_idx = torch.zeros(num_nodes, dtype=torch.long)
    tree = defaultdict(list)
    for parent, child in edges:
        tree[parent].append(child)
    for parent, children in tree.items():
        for i, child in enumerate(children):
            child_idx[child] = i
    return child_idx

def create_graph(sample, code_key='cleaned_code'):
    """Create a PyG Data object with node type, depth, and child index embeddings."""
    tree = create_tree(sample, code_key=code_key)
    edges, id_map, _ = tree_to_graph(tree.walk())
    
    edge_index = tensor(edges, dtype=tlong).t().contiguous()
    
    # Node type IDs
    x = [type_to_ind[node.type] for node, _ in sorted(id_map.items(), key=lambda kv: kv[1])]
    x = tensor(x, dtype=tlong)
    num_nodes = x.size(0)

    # Compute depth and child index
    node_depth = compute_depths(num_nodes, edges)
    child_index = compute_child_indices(num_nodes, edges)
    
    # Target
    y = tensor([sample['target_binary']], dtype=tlong)
    
    # Graph features
    graph_features = tensor(list(sample['features'].values()))
    
    # Metadata
    metadata = {
        'language': sample['language'],
        'target': sample['target'],
        'target_binary': sample['target_binary'],
        'code': sample['code'],
        'cleaned_code': sample['cleaned_code']
    }

    data = Data(
        x=x,
        y=y,
        edge_index=edge_index,
        node_depth=node_depth,
        child_index=child_index,
        graph_features=graph_features,
        metadata=metadata
    )
    
    return data


def create_graphs(dataset, desc_keyword, code_key='cleaned_code'):
    graphs = []

    for i, sample in enumerate(tqdm(dataset, desc=f'Creating {desc_keyword} graphs')):
        data = create_graph(sample, code_key)
        graphs.append(data)

    return graphs
    


In [None]:
# Statistics tracking for depth and child indices
max_depth_global = 0
max_child_index_global = 0
depth_stats = []
child_index_stats = []

def compute_depths(num_nodes: int, edges: List[Tuple[int, int]]) -> tlong:
    """Compute depth (distance from root) for each node."""
    global max_depth_global
    depths = torch.zeros(num_nodes, dtype=torch.long)
    tree = defaultdict(list)
    for parent, child in edges:
        tree[parent].append(child)
    visited = [False] * num_nodes
    queue = deque([0])  # assume root node has ID 0
    visited[0] = True
    while queue:
        node = queue.popleft()
        for child in tree[node]:
            if not visited[child]:
                depths[child] = depths[node] + 1
                visited[child] = True
                queue.append(child)
    
    # Update global max depth
    current_max_depth = depths.max().item()
    max_depth_global = max(max_depth_global, current_max_depth)
    
    return depths

def compute_child_indices(num_nodes: int, edges: List[Tuple[int, int]]) -> tlong:
    """Compute sibling index for each node (order among its siblings)."""
    global max_child_index_global
    child_idx = torch.zeros(num_nodes, dtype=torch.long)
    tree = defaultdict(list)
    for parent, child in edges:
        tree[parent].append(child)
    for parent, children in tree.items():
        for i, child in enumerate(children):
            child_idx[child] = i
    
    # Update global max child index
    if len(child_idx) > 0:
        current_max_child_index = child_idx.max().item()
        max_child_index_global = max(max_child_index_global, current_max_child_index)
    
    return child_idx

def create_graph(sample, code_key='cleaned_code'):
    """Create a PyG Data object with node type, depth, and child index embeddings."""
    global depth_stats, child_index_stats
    
    tree = create_tree(sample, code_key=code_key)
    edges, id_map, _ = tree_to_graph(tree.walk())
    
    edge_index = tensor(edges, dtype=tlong).t().contiguous()
    
    # Node type IDs
    x = [type_to_ind[node.type] for node, _ in sorted(id_map.items(), key=lambda kv: kv[1])]
    x = tensor(x, dtype=tlong)
    num_nodes = x.size(0)

    # Compute depth and child index
    node_depth = compute_depths(num_nodes, edges)
    child_index = compute_child_indices(num_nodes, edges)
    
    # Collect statistics for this graph
    if len(node_depth) > 0:
        graph_max_depth = node_depth.max().item()
        depth_stats.append(graph_max_depth)
    
    if len(child_index) > 0:
        graph_max_child_index = child_index.max().item()
        child_index_stats.append(graph_max_child_index)
    
    # Target
    y = tensor([sample['target_binary']], dtype=tlong)
    
    # Graph features
    graph_features = tensor(list(sample['features'].values()))
    
    # Metadata
    metadata = {
        'language': sample['language'],
        'target': sample['target'],
        'target_binary': sample['target_binary'],
        'code': sample['code'],
        'cleaned_code': sample['cleaned_code']
    }

    data = Data(
        x=x,
        y=y,
        edge_index=edge_index,
        node_depth=node_depth,
        child_index=child_index,
        graph_features=graph_features,
        metadata=metadata
    )
    
    return data


def create_graphs(dataset, desc_keyword, code_key='cleaned_code'):
    graphs = []

    for i, sample in enumerate(tqdm(dataset, desc=f'Creating {desc_keyword} graphs')):
        data = create_graph(sample, code_key)
        graphs.append(data)

    return graphs

In [None]:
# Reset statistics before creating graphs
max_depth_global = 0
max_child_index_global = 0
depth_stats = []
child_index_stats = []

print("Starting graph creation with statistics tracking...")
print("This will track maximum depth and child index values across all graphs.")
print()

from torch_geometric.data import Data
from torch import save
import gc

train_graphs = create_graphs(train, 'train', 'code')
save(train_graphs, '../../data/codet_graphs/train_graphs_cleaned_comments_depth.pt')
print(f"After train graphs: max_depth={max_depth_global}, max_child_index={max_child_index_global}")
del train, train_graphs
gc.collect()

val_graphs = create_graphs(val, 'val', 'code')
save(val_graphs, '../../data/codet_graphs/val_graphs_cleaned_comments_depth.pt')
print(f"After val graphs: max_depth={max_depth_global}, max_child_index={max_child_index_global}")
del val, val_graphs
gc.collect()

test_graphs = create_graphs(test, 'test', 'code')
save(test_graphs, '../../data/codet_graphs/test_graphs_cleaned_comments_depth.pt')
print(f"After test graphs: max_depth={max_depth_global}, max_child_index={max_child_index_global}")
del test, test_graphs
gc.collect()

save(type_to_ind, '../../data/codet_graphs/type_to_ind_cleaned_comments_depth.pt')

print()
print("Graph creation completed!")
print(f"Final statistics: max_depth={max_depth_global}, max_child_index={max_child_index_global}")
print(f"Total graphs processed: {len(depth_stats)}")

In [None]:
# Save statistics about depth and child indices
import json
import numpy as np

# Calculate comprehensive statistics
depth_stats_array = np.array(depth_stats)
child_index_stats_array = np.array(child_index_stats)

statistics = {
    'max_depth_global': int(max_depth_global),
    'max_child_index_global': int(max_child_index_global),
    'total_graphs_processed': len(depth_stats),
    'depth_statistics': {
        'mean': float(np.mean(depth_stats_array)) if len(depth_stats_array) > 0 else 0.0,
        'std': float(np.std(depth_stats_array)) if len(depth_stats_array) > 0 else 0.0,
        'min': int(np.min(depth_stats_array)) if len(depth_stats_array) > 0 else 0,
        'max': int(np.max(depth_stats_array)) if len(depth_stats_array) > 0 else 0,
        'percentile_95': float(np.percentile(depth_stats_array, 95)) if len(depth_stats_array) > 0 else 0.0,
        'percentile_99': float(np.percentile(depth_stats_array, 99)) if len(depth_stats_array) > 0 else 0.0
    },
    'child_index_statistics': {
        'mean': float(np.mean(child_index_stats_array)) if len(child_index_stats_array) > 0 else 0.0,
        'std': float(np.std(child_index_stats_array)) if len(child_index_stats_array) > 0 else 0.0,
        'min': int(np.min(child_index_stats_array)) if len(child_index_stats_array) > 0 else 0,
        'max': int(np.max(child_index_stats_array)) if len(child_index_stats_array) > 0 else 0,
        'percentile_95': float(np.percentile(child_index_stats_array, 95)) if len(child_index_stats_array) > 0 else 0.0,
        'percentile_99': float(np.percentile(child_index_stats_array, 99)) if len(child_index_stats_array) > 0 else 0.0
    }
}

# Save statistics to JSON file
stats_file_path = '../../data/codet_graphs/depth_child_index_stats_cleaned_comments_depth.json'
with open(stats_file_path, 'w') as f:
    json.dump(statistics, f, indent=2)

print("Graph Statistics Summary:")
print("=" * 50)
print(f"Total graphs processed: {statistics['total_graphs_processed']}")
print(f"Global maximum depth: {statistics['max_depth_global']}")
print(f"Global maximum child index: {statistics['max_child_index_global']}")
print()
print("Depth Statistics:")
print(f"  Mean: {statistics['depth_statistics']['mean']:.2f}")
print(f"  Std:  {statistics['depth_statistics']['std']:.2f}")
print(f"  Min:  {statistics['depth_statistics']['min']}")
print(f"  Max:  {statistics['depth_statistics']['max']}")
print(f"  95th percentile: {statistics['depth_statistics']['percentile_95']:.2f}")
print(f"  99th percentile: {statistics['depth_statistics']['percentile_99']:.2f}")
print()
print("Child Index Statistics:")
print(f"  Mean: {statistics['child_index_statistics']['mean']:.2f}")
print(f"  Std:  {statistics['child_index_statistics']['std']:.2f}")
print(f"  Min:  {statistics['child_index_statistics']['min']}")
print(f"  Max:  {statistics['child_index_statistics']['max']}")
print(f"  95th percentile: {statistics['child_index_statistics']['percentile_95']:.2f}")
print(f"  99th percentile: {statistics['child_index_statistics']['percentile_99']:.2f}")
print()
print(f"Statistics saved to: {stats_file_path}")

In [None]:
# Visualize depth and child index distributions
import matplotlib.pyplot as plt
import numpy as np

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Depth distribution
if len(depth_stats) > 0:
    ax1.hist(depth_stats, bins=50, alpha=0.7, edgecolor='black')
    ax1.axvline(np.mean(depth_stats), color='red', linestyle='--', 
                label=f'Mean: {np.mean(depth_stats):.2f}')
    ax1.axvline(np.percentile(depth_stats, 95), color='orange', linestyle='--', 
                label=f'95th percentile: {np.percentile(depth_stats, 95):.2f}')
    ax1.set_xlabel('Maximum Depth per Graph')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distribution of Maximum Depths')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

# Child index distribution
if len(child_index_stats) > 0:
    ax2.hist(child_index_stats, bins=50, alpha=0.7, edgecolor='black', color='green')
    ax2.axvline(np.mean(child_index_stats), color='red', linestyle='--', 
                label=f'Mean: {np.mean(child_index_stats):.2f}')
    ax2.axvline(np.percentile(child_index_stats, 95), color='orange', linestyle='--', 
                label=f'95th percentile: {np.percentile(child_index_stats, 95):.2f}')
    ax2.set_xlabel('Maximum Child Index per Graph')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Distribution of Maximum Child Indices')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Additional statistics breakdown by language if possible
print("\nDetailed Analysis:")
print(f"Depth range: {min(depth_stats)} to {max(depth_stats)}")
print(f"Child index range: {min(child_index_stats)} to {max(child_index_stats)}")
print()
print("These statistics help determine appropriate embedding dimensions for:")
print("- Node depth embeddings (should be >= max_depth + 1)")
print("- Child index embeddings (should be >= max_child_index + 1)")

In [None]:
from torch import load


In [None]:
from torch_geometric.data import Data


In [None]:
train_graphs = load('../../data/codet_graphs/train_graphs_cleaned_comments_depth.pt', weights_only=False)


In [None]:
type_to_ind = load('../../data/codet_graphs/type_to_ind_cleaned_comments_depth.pt', weights_only=False)


In [None]:
len(train_graphs)


In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

def visualize_graph(data, figsize=(15, 10), show_labels=True):
    # Convert to NetworkX graph
    G = to_networkx(data, to_undirected=False)
    
    plt.figure(figsize=figsize)
    
    # Create tree layout manually
    pos = {}
    
    # Find root (node with no incoming edges)
    root = 0
    for node in G.nodes():
        if G.in_degree(node) == 0:
            root = node
            break
    
    # Assign positions level by level
    levels = {}
    queue = [(root, 0)]
    
    while queue:
        node, level = queue.pop(0)
        levels[node] = level
        
        # Add children to next level
        for child in G.successors(node):
            queue.append((child, level + 1))
    
    # Group nodes by level
    level_groups = {}
    for node, level in levels.items():
        if level not in level_groups:
            level_groups[level] = []
        level_groups[level].append(node)
    
    # Position nodes
    for level, nodes in level_groups.items():
        for i, node in enumerate(nodes):
            x = i - len(nodes) / 2  # Center nodes horizontally
            y = -level  # Higher levels at top
            pos[node] = (x, y)
    
    # Create labels if requested
    labels = None
    if show_labels:
        # Create reverse mapping from index to type
        ind_to_type = {v: k for k, v in type_to_ind.items()}
        labels = {}
        for node in G.nodes():
            node_type_idx = data.x[node].item()
            node_type = ind_to_type.get(node_type_idx, f"idx_{node_type_idx}")
            labels[node] = node_type
    
    nx.draw(G, pos, with_labels=show_labels, labels=labels, 
            node_color='lightblue', node_size=500, arrows=True, 
            font_size=6 if show_labels else 8)
    
    plt.title("AST Tree")
    plt.show()



In [None]:
train_graphs[1]

In [None]:
visualize_graph(train_graphs[1], show_labels=True)
