In [1]:
import os
import joblib
import numpy as np
import torch
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.conv import RGCNConv

torch.manual_seed(42)
np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
NODE_TYPE_EMB_DIM = 64
DIFF_FEATURE_DIM = 6
OUTPUT_DIR = "output/rgcn_graphs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


In [None]:
RELATIONS = [
    "AST_CHILD",
    "AST_PARENT",
    "CFG_NEXT",
    "CFG_TRUE",
    "CFG_FALSE",
    "CFG_LOOP",
    "DEF_USE",
    "USE_DEF",
    "DIFF_PARENT",
    "DIFF_SIBLING",
]


node_type_to_id = joblib.load("node_type_to_id.joblib")
data_list = torch.load("output/megadiff_graphs.pt", weights_only=False)


In [None]:
# Graph-level labels for bug detection
for g in data_list:
    g.graph_y = torch.tensor([1.0 if int(g.y.sum()) > 0 else 0.0], dtype=torch.float)

# Train/val/test split
num_total = len(data_list)
train_size = int(0.8 * num_total)
val_size = int(0.1 * num_total)
test_size = num_total - train_size - val_size
train_ds, val_ds, test_ds = random_split(
    data_list,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42),
)

BATCH_SIZE = 64
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

print(f"Split: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")

In [None]:
def count_labels(samples, label_attr: str):
    pos = 0
    total = 0
    for g in samples:
        labels = getattr(g, label_attr).view(-1)
        pos += int(labels.sum().item())
        total += int(labels.numel())
    neg = total - pos
    ratio = (neg / pos) if pos > 0 else float("inf")
    return pos, neg, total, ratio

# Overall (node-level + graph-level)
node_pos, node_neg, node_total, node_ratio = count_labels(data_list, "y")
graph_pos, graph_neg, graph_total, graph_ratio = count_labels(data_list, "graph_y")

print("overall imbalance")
print(f"  Node labels: pos={node_pos} neg={node_neg} total={node_total} neg/pos={node_ratio:.2f}")
print(f"  Graph labels: pos={graph_pos} neg={graph_neg} total={graph_total} neg/pos={graph_ratio:.2f}")

# Per-split
for name, ds in [("train", train_ds), ("val", val_ds), ("test", test_ds)]:
    npos, nneg, ntotal, nratio = count_labels(ds, "y")
    gpos, gneg, gtotal, gratio = count_labels(ds, "graph_y")
    print(f"\n{name} imbalance")
    print(f"  Node: pos={npos} neg={nneg} total={ntotal} neg/pos={nratio:.2f}")
    print(f"  Graph: pos={gpos} neg={gneg} total={gtotal} neg/pos={gratio:.2f}")

In [None]:
class RGCNDetector(torch.nn.Module):
    def __init__(
        self,
        base_in_dim: int,
        hidden_dim: int,
        num_relations: int,
        num_node_types: int,
        node_type_emb_dim: int,
        num_layers: int = 2,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.base_in_dim = base_in_dim
        self.node_type_emb_dim = node_type_emb_dim
        self.node_type_emb = torch.nn.Embedding(num_node_types, node_type_emb_dim)
        conv_in_dim = base_in_dim + node_type_emb_dim

        self.convs = torch.nn.ModuleList()
        self.convs.append(RGCNConv(conv_in_dim, hidden_dim, num_relations=num_relations))
        for _ in range(num_layers - 1):
            self.convs.append(RGCNConv(hidden_dim, hidden_dim, num_relations=num_relations))
        self.dropout = torch.nn.Dropout(dropout)
        self.node_head = torch.nn.Linear(hidden_dim, 1)
        self.graph_head = torch.nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x, edge_index, edge_type, batch = data.x, data.edge_index, data.edge_type, data.batch

        if x.shape[1] == self.base_in_dim:
            node_type_feats = self.node_type_emb(data.node_type_ids)
            x = torch.cat([x, node_type_feats], dim=1)
        elif x.shape[1] != self.base_in_dim + self.node_type_emb_dim:
            raise ValueError(
                f"Unexpected x dim {x.shape[1]} (expected {self.base_in_dim} or {self.base_in_dim + self.node_type_emb_dim})"
            )

        for conv in self.convs:
            x = conv(x, edge_index, edge_type)
            x = torch.relu(x)
            x = self.dropout(x)
        node_logits = self.node_head(x).squeeze(-1)
        graph_emb = global_mean_pool(x, batch)
        graph_logits = self.graph_head(graph_emb).squeeze(-1)
        return node_logits, graph_logits


expected_base_in_dim = 768 + DIFF_FEATURE_DIM
raw_in_dim = data_list[0].x.shape[1]
base_in_dim = (
    expected_base_in_dim
    if raw_in_dim in {expected_base_in_dim, expected_base_in_dim + NODE_TYPE_EMB_DIM}
    else raw_in_dim
)
HIDDEN_DIM = 256
model = RGCNDetector(
    base_in_dim,
    HIDDEN_DIM,
    num_relations=len(RELATIONS),
    num_node_types=len(node_type_to_id),
    node_type_emb_dim=NODE_TYPE_EMB_DIM,
).to(device)

In [None]:
def compute_pos_weight(samples, label_attr: str):
    labels = []
    for g in samples:
        labels.append(getattr(g, label_attr).view(-1).float())
    all_labels = torch.cat(labels, dim=0)
    pos = all_labels.sum().item()
    neg = all_labels.numel() - pos
    if pos == 0:
        return torch.tensor(1.0)
    return torch.tensor(neg / pos)


node_pos_weight = compute_pos_weight(train_ds, "y")
graph_pos_weight = compute_pos_weight(train_ds, "graph_y")

node_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=node_pos_weight.to(device))
graph_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=graph_pos_weight.to(device))

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)


