In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [2]:
from data.dataset import CoDeTM4

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train, val, test  = CoDeTM4('../../data/').get_dataset(['train','val','test'], columns='all', dynamic_split_sizing=False)

In [4]:
from datasets import concatenate_datasets

In [5]:
codet = concatenate_datasets([train, val, test])

In [6]:
import tree_sitter_python as tspython
import tree_sitter_cpp as tscpp
import tree_sitter_java as tsjava
from tree_sitter import Parser, Language

TS_PYTHON = Language(tspython.language())
TS_JAVA = Language(tsjava.language())
TS_CPP = Language(tscpp.language())

PYTHON_PARSER, JAVA_PARSER, CPP_PARSER = Parser(language=TS_PYTHON), Parser(language=TS_JAVA), Parser(language=TS_CPP)

In [7]:
# Install and import UniXCoder
import torch
from transformers import RobertaTokenizer, RobertaModel

# Load UniXCoder model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = RobertaTokenizer.from_pretrained("microsoft/unixcoder-base")
model = RobertaModel.from_pretrained("microsoft/unixcoder-base")
model.to(device)
model.eval()

print(f"UniXCoder model loaded successfully")

Using device: cuda
UniXCoder model loaded successfully
UniXCoder model loaded successfully


In [8]:
def get_parser(language):
    match language:
        case 'python':
            return PYTHON_PARSER
        case 'java':
            return JAVA_PARSER
        case 'cpp':
            return CPP_PARSER
    raise ValueError(f"Unsupported language: {language}")

In [9]:
def create_tree(sample, code_key='cleaned_code'):
    language = sample['language']
    parser = get_parser(language)
    tree = parser.parse(bytes(sample[code_key], 'utf-8'))
    return tree

In [10]:
def get_node_text(node, source_code_bytes):
    """Extract the text content of a node from the source code."""
    return source_code_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='ignore')

In [11]:
# Create a mapping for root node types and special embeddings
def create_root_embedding(node, language):
    """Create a special embedding for root nodes based on node type and language."""
    # Create a simple text representation for the root
    root_text = f"<ROOT_{language.upper()}_{node.type}>"
    
    # Generate embedding for this special token
    inputs = tokenizer(root_text, return_tensors="pt", max_length=32, 
                      truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :].squeeze().cpu()
    
    return embedding

