In [8]:
import os
import sys
import json
import math
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import javalang
import multiprocessing as mp

# Avoid tokenizers fork warning by fixing parallelism behavior early.
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

from transformers import AutoTokenizer, AutoModel

sys.path.insert(0, "E:/CodeBuggy")

from pipeline.gumtree_diff import GumTreeDiff, EditType

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import RGCNConv, global_mean_pool


torch.manual_seed(42)
np.random.seed(42)

In [9]:
DATA_PATH = "inputs/megadiff_single_function.parquet"
GUMTREE_PATH = "E:/CodeBuggy/gumtree-4.0.0-beta4/bin/gumtree.bat"
MODEL_NAME = "microsoft/graphcodebert-base"
OUTPUT_DIR = "output/rgcn_graphs"
MAX_SAMPLES = min(71150, 71150) # 71150 is num of samples
USE_CODE_EMBEDDINGS = True
NODE_TYPE_EMB_DIM = 64
DIFF_FEATURE_DIM = 6
LABEL_PARENTS = False

EMB_BATCH_SIZE = 8
SAVE_CHUNK_SIZE = 10000
USE_MULTIPROCESSING = True
NUM_WORKERS = max(1, (os.cpu_count() or 2) - 1)
MP_CHUNKSIZE = 8
os.makedirs(OUTPUT_DIR, exist_ok=True)

df = pd.read_parquet(DATA_PATH)

print(f"Loaded {len(df)} valid samples")
df.head()

Loaded 71150 valid samples