def step_metrics(logits, labels):
    probs = torch.sigmoid(logits)
    preds = (probs > 0.5).long()
    labels = labels.long()
    tp = ((preds == 1) & (labels == 1)).sum().item()
    fp = ((preds == 1) & (labels == 0)).sum().item()
    fn = ((preds == 0) & (labels == 1)).sum().item()
    tn = ((preds == 0) & (labels == 0)).sum().item()
    precision = tp / (tp + fp + 1e-9)
    recall = tp / (tp + fn + 1e-9)
    f1 = 2 * precision * recall / (precision + recall + 1e-9)
    acc = (tp + tn) / max(tp + tn + fp + fn, 1)
    return precision, recall, f1, acc


def run_epoch(loader, is_train: bool):
    model.train() if is_train else model.eval()
    total_loss = 0.0
    node_stats = []
    graph_stats = []
    with torch.set_grad_enabled(is_train):
        for batch in loader:
            batch = batch.to(device)
            node_logits, graph_logits = model(batch)
            node_labels = batch.y.float()
            graph_labels = batch.graph_y.view(-1).float()

            node_loss = node_criterion(node_logits, node_labels)
            graph_loss = graph_criterion(graph_logits, graph_labels)
            loss = node_loss + graph_loss

            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item()
            node_stats.append(step_metrics(node_logits.detach().cpu(), node_labels.detach().cpu()))
            graph_stats.append(step_metrics(graph_logits.detach().cpu(), graph_labels.detach().cpu()))

    node_metrics = np.mean(node_stats, axis=0)
    graph_metrics = np.mean(graph_stats, axis=0)
    return total_loss / max(len(loader), 1), node_metrics, graph_metrics


EPOCHS = 10
hist = []
for epoch in range(1, EPOCHS + 1):
    train_loss, train_node, train_graph = run_epoch(train_loader, True)
    val_loss, val_node, val_graph = run_epoch(val_loader, False)
    hist.append(
        {
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_node": train_node.tolist(),
            "val_node": val_node.tolist(),
            "train_graph": train_graph.tolist(),
            "val_graph": val_graph.tolist(),
        }
    )
    print(
        f"Epoch {epoch:02d} | "
        f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} | "
        f"node_f1={val_node[2]:.4f} node_acc={val_node[3]:.4f} | "
        f"graph_f1={val_graph[2]:.4f} graph_acc={val_graph[3]:.4f}"
    )

HIST_OUT = os.path.join(OUTPUT_DIR, "train_hist.pt")
torch.save(hist, HIST_OUT)
print(f"Saved history to {HIST_OUT}")

In [None]:
test_loss, test_node, test_graph = run_epoch(test_loader, False)
print(
    f"Test | loss={test_loss:.4f} | "
    f"node_precision={test_node[0]:.4f} node_recall={test_node[1]:.4f} node_f1={test_node[2]:.4f} node_acc={test_node[3]:.4f} | "
    f"graph_precision={test_graph[0]:.4f} graph_recall={test_graph[1]:.4f} graph_f1={test_graph[2]:.4f} graph_acc={test_graph[3]:.4f}"
)

In [None]:
MODEL_OUT = os.path.join(OUTPUT_DIR, "rgcn_detector.pt")
torch.save(
    {
        "model_state": model.state_dict(),
        "base_in_dim": base_in_dim,
        "hidden_dim": HIDDEN_DIM,
        "relations": RELATIONS,
        "node_type_to_id": node_type_to_id,
    },
    MODEL_OUT,
)
print(f"Saved model to {MODEL_OUT}")