In [12]:
def get_unixcoder_embedding(text, max_length=512, is_root=False, node=None, language=None):
    """Generate UniXCoder embedding for a given text."""
    if is_root and node is not None and language is not None:
        # For root nodes, create a special embedding based on node type and language
        return create_root_embedding(node, language)
    
    if not text.strip():
        # Return zero embedding for empty text
        return torch.zeros(768)  # UniXCoder base has 768 dimensions
    
    # Tokenize and truncate if necessary
    inputs = tokenizer(text, return_tensors="pt", max_length=max_length, 
                      truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        # Use the [CLS] token embedding (first token)
        embedding = outputs.last_hidden_state[:, 0, :].squeeze().cpu()
    
    return embedding

In [13]:
def get_unixcoder_embeddings_batch(texts, max_length=512, batch_size=32):
    """Generate UniXCoder embeddings for a batch of texts efficiently."""
    if not texts:
        return []
    
    all_embeddings = []
    
    # Process in batches to avoid memory issues
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        
        # Filter out empty texts and keep track of indices
        non_empty_texts = []
        text_indices = []
        for j, text in enumerate(batch_texts):
            if text.strip():
                non_empty_texts.append(text)
                text_indices.append(j)
        
        if not non_empty_texts:
            # All texts in this batch are empty
            batch_embeddings = [torch.zeros(768) for _ in batch_texts]
        else:
            # Batch tokenize non-empty texts
            inputs = tokenizer(non_empty_texts, return_tensors="pt", max_length=max_length, 
                              truncation=True, padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model(**inputs)
                embeddings = outputs.last_hidden_state[:, 0, :].cpu()
            
            # Create result list with zeros for empty texts
            batch_embeddings = []
            embedding_idx = 0
            for j, text in enumerate(batch_texts):
                if j in text_indices:
                    batch_embeddings.append(embeddings[embedding_idx])
                    embedding_idx += 1
                else:
                    batch_embeddings.append(torch.zeros(768))
        
        all_embeddings.extend(batch_embeddings)
    
    return all_embeddings

In [14]:
from torch_geometric.data import Data
from typing import List, Tuple, Dict
from torch import tensor, long as tlong
from tree_sitter import TreeCursor

def tree_to_graph_with_embeddings(cursor: TreeCursor, source_code_bytes: bytes, 
                                 language: str,
                                 id_map: Dict = None, next_id: int = 0, 
                                 edges: List[Tuple[int, int]] = None,
                                 node_embeddings: List = None,
                                 is_root_call: bool = True) -> Tuple[List[Tuple[int, int]], Dict, int, List]:
    if edges is None:
        edges = []
    if id_map is None:
        id_map = {}
    if node_embeddings is None:
        node_embeddings = []

    # Assign ID to current node
    if cursor.node not in id_map:
        id_map[cursor.node] = next_id
        # Get node text and generate embedding
        node_text = get_node_text(cursor.node, source_code_bytes)
        # Check if this is the root node (first call to the function)
        if is_root_call:
            embedding = get_unixcoder_embedding(node_text, is_root=True, node=cursor.node, language=language)
        else:
            embedding = get_unixcoder_embedding(node_text, is_root=False)
        node_embeddings.append(embedding)
        next_id += 1
    
    current_id = id_map[cursor.node]

    if cursor.goto_first_child():
        # Process first child
        if cursor.node not in id_map:
            id_map[cursor.node] = next_id
            node_text = get_node_text(cursor.node, source_code_bytes)
            embedding = get_unixcoder_embedding(node_text, is_root=False)
            node_embeddings.append(embedding)
            next_id += 1
        child_id = id_map[cursor.node]
        edges.append((current_id, child_id))
        edges, id_map, next_id, node_embeddings = tree_to_graph_with_embeddings(
            cursor, source_code_bytes, language, id_map, next_id, edges, node_embeddings, is_root_call=False)
        
        # Process siblings
        while cursor.goto_next_sibling():
            if cursor.node not in id_map:
                id_map[cursor.node] = next_id
                node_text = get_node_text(cursor.node, source_code_bytes)
                embedding = get_unixcoder_embedding(node_text, is_root=False)
                node_embeddings.append(embedding)
                next_id += 1
            child_id = id_map[cursor.node]
            edges.append((current_id, child_id))
            edges, id_map, next_id, node_embeddings = tree_to_graph_with_embeddings(
                cursor, source_code_bytes, language, id_map, next_id, edges, node_embeddings, is_root_call=False)
        
        cursor.goto_parent()

    return edges, id_map, next_id, node_embeddings

In [15]:
from tqdm import tqdm

In [16]:
def create_graph_with_unixcoder(sample):
    tree = create_tree(sample)
    source_code_bytes = sample['cleaned_code'].encode('utf-8')
    language = sample['language']
    
    edges, id_map, _, node_embeddings = tree_to_graph_with_embeddings(
        tree.walk(), source_code_bytes, language)
    
    edge_index = tensor(edges, dtype=tlong).t().contiguous() if edges else tensor([], dtype=tlong).reshape(2, 0)
    
    # Stack node embeddings to create feature matrix
    if node_embeddings:
        x = torch.stack(node_embeddings)
    else:
        # Fallback for empty graphs
        x = torch.zeros((1, 768))  # UniXCoder embedding dimension
    
    y = tensor([sample['target_binary']], dtype=tlong)
    graph_features = tensor(list(sample['features'].values()))
    
    metadata = {
        'language': sample['language'],
        'target': sample['target'],
        'target_binary': sample['target_binary'],
        'code': sample['code'],
        'cleaned_code': sample['cleaned_code']
    }
    
    data = Data(
        x=x, 
        y=y, 
        edge_index=edge_index, 
        graph_features=graph_features,
        metadata=metadata
    )
    return data

In [17]:
def collect_nodes_and_texts(cursor, source_code_bytes, language, nodes_info=None, is_root_call=True):
    """Collect all nodes and their texts in a single traversal."""
    if nodes_info is None:
        nodes_info = []
    
    node_text = get_node_text(cursor.node, source_code_bytes)
    if is_root_call:
        # Special handling for root
        root_text = f"<ROOT_{language.upper()}_{cursor.node.type}>"
        nodes_info.append((cursor.node, root_text, True))
    else:
        nodes_info.append((cursor.node, node_text, False))
    
    if cursor.goto_first_child():
        collect_nodes_and_texts(cursor, source_code_bytes, language, nodes_info, False)
        while cursor.goto_next_sibling():
            collect_nodes_and_texts(cursor, source_code_bytes, language, nodes_info, False)
        cursor.goto_parent()
    
    return nodes_info

def build_graph_structure(cursor, id_map=None, next_id=0, edges=None):
    """Build the graph structure (edges and node mapping) without embeddings."""
    if edges is None:
        edges = []
    if id_map is None:
        id_map = {}

    # Assign ID to current node
    if cursor.node not in id_map:
        id_map[cursor.node] = next_id
        next_id += 1
    
    current_id = id_map[cursor.node]

    if cursor.goto_first_child():
        # Process first child
        if cursor.node not in id_map:
            id_map[cursor.node] = next_id
            next_id += 1
        child_id = id_map[cursor.node]
        edges.append((current_id, child_id))
        edges, id_map, next_id = build_graph_structure(cursor, id_map, next_id, edges)
        
        # Process siblings
        while cursor.goto_next_sibling():
            if cursor.node not in id_map:
                id_map[cursor.node] = next_id
                next_id += 1
            child_id = id_map[cursor.node]
            edges.append((current_id, child_id))
            edges, id_map, next_id = build_graph_structure(cursor, id_map, next_id, edges)
        
        cursor.goto_parent()

    return edges, id_map, next_id

def create_graph_with_unixcoder_batch(sample):
    """Optimized graph creation using batch processing for embeddings."""
    tree = create_tree(sample)
    source_code_bytes = sample['cleaned_code'].encode('utf-8')
    language = sample['language']
    
    # First pass: collect all nodes and their texts
    nodes_info = collect_nodes_and_texts(tree.walk(), source_code_bytes, language)
    
    # Extract texts for batch processing
    texts = [text for _, text, _ in nodes_info]
    
    # Generate embeddings in batch
    embeddings = get_unixcoder_embeddings_batch(texts)
    
    # Second pass: build graph structure
    edges, id_map, _ = build_graph_structure(tree.walk())
    
    # Create node feature matrix using batch embeddings
    # Sort nodes by their IDs to ensure correct ordering
    node_embedding_pairs = []
    for i, (node, _, _) in enumerate(nodes_info):
        node_id = id_map[node]
        node_embedding_pairs.append((node_id, embeddings[i]))
    
    # Sort by node ID and extract embeddings in correct order
    node_embedding_pairs.sort(key=lambda x: x[0])
    ordered_embeddings = [emb for _, emb in node_embedding_pairs]
    
    edge_index = tensor(edges, dtype=tlong).t().contiguous() if edges else tensor([], dtype=tlong).reshape(2, 0)
    x = torch.stack(ordered_embeddings) if ordered_embeddings else torch.zeros((1, 768))
    
    y = tensor([sample['target_binary']], dtype=tlong)
    graph_features = tensor(list(sample['features'].values()))
    
    metadata = {
        'language': sample['language'],
        'target': sample['target'],
        'target_binary': sample['target_binary'],
        'code': sample['code'],
        'cleaned_code': sample['cleaned_code']
    }
    
    return Data(x=x, y=y, edge_index=edge_index, graph_features=graph_features, metadata=metadata)

In [18]:
def create_graphs_with_unixcoder_optimized(dataset, desc_keyword):
    """Create graphs using optimized batch processing for embeddings."""
    graphs = []

    for i, sample in enumerate(tqdm(dataset, desc=f'Creating {desc_keyword} graphs with UniXCoder (optimized)')):
        try:
            data = create_graph_with_unixcoder_batch(sample)
            graphs.append(data)
        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            # Skip problematic samples
            continue

    return graphs

# Keep the old function for comparison
def create_graphs_with_unixcoder(dataset, desc_keyword):
    """Original (slower) graph creation function."""
    graphs = []

    for i, sample in enumerate(tqdm(dataset, desc=f'Creating {desc_keyword} graphs with UniXCoder')):
        try:
            data = create_graph_with_unixcoder(sample)
            graphs.append(data)
        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            # Skip problematic samples
            continue

    return graphs

In [19]:
# Test the optimization on a small subset
import time

print("Testing optimization on small subset...")
test_subset = train.select(range(10))  # Test with 10 samples

# Test original method
print("Testing original method...")
start_time = time.time()
graphs_original = create_graphs_with_unixcoder(test_subset, 'test_original')
original_time = time.time() - start_time

# Test optimized method
print("Testing optimized method...")
start_time = time.time()
graphs_optimized = create_graphs_with_unixcoder_optimized(test_subset, 'test_optimized')
optimized_time = time.time() - start_time

print(f"\nPerformance comparison:")
print(f"Original method: {original_time:.2f} seconds for {len(graphs_original)} graphs")
print(f"Optimized method: {optimized_time:.2f} seconds for {len(graphs_optimized)} graphs")
print(f"Speedup: {original_time/optimized_time:.2f}x")

# Verify that results are similar
if len(graphs_original) > 0 and len(graphs_optimized) > 0:
    print(f"\nVerification:")
    print(f"Original graph shape: {graphs_original[0].x.shape}")
    print(f"Optimized graph shape: {graphs_optimized[0].x.shape}")
    print(f"Shapes match: {graphs_original[0].x.shape == graphs_optimized[0].x.shape}")

del graphs_original, graphs_optimized  # Free memory

Testing optimization on small subset...
Testing original method...


Creating test_original graphs with UniXCoder: 100%|██████████| 10/10 [00:28<00:00,  2.82s/it]


Testing optimized method...


Creating test_optimized graphs with UniXCoder (optimized): 100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Performance comparison:
Original method: 28.23 seconds for 10 graphs
Optimized method: 6.14 seconds for 10 graphs
Speedup: 4.60x

Verification:
Original graph shape: torch.Size([2067, 768])
Optimized graph shape: torch.Size([2067, 768])
Shapes match: True





In [20]:
# Create directories if they don't exist
import os
os.makedirs('../../data/codet_graphs', exist_ok=True)

from torch_geometric.data import Data
from torch import save

# Option to process subset first (change to False for full dataset)
USE_SUBSET = False
SUBSET_SIZE = 5000  # Process first 5000 samples for testing

if USE_SUBSET:
    print(f"Processing subset of {SUBSET_SIZE} samples first...")
    train_subset = train.select(range(min(SUBSET_SIZE, len(train))))
    val_subset = val.select(range(min(SUBSET_SIZE, len(val))))  
    test_subset = test.select(range(min(SUBSET_SIZE, len(test))))
    
    print("Creating train graphs with UniXCoder embeddings (subset)...")
    train_graphs_unixcoder = create_graphs_with_unixcoder_optimized(train_subset, 'train_subset')
    save(train_graphs_unixcoder, '../../data/codet_graphs/train_graphs_unixcoder_subset.pt')
    print(f"Saved {len(train_graphs_unixcoder)} train graphs (subset)")
    del train_subset, train_graphs_unixcoder
    
    print("Creating validation graphs with UniXCoder embeddings (subset)...")
    val_graphs_unixcoder = create_graphs_with_unixcoder_optimized(val_subset, 'val_subset')
    save(val_graphs_unixcoder, '../../data/codet_graphs/val_graphs_unixcoder_subset.pt')
    print(f"Saved {len(val_graphs_unixcoder)} validation graphs (subset)")
    del val_subset, val_graphs_unixcoder
    
    print("Creating test graphs with UniXCoder embeddings (subset)...")
    test_graphs_unixcoder = create_graphs_with_unixcoder_optimized(test_subset, 'test_subset')
    save(test_graphs_unixcoder, '../../data/codet_graphs/test_graphs_unixcoder_subset.pt')
    print(f"Saved {len(test_graphs_unixcoder)} test graphs (subset)")
    del test_subset, test_graphs_unixcoder
    
    print("Subset processing completed!")
    
else:
    print("Creating train graphs with UniXCoder embeddings (full dataset - optimized)...")
    train_graphs_unixcoder = create_graphs_with_unixcoder_optimized(train, 'train')
    save(train_graphs_unixcoder, '../../data/codet_graphs/train_graphs_unixcoder.pt')
    print(f"Saved {len(train_graphs_unixcoder)} train graphs")
    del train, train_graphs_unixcoder
    
    print("Creating validation graphs with UniXCoder embeddings (full dataset - optimized)...")
    val_graphs_unixcoder = create_graphs_with_unixcoder_optimized(val, 'val')
    save(val_graphs_unixcoder, '../../data/codet_graphs/val_graphs_unixcoder.pt')
    print(f"Saved {len(val_graphs_unixcoder)} validation graphs")
    del val, val_graphs_unixcoder
    
    print("Creating test graphs with UniXCoder embeddings (full dataset - optimized)...")
    test_graphs_unixcoder = create_graphs_with_unixcoder_optimized(test, 'test')
    save(test_graphs_unixcoder, '../../data/codet_graphs/test_graphs_unixcoder.pt')
    print(f"Saved {len(test_graphs_unixcoder)} test graphs")
    del test, test_graphs_unixcoder
    
    print("All UniXCoder graph datasets have been created and saved!")

Creating train graphs with UniXCoder embeddings (full dataset - optimized)...


Creating train graphs with UniXCoder (optimized):   0%|          | 161/405069 [00:49<34:30:21,  3.26it/s]



KeyboardInterrupt: 

In [None]:
# Test loading the created graphs
from torch import load

print("Testing loading of created graphs...")
train_graphs_unixcoder = load('../../data/codet_graphs/train_graphs_unixcoder.pt', weights_only=False)
print(f"Loaded {len(train_graphs_unixcoder)} training graphs")
print(f"First graph shape: {train_graphs_unixcoder[0].x.shape}")
print(f"Embedding dimension: {train_graphs_unixcoder[0].x.shape[1]}")

val_graphs_unixcoder = load('../../data/codet_graphs/val_graphs_unixcoder.pt', weights_only=False)
print(f"Loaded {len(val_graphs_unixcoder)} validation graphs")

test_graphs_unixcoder = load('../../data/codet_graphs/test_graphs_unixcoder.pt', weights_only=False)
print(f"Loaded {len(test_graphs_unixcoder)} test graphs")

print("All graphs loaded successfully!")

In [None]:
# Visualize a sample graph with UniXCoder embeddings
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

def visualize_unixcoder_graph(data, figsize=(15, 10)):
    # Convert to NetworkX graph
    G = to_networkx(data, to_undirected=False)
    
    plt.figure(figsize=figsize)
    
    # Create tree layout manually
    pos = {}
    
    # Find root (node with no incoming edges)
    root = 0
    for node in G.nodes():
        if G.in_degree(node) == 0:
            root = node
            break
    
    # Assign positions level by level
    levels = {}
    queue = [(root, 0)]
    
    while queue:
        node, level = queue.pop(0)
        levels[node] = level
        
        # Add children to next level
        for child in G.successors(node):
            queue.append((child, level + 1))
    
    # Group nodes by level
    level_groups = {}
    for node, level in levels.items():
        if level not in level_groups:
            level_groups[level] = []
        level_groups[level].append(node)
    
    # Position nodes
    for level, nodes in level_groups.items():
        for i, node in enumerate(nodes):
            x = i - len(nodes) / 2  # Center nodes horizontally
            y = -level  # Higher levels at top
            pos[node] = (x, y)
    
    nx.draw(G, pos, with_labels=True, 
            node_color='lightgreen', node_size=300, arrows=True, 
            font_size=8)
    
    plt.title(f"AST Tree with UniXCoder Embeddings\nNodes: {data.x.shape[0]}, Embedding dim: {data.x.shape[1]}")
    plt.show()
    
    # Print some metadata
    print(f"Language: {data.metadata['language']}")
    print(f"Target: {data.metadata['target']}")
    print(f"Code snippet (first 200 chars):\n{data.metadata['cleaned_code'][:200]}...")

# Visualize the first graph
if len(train_graphs_unixcoder) > 0:
    visualize_unixcoder_graph(train_graphs_unixcoder[0])