Unnamed: 0,diff,buggy_function,fixed_function
0,diff --git a/src/main/java/hudson/remoting/Pin...,"private void ping() throws IOException, In...","private void ping() throws IOException, In..."
1,diff --git a/choco-parser/src/main/java/parser...,public synchronized boolean newSol(int val...,public synchronized boolean newSol(int val...
2,diff --git a/opentripplanner-routing/src/main/...,public State traverse(State s0) {\n ...,public State traverse(State s0) {\n ...
3,diff --git a/src/java/davmail/ldap/LdapConnect...,public void run() {\n try {...,public void run() {\n try {...
4,diff --git a/src/com/orangeleap/tangerine/web/...,public void addListFieldsToMap(HttpServlet...,public void addListFieldsToMap(HttpServlet...


In [10]:
RELATIONS = [
    "AST_CHILD",
    "AST_PARENT",
    "CFG_NEXT",
    "CFG_TRUE",
    "CFG_FALSE",
    "CFG_LOOP",
    "DEF_USE",
    "USE_DEF",
    "DIFF_PARENT",
    "DIFF_SIBLING",
]
RELATION_TO_ID = {r: i for i, r in enumerate(RELATIONS)}

STATEMENT_NODES = {
    "StatementExpression",
    "ReturnStatement",
    "IfStatement",
    "ForStatement",
    "WhileStatement",
    "DoStatement",
    "SwitchStatement",
    "TryStatement",
    "ThrowStatement",
    "BreakStatement",
    "ContinueStatement",
    "BlockStatement",
}

WRAP_TEMPLATE = """public class Dummy {{
    {method_code}
}}"""


def wrap_method(method_code: str) -> str:
    return WRAP_TEMPLATE.format(method_code=method_code)


def iter_children(node):
    for child in node.children:
        if child is None:
            continue
        if isinstance(child, list):
            for item in child:
                if item is not None:
                    yield item
        else:
            yield child


def get_node_label(node) -> str | None:
    if hasattr(node, "name") and node.name:
        return str(node.name)
    if hasattr(node, "member") and node.member:
        return str(node.member)
    if hasattr(node, "value") and node.value is not None:
        return str(node.value)
    if hasattr(node, "operator") and node.operator:
        return str(node.operator)
    if hasattr(node, "type") and isinstance(node.type, str):
        return str(node.type)
    return None


def compute_line_offsets(code: str) -> list[int]:
    offsets = [0]
    for idx, ch in enumerate(code):
        if ch == "\n":
            offsets.append(idx + 1)
    return offsets


def line_col_to_offset(code: str, line: int, col: int) -> int | None:
    if line <= 0 or col <= 0:
        return None
    line_offsets = compute_line_offsets(code)
    if line > len(line_offsets):
        return None
    return line_offsets[line - 1] + (col - 1)


def offset_to_line_col(code: str, offset: int) -> tuple[int, int]:
    if offset < 0:
        return (1, 1)
    line_offsets = compute_line_offsets(code)
    line = 1
    for i, start in enumerate(line_offsets, 1):
        if start <= offset:
            line = i
        else:
            break
    col = offset - line_offsets[line - 1] + 1
    return (line, col)

In [11]:
def build_ast_graph(code: str):
    tree = javalang.parse.parse(wrap_method(code))

    nodes = []
    parents = []
    children = []
    id_to_index = {}

    def is_wrapper(node) -> bool:
        if isinstance(node, javalang.tree.CompilationUnit):
            return True
        if isinstance(node, javalang.tree.ClassDeclaration) and getattr(node, "name", None) == "Dummy":
            return True
        return False

    def visit(node, parent_idx: int | None):
        if not isinstance(node, javalang.tree.Node):
            return
        if is_wrapper(node):
            for child in iter_children(node):
                visit(child, parent_idx)
            return

        idx = len(nodes)
        id_to_index[id(node)] = idx

        label = get_node_label(node)
        line = None
        col = None
        if getattr(node, "position", None):
            line = node.position.line
            col = node.position.column
        start_pos = line_col_to_offset(code, line, col) if line and col else None
        end_pos = start_pos + len(label) if (start_pos is not None and label) else start_pos

        nodes.append(
            {
                "raw": node,
                "node_type": node.__class__.__name__,
                "label": label,
                "line": line,
                "col": col,
                "start_pos": start_pos,
                "end_pos": end_pos,
            }
        )
        parents.append(parent_idx)
        children.append([])
        if parent_idx is not None:
            children[parent_idx].append(idx)

        for child in iter_children(node):
            visit(child, idx)

    visit(tree, None)

    edges = []
    edge_types = []

    def add_edge(src: int, dst: int, rel: str):
        edges.append((src, dst))
        edge_types.append(RELATION_TO_ID[rel])

    # AST edges
    for parent_idx, child_list in enumerate(children):
        for child_idx in child_list:
            add_edge(parent_idx, child_idx, "AST_CHILD")
            add_edge(child_idx, parent_idx, "AST_PARENT")

    # CFG edges (heuristic)
    for parent_idx, child_list in enumerate(children):
        stmt_children = [c for c in child_list if nodes[c]["node_type"] in STATEMENT_NODES]
        for a, b in zip(stmt_children, stmt_children[1:]):
            add_edge(a, b, "CFG_NEXT")

    for idx, node in enumerate(nodes):
        node_type = node["node_type"]
        raw = node["raw"]

        if node_type == "IfStatement":
            then_node = getattr(raw, "then_statement", None)
            else_node = getattr(raw, "else_statement", None)
            then_idx = id_to_index.get(id(then_node))
            else_idx = id_to_index.get(id(else_node))
            if then_idx is not None:
                add_edge(idx, then_idx, "CFG_TRUE")
            if else_idx is not None:
                add_edge(idx, else_idx, "CFG_FALSE")

        if node_type in {"ForStatement", "WhileStatement", "DoStatement"}:
            body = getattr(raw, "body", None)
            body_idx = id_to_index.get(id(body))
            if body_idx is not None:
                add_edge(idx, body_idx, "CFG_LOOP")

    # DFG edges (simple def-use)
    last_def = {}

    def extract_assigned_name(raw_node) -> str | None:
        target = getattr(raw_node, "expressionl", None) or getattr(raw_node, "left", None)
        if target is None:
            return None
        if hasattr(target, "member") and target.member:
            return str(target.member)
        if hasattr(target, "name") and target.name:
            return str(target.name)
        return None

    for idx, node in enumerate(nodes):
        node_type = node["node_type"]
        label = node["label"]
        raw = node["raw"]

        if node_type in {"VariableDeclarator", "FormalParameter"} and label:
            last_def[label] = idx

        if node_type == "Assignment":
            assigned = extract_assigned_name(raw)
            if assigned:
                last_def[assigned] = idx

        if node_type == "MemberReference" and label:
            if label in last_def:
                def_idx = last_def[label]
                add_edge(def_idx, idx, "DEF_USE")
                add_edge(idx, def_idx, "USE_DEF")

    return nodes, parents, children, edges, edge_types

In [12]:
def match_actions_to_nodes(code: str, nodes: list[dict], parents: list[int | None], actions):
    action_map: dict[int, EditType] = {}

    for action in actions:
        node_type = action.node.node_type
        label = action.node.label
        pos = action.node.position
        line = None
        col = None
        if pos is not None:
            line, col = offset_to_line_col(code, pos[0])

        candidates = [i for i, n in enumerate(nodes) if n["node_type"] == node_type]
        if label:
            label_candidates = [i for i in candidates if nodes[i]["label"] == label]
            if label_candidates:
                candidates = label_candidates
        if line:
            line_candidates = [i for i in candidates if nodes[i]["line"] == line]
            if line_candidates:
                candidates = line_candidates

        if not candidates:
            continue

        if col is not None:
            candidates.sort(key=lambda i: abs((nodes[i]["col"] or col) - col))

        matched_idx = candidates[0]
        action_map[matched_idx] = action.action_type

    # subtree changed
    subtree_changed = [0] * len(nodes)
    for idx in action_map.keys():
        cur = parents[idx]
        while cur is not None:
            subtree_changed[cur] = 1
            cur = parents[cur]

    return action_map, subtree_changed


def build_diff_features(action_map: dict[int, EditType], subtree_changed: list[int], num_nodes: int):
    diff_feats = []
    labels = []
    for idx in range(num_nodes):
        action = action_map.get(idx)
        is_diff = 1 if action is not None else 0
        action_none = 1 if action is None else 0
        action_update = 1 if action == EditType.UPDATE else 0
        action_delete = 1 if action == EditType.DELETE else 0
        action_move = 1 if action == EditType.MOVE else 0
        diff_feats.append([
            is_diff,
            action_none,
            action_update,
            action_delete,
            action_move,
            subtree_changed[idx],
        ])
        labels.append(1 if action in {EditType.UPDATE, EditType.DELETE} else 0)
    return torch.tensor(diff_feats, dtype=torch.float), torch.tensor(labels, dtype=torch.long)


def add_diff_edges(children, action_map, edges, edge_types):
    changed_nodes = set(action_map.keys())
    for idx in changed_nodes:
        parent = None
        for p, child_list in enumerate(children):
            if idx in child_list:
                parent = p
                break
        if parent is not None:
            edges.append((idx, parent))
            edge_types.append(RELATION_TO_ID["DIFF_PARENT"])
            for sibling in children[parent]:
                if sibling == idx:
                    continue
                edges.append((idx, sibling))
                edge_types.append(RELATION_TO_ID["DIFF_SIBLING"])
                edges.append((sibling, idx))
                edge_types.append(RELATION_TO_ID["DIFF_SIBLING"])

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if USE_CODE_EMBEDDINGS:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    encoder = AutoModel.from_pretrained(MODEL_NAME).to(device)
    encoder.eval()
else:
    tokenizer = None
    encoder = None


def compute_code_embeddings(code: str, nodes: list[dict], children: list[list[int]]):
    num_nodes = len(nodes)
    if not USE_CODE_EMBEDDINGS:
        return torch.zeros((num_nodes, 768), dtype=torch.float)

    enc = tokenizer(
        code,
        return_tensors="pt",
        return_offsets_mapping=True,
        truncation=True,
        max_length=512,
    )
    offsets = enc.pop("offset_mapping")[0].tolist()
    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.inference_mode():
        outputs = encoder(**enc)
    token_embs = outputs.last_hidden_state[0].cpu()

    node_embs = torch.zeros((num_nodes, token_embs.shape[-1]), dtype=torch.float)

    for idx, node in enumerate(nodes):
        start = node["start_pos"]
        end = node["end_pos"]
        if start is None or end is None or start == end:
            continue
        token_idxs = [
            t for t, (s, e) in enumerate(offsets)
            if not (s == 0 and e == 0) and s < end and e > start
        ]
        if token_idxs:
            node_embs[idx] = token_embs[token_idxs].mean(dim=0)

    # Aggregate from children for non-leaf nodes
    for idx in reversed(range(num_nodes)):
        if torch.all(node_embs[idx] == 0) and children[idx]:
            child_embs = [node_embs[c] for c in children[idx] if torch.any(node_embs[c] != 0)]
            if child_embs:
                node_embs[idx] = torch.stack(child_embs).mean(dim=0)

    return node_embs


def compute_code_embeddings_batch(
    codes: list[str],
    nodes_list: list[list[dict]],
    children_list: list[list[list[int]]],
    batch_size: int = 8,
):
    if not USE_CODE_EMBEDDINGS:
        return [torch.zeros((len(nodes), 768), dtype=torch.float) for nodes in nodes_list]

    all_embs = []
    for start in range(0, len(codes), batch_size):
        batch_codes = codes[start : start + batch_size]
        batch_nodes = nodes_list[start : start + batch_size]
        batch_children = children_list[start : start + batch_size]

        enc = tokenizer(
            batch_codes,
            return_tensors="pt",
            return_offsets_mapping=True,
            truncation=True,
            max_length=512,
            padding=True,
        )
        offsets_batch = enc.pop("offset_mapping").tolist()
        enc = {k: v.to(device) for k, v in enc.items()}

        with torch.inference_mode():
            outputs = encoder(**enc)
        token_embs = outputs.last_hidden_state.cpu()

        for i, nodes in enumerate(batch_nodes):
            offsets = offsets_batch[i]
            token_embs_i = token_embs[i]
            num_nodes = len(nodes)

            node_embs = torch.zeros((num_nodes, token_embs_i.shape[-1]), dtype=torch.float)

            for idx, node in enumerate(nodes):
                start_pos = node["start_pos"]
                end_pos = node["end_pos"]
                if start_pos is None or end_pos is None or start_pos == end_pos:
                    continue
                token_idxs = [
                    t for t, (s, e) in enumerate(offsets)
                    if not (s == 0 and e == 0) and s < end_pos and e > start_pos
                ]
                if token_idxs:
                    node_embs[idx] = token_embs_i[token_idxs].mean(dim=0)

            # Aggregate from children for non-leaf nodes
            for idx in reversed(range(num_nodes)):
                if torch.all(node_embs[idx] == 0) and batch_children[i][idx]:
                    child_embs = [
                        node_embs[c]
                        for c in batch_children[i][idx]
                        if torch.any(node_embs[c] != 0)
                    ]
                    if child_embs:
                        node_embs[idx] = torch.stack(child_embs).mean(dim=0)

            all_embs.append(node_embs)

    return all_embs

Using device: cpu


Some weights of RobertaModel were not initialized from the model checkpoint at microsoft/graphcodebert-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
def build_node_type_vocab(codes: list[str]):
    node_types = set()
    for code in tqdm(codes, desc="Collect node types"):
        try:
            nodes, _, _, _, _ = build_ast_graph(code)
            node_types.update(n["node_type"] for n in nodes)
        except Exception:
            continue
    return {t: i for i, t in enumerate(sorted(node_types))}


# node_type_to_id = build_node_type_vocab(df["buggy_function"].tolist())
# if "UNK" not in node_type_to_id:
#     node_type_to_id["UNK"] = len(node_type_to_id)
# print(f"Node types: {len(node_type_to_id)}")

import joblib
# joblib.dump(node_type_to_id, "output/node_type_to_id.joblib")
node_type_to_id = joblib.load("output/node_type_to_id.joblib")
print(f"Node types: {len(node_type_to_id)}")


Node types: 59


In [15]:
def add_diff_edges(children, parents, action_map, edges, edge_types):
    changed_nodes = set(action_map.keys())
    for idx in changed_nodes:
        parent = parents[idx]
        if parent is not None:
            edges.append((idx, parent))
            edge_types.append(RELATION_TO_ID["DIFF_PARENT"])
            for sibling in children[parent]:
                if sibling == idx:
                    continue
                edges.append((idx, sibling))
                edge_types.append(RELATION_TO_ID["DIFF_SIBLING"])
                edges.append((sibling, idx))
                edge_types.append(RELATION_TO_ID["DIFF_SIBLING"])


def get_node_type_id(node_type: str) -> int:
    return node_type_to_id.get(node_type, node_type_to_id["UNK"])


gumtree_diff = GumTreeDiff(gumtree_path=GUMTREE_PATH)


def init_gumtree_worker(gumtree_path: str):
    global gumtree_diff
    gumtree_diff = GumTreeDiff(gumtree_path=gumtree_path)


def build_graph_parts_worker(args):
    buggy_code, fixed_code, method_id = args
    try:
        return build_graph_parts(buggy_code, fixed_code, method_id)
    except Exception:
        return None


def build_graph_parts(buggy_code: str, fixed_code: str, method_id: str):
    nodes, parents, children, edges, edge_types = build_ast_graph(buggy_code)

    diff_result = gumtree_diff.diff(buggy_code, fixed_code)
    action_map, subtree_changed = match_actions_to_nodes(buggy_code, nodes, parents, diff_result.actions)

    diff_feats, labels = build_diff_features(action_map, subtree_changed, len(nodes))
    add_diff_edges(children, parents, action_map, edges, edge_types)

    if LABEL_PARENTS:
        for idx, val in enumerate(labels.tolist()):
            if val == 1 and parents[idx] is not None:
                labels[parents[idx]] = 1

    node_type_ids = torch.tensor([get_node_type_id(n["node_type"]) for n in nodes], dtype=torch.long)

    return {
        "buggy_code": buggy_code,
        "nodes": nodes,
        "parents": parents,
        "children": children,
        "edges": edges,
        "edge_types": edge_types,
        "diff_feats": diff_feats,
        "labels": labels,
        "node_type_ids": node_type_ids,
        "method_id": method_id,
    }


def build_graph_sample(
    buggy_code: str,
    fixed_code: str,
    method_id: str,
    code_embs: torch.Tensor | None = None,
):
    parts = build_graph_parts(buggy_code, fixed_code, method_id)

    if code_embs is None:
        code_embs = compute_code_embeddings(
            parts["buggy_code"], parts["nodes"], parts["children"]
        )

    x = torch.cat([code_embs, parts["diff_feats"]], dim=1)

    if parts["edges"]:
        edge_index = torch.tensor(parts["edges"], dtype=torch.long).t().contiguous()
        edge_type = torch.tensor(parts["edge_types"], dtype=torch.long)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_type = torch.empty((0,), dtype=torch.long)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_type=edge_type,
        y=parts["labels"],
        node_type_ids=parts["node_type_ids"],
        method_id=parts["method_id"],
    )
    return data

In [16]:
sample = df.iloc[0]
method_id = str(sample.get("method_id", 0))

example_graph = build_graph_sample(
    sample["buggy_function"],
    sample["fixed_function"],
    method_id=method_id,
)

print(example_graph)
print(f"x: {example_graph.x.shape}")
print(f"edge_index: {example_graph.edge_index.shape}")
print(f"edge_type: {example_graph.edge_type.shape}")
print(f"y: {example_graph.y.shape}, positives={int(example_graph.y.sum())}")

Data(x=[66, 774], edge_index=[2, 155], y=[66], edge_type=[155], node_type_ids=[66], method_id='0')
x: torch.Size([66, 774])
edge_index: torch.Size([2, 155])
edge_type: torch.Size([155])
y: torch.Size([66]), positives=1


In [17]:
# data_list = []

# subset = df.head(MAX_SAMPLES)
# batch_parts = []
# chunk_idx = 0
# num_saved = 0
# saved_files = []

# for idx, row in enumerate(
#     tqdm(subset.itertuples(index=False), total=len(subset), desc="Building graphs")
# ):
#     try:
#         method_id = str(getattr(row, "method_id", idx))
#         parts = build_graph_parts(row.buggy_function, row.fixed_function, method_id)
#         batch_parts.append(parts)

#         if len(batch_parts) >= EMB_BATCH_SIZE:
#             code_embs_list = compute_code_embeddings_batch(
#                 [p["buggy_code"] for p in batch_parts],
#                 [p["nodes"] for p in batch_parts],
#                 [p["children"] for p in batch_parts],
#                 batch_size=len(batch_parts),
#             )
#             for parts, code_embs in zip(batch_parts, code_embs_list):
#                 x = torch.cat([code_embs, parts["diff_feats"]], dim=1)

#                 if parts["edges"]:
#                     edge_index = torch.tensor(parts["edges"], dtype=torch.long).t().contiguous()
#                     edge_type = torch.tensor(parts["edge_types"], dtype=torch.long)
#                 else:
#                     edge_index = torch.empty((2, 0), dtype=torch.long)
#                     edge_type = torch.empty((0,), dtype=torch.long)

#                 data_list.append(
#                     Data(
#                         x=x,
#                         edge_index=edge_index,
#                         edge_type=edge_type,
#                         y=parts["labels"],
#                         node_type_ids=parts["node_type_ids"],
#                         method_id=parts["method_id"],
#                     )
#                 )
#             batch_parts = []

#         if len(data_list) >= SAVE_CHUNK_SIZE:
#             chunk_idx += 1
#             output_path = os.path.join(OUTPUT_DIR, f"rgcn_graphs_part{chunk_idx}.pt")
#             torch.save(data_list, output_path)
#             saved_files.append(output_path)
#             num_saved += len(data_list)
#             data_list = []
#     except Exception as exc:
#         print(f"Skip idx={idx}: {exc}")
#         continue

# if batch_parts:
#     code_embs_list = compute_code_embeddings_batch(
#         [p["buggy_code"] for p in batch_parts],
#         [p["nodes"] for p in batch_parts],
#         [p["children"] for p in batch_parts],
#         batch_size=len(batch_parts),
#     )
#     for parts, code_embs in zip(batch_parts, code_embs_list):
#         x = torch.cat([code_embs, parts["diff_feats"]], dim=1)

#         if parts["edges"]:
#             edge_index = torch.tensor(parts["edges"], dtype=torch.long).t().contiguous()
#             edge_type = torch.tensor(parts["edge_types"], dtype=torch.long)
#         else:
#             edge_index = torch.empty((2, 0), dtype=torch.long)
#             edge_type = torch.empty((0,), dtype=torch.long)

#         data_list.append(
#             Data(
#                 x=x,
#                 edge_index=edge_index,
#                 edge_type=edge_type,
#                 y=parts["labels"],
#                 node_type_ids=parts["node_type_ids"],
#                 method_id=parts["method_id"],
#             )
#         )
#     batch_parts = []

# if data_list:
#     chunk_idx += 1
#     output_path = os.path.join(OUTPUT_DIR, f"rgcn_graphs_part{chunk_idx}.pt")
#     torch.save(data_list, output_path)
#     saved_files.append(output_path)
#     num_saved += len(data_list)
#     data_list = []

# meta_path = os.path.join(OUTPUT_DIR, "rgcn_graphs_meta.json")
# with open(meta_path, "w", encoding="utf-8") as f:
#     json.dump(
#         {
#             "num_graphs": num_saved,
#             "relations": RELATIONS,
#             "relation_to_id": RELATION_TO_ID,
#             "node_type_to_id": node_type_to_id,
#             "node_type_emb_dim": NODE_TYPE_EMB_DIM,
#             "diff_feature_dim": DIFF_FEATURE_DIM,
#             "model_name": MODEL_NAME,
#             "parts": [os.path.basename(p) for p in saved_files],
#             "chunk_size": SAVE_CHUNK_SIZE,
#         },
#         f,
#         indent=2,
#         ensure_ascii=True,
#     )

# print(f"Saved {num_saved} graphs into {len(saved_files)} parts")
# print(f"Saved metadata to {meta_path}")

In [None]:
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
import glob

candidates = sorted(glob.glob(os.path.join(OUTPUT_DIR, "rgcn_graphs_part*.pt")))
if not candidates:
    raise FileNotFoundError("No graph files found in output/rgcn_graphs")

data_list = []
for path in candidates:
    # Trust local files to load full objects (PyTorch 2.6 defaults to weights_only=True)
    part = torch.load(path, weights_only=False)
    data_list.extend(part)

print(f"Loaded {len(data_list)} graphs from {len(candidates)} files")

# Graph-level labels for bug detection
for g in data_list:
    g.graph_y = torch.tensor([1.0 if int(g.y.sum()) > 0 else 0.0], dtype=torch.float)

# Train/val/test split
num_total = len(data_list)
train_size = int(0.8 * num_total)
val_size = int(0.1 * num_total)
test_size = num_total - train_size - val_size
train_ds, val_ds, test_ds = random_split(
    data_list,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42),
)

BATCH_SIZE = 8
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

print(f"Split: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")

Loaded 20096 graphs from 2 files


In [19]:
torch.save(data_list, 'output/megadiff_graphs_list.pt')

In [25]:
from torch_geometric.data import InMemoryDataset

# # data_list là list Data
# print(type(data_list), len(data_list))

# data, slices = InMemoryDataset.collate(data_list)
# torch.save({"data": data, "slices": slices}, "output/megadiff_graphs.pt")


In [26]:
# load lại
payload = torch.load("output/megadiff_graphs.pt", weights_only=False)
data, slices = payload["data"], payload["slices"]

class MegadiffGraphDataset(InMemoryDataset):
    def __init__(self, data, slices):
        super().__init__(".")
        self.data = data
        self.slices = slices

dataset = MegadiffGraphDataset(data, slices)
print(len(dataset), dataset[0])

20096 Data(x=[66, 774], edge_index=[2, 152], y=[66], edge_type=[152], node_type_ids=[66], method_id='0', graph_y=[1])


In [14]:
def count_labels(samples, label_attr: str):
    pos = 0
    total = 0
    for g in samples:
        labels = getattr(g, label_attr).view(-1)
        pos += int(labels.sum().item())
        total += int(labels.numel())
    neg = total - pos
    ratio = (neg / pos) if pos > 0 else float("inf")
    return pos, neg, total, ratio

# Overall (node-level + graph-level)
node_pos, node_neg, node_total, node_ratio = count_labels(data_list, "y")
graph_pos, graph_neg, graph_total, graph_ratio = count_labels(data_list, "graph_y")

print("Overall imbalance")
print(f"Node labels: pos={node_pos} neg={node_neg} total={node_total} neg/pos={node_ratio:.2f}")
print(f"Graph labels: pos={graph_pos} neg={graph_neg} total={graph_total} neg/pos={graph_ratio:.2f}")

# Per-split
for name, ds in [("train", train_ds), ("val", val_ds), ("test", test_ds)]:
    npos, nneg, ntotal, nratio = count_labels(ds, "y")
    gpos, gneg, gtotal, gratio = count_labels(ds, "graph_y")
    print(f"\n{name} imbalance")
    print(f"  Node: pos={npos} neg={nneg} total={ntotal} neg/pos={nratio:.2f}")
    print(f"  Graph: pos={gpos} neg={gneg} total={gtotal} neg/pos={gratio:.2f}")

Overall imbalance
Node labels: pos=3170 neg=5168901 total=5172071 neg/pos=1630.57
Graph labels: pos=2257 neg=17839 total=20096 neg/pos=7.90

train imbalance
  Node: pos=2567 neg=4126138 total=4128705 neg/pos=1607.38
  Graph: pos=1838 neg=14238 total=16076 neg/pos=7.75

val imbalance
  Node: pos=313 neg=519548 total=519861 neg/pos=1659.90
  Graph: pos=214 neg=1795 total=2009 neg/pos=8.39

test imbalance
  Node: pos=290 neg=523215 total=523505 neg/pos=1804.19
  Graph: pos=205 neg=1806 total=2011 neg/pos=8.81


In [12]:
from torch_geometric.nn.conv import RGCNConv

class RGCNDetector(torch.nn.Module):
    def __init__(
        self,
        base_in_dim: int,
        hidden_dim: int,
        num_relations: int,
        num_node_types: int,
        node_type_emb_dim: int,
        num_layers: int = 2,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.base_in_dim = base_in_dim
        self.node_type_emb_dim = node_type_emb_dim
        self.node_type_emb = torch.nn.Embedding(num_node_types, node_type_emb_dim)
        conv_in_dim = base_in_dim + node_type_emb_dim

        self.convs = torch.nn.ModuleList()
        self.convs.append(RGCNConv(conv_in_dim, hidden_dim, num_relations=num_relations))
        for _ in range(num_layers - 1):
            self.convs.append(RGCNConv(hidden_dim, hidden_dim, num_relations=num_relations))
        self.dropout = torch.nn.Dropout(dropout)
        self.node_head = torch.nn.Linear(hidden_dim, 1)
        self.graph_head = torch.nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x, edge_index, edge_type, batch = data.x, data.edge_index, data.edge_type, data.batch

        if x.shape[1] == self.base_in_dim:
            node_type_feats = self.node_type_emb(data.node_type_ids)
            x = torch.cat([x, node_type_feats], dim=1)
        elif x.shape[1] != self.base_in_dim + self.node_type_emb_dim:
            raise ValueError(
                f"Unexpected x dim {x.shape[1]} (expected {self.base_in_dim} or {self.base_in_dim + self.node_type_emb_dim})"
            )

        for conv in self.convs:
            x = conv(x, edge_index, edge_type)
            x = torch.relu(x)
            x = self.dropout(x)
        node_logits = self.node_head(x).squeeze(-1)
        graph_emb = global_mean_pool(x, batch)
        graph_logits = self.graph_head(graph_emb).squeeze(-1)
        return node_logits, graph_logits


expected_base_in_dim = 768 + DIFF_FEATURE_DIM
raw_in_dim = data_list[0].x.shape[1]
base_in_dim = (
    expected_base_in_dim
    if raw_in_dim in {expected_base_in_dim, expected_base_in_dim + NODE_TYPE_EMB_DIM}
    else raw_in_dim
)
HIDDEN_DIM = 256
model = RGCNDetector(
    base_in_dim,
    HIDDEN_DIM,
    num_relations=len(RELATIONS),
    num_node_types=len(node_type_to_id),
    node_type_emb_dim=NODE_TYPE_EMB_DIM,
).to(device)

print(model)

RGCNDetector(
  (node_type_emb): Embedding(59, 64)
  (convs): ModuleList(
    (0): RGCNConv(838, 256, num_relations=10)
    (1): RGCNConv(256, 256, num_relations=10)
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (node_head): Linear(in_features=256, out_features=1, bias=True)
  (graph_head): Linear(in_features=256, out_features=1, bias=True)
)


In [13]:
def compute_pos_weight(samples, label_attr: str):
    labels = []
    for g in samples:
        labels.append(getattr(g, label_attr).view(-1).float())
    all_labels = torch.cat(labels, dim=0)
    pos = all_labels.sum().item()
    neg = all_labels.numel() - pos
    if pos == 0:
        return torch.tensor(1.0)
    return torch.tensor(neg / pos)


node_pos_weight = compute_pos_weight(train_ds, "y")
graph_pos_weight = compute_pos_weight(train_ds, "graph_y")

node_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=node_pos_weight.to(device))
graph_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=graph_pos_weight.to(device))

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)


def step_metrics(logits, labels):
    probs = torch.sigmoid(logits)
    preds = (probs > 0.5).long()
    labels = labels.long()
    tp = ((preds == 1) & (labels == 1)).sum().item()
    fp = ((preds == 1) & (labels == 0)).sum().item()
    fn = ((preds == 0) & (labels == 1)).sum().item()
    tn = ((preds == 0) & (labels == 0)).sum().item()
    precision = tp / (tp + fp + 1e-9)
    recall = tp / (tp + fn + 1e-9)
    f1 = 2 * precision * recall / (precision + recall + 1e-9)
    acc = (tp + tn) / max(tp + tn + fp + fn, 1)
    return precision, recall, f1, acc


def run_epoch(loader, is_train: bool):
    model.train() if is_train else model.eval()
    total_loss = 0.0
    node_stats = []
    graph_stats = []
    with torch.set_grad_enabled(is_train):
        for batch in loader:
            batch = batch.to(device)
            node_logits, graph_logits = model(batch)
            node_labels = batch.y.float()
            graph_labels = batch.graph_y.view(-1).float()

            node_loss = node_criterion(node_logits, node_labels)
            graph_loss = graph_criterion(graph_logits, graph_labels)
            loss = node_loss + graph_loss

            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item()
            node_stats.append(step_metrics(node_logits.detach().cpu(), node_labels.detach().cpu()))
            graph_stats.append(step_metrics(graph_logits.detach().cpu(), graph_labels.detach().cpu()))

    node_metrics = np.mean(node_stats, axis=0)
    graph_metrics = np.mean(graph_stats, axis=0)
    return total_loss / max(len(loader), 1), node_metrics, graph_metrics


EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
    train_loss, train_node, train_graph = run_epoch(train_loader, True)
    val_loss, val_node, val_graph = run_epoch(val_loader, False)
    print(
        f"Epoch {epoch:02d} | "
        f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} | "
        f"node_f1={val_node[2]:.4f} node_acc={val_node[3]:.4f} | "
        f"graph_f1={val_graph[2]:.4f} graph_acc={val_graph[3]:.4f}"
    )

KeyboardInterrupt: 

In [None]:
test_loss, test_node, test_graph = run_epoch(test_loader, False)
print(
    f"Test | loss={test_loss:.4f} | "
    f"node_precision={test_node[0]:.4f} node_recall={test_node[1]:.4f} node_f1={test_node[2]:.4f} node_acc={test_node[3]:.4f} | "
    f"graph_precision={test_graph[0]:.4f} graph_recall={test_graph[1]:.4f} graph_f1={test_graph[2]:.4f} graph_acc={test_graph[3]:.4f}"
)

Test | loss=4.0097 | node_precision=0.0278 node_recall=0.2500 node_f1=0.0500 node_acc=0.9912 | graph_precision=0.0000 graph_recall=0.0000 graph_f1=0.0000 graph_acc=0.8125


In [None]:
MODEL_OUT = os.path.join(OUTPUT_DIR, "rgcn_detector.pt")
torch.save(
    {
        "model_state": model.state_dict(),
        "base_in_dim": base_in_dim,
        "hidden_dim": HIDDEN_DIM,
        "relations": RELATIONS,
        "node_type_to_id": node_type_to_id,
    },
    MODEL_OUT,
)
print(f"Saved model to {MODEL_OUT}")

Saved model to output/rgcn_graphs\rgcn_detector.pt
