# CodeBuggy Inference (code + diff)

Nhap code buggy va diff (unified diff) de suy ra fixed code, sau do chay model RGCN de du doan bug o muc do graph va node.

In [1]:
import os
import re
import torch
import numpy as np
import javalang
import joblib

from transformers import AutoTokenizer, AutoModel
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.conv import RGCNConv

from utils.gumtree_diff import GumTreeDiff, EditType

torch.manual_seed(42)
np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
GUMTREE_PATH = "./gumtree-4.0.0-beta4/bin/gumtree"
MODEL_NAME = "microsoft/graphcodebert-base"
NODE_TYPE_PATH = "output/node_type_to_id.joblib"
CHECKPOINT_PATH = "output/rgcn_detector.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

RELATIONS = [
    "AST_CHILD",
    "AST_PARENT",
    "CFG_NEXT",
    "CFG_TRUE",
    "CFG_FALSE",
    "CFG_LOOP",
    "DEF_USE",
    "USE_DEF",
    "DIFF_PARENT",
    "DIFF_SIBLING",
]
RELATION_TO_ID = {r: i for i, r in enumerate(RELATIONS)}

node_type_to_id = joblib.load(NODE_TYPE_PATH)
gumtree_diff = GumTreeDiff(gumtree_path=GUMTREE_PATH)

for candidate in [CHECKPOINT_PATH, "output/rgcn_graphs/rgcn_detector.pt"]:
    if os.path.exists(candidate):
        CHECKPOINT_PATH = candidate
        break
else:
    raise FileNotFoundError(
        "Cannot find checkpoint. Tried output/rgcn_detector.pt and output/rgcn_graphs/rgcn_detector.pt"
    )

checkpoint = torch.load(CHECKPOINT_PATH, weights_only=False, map_location="cpu")

class RGCNDetector(torch.nn.Module):
    def __init__(
        self,
        base_in_dim: int,
        hidden_dim: int,
        num_relations: int,
        num_node_types: int,
        node_type_emb_dim: int,
        num_layers: int = 2,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.base_in_dim = base_in_dim
        self.node_type_emb_dim = node_type_emb_dim
        self.node_type_emb = torch.nn.Embedding(num_node_types, node_type_emb_dim)
        conv_in_dim = base_in_dim + node_type_emb_dim

        self.convs = torch.nn.ModuleList()
        self.convs.append(RGCNConv(conv_in_dim, hidden_dim, num_relations=num_relations))
        for _ in range(num_layers - 1):
            self.convs.append(RGCNConv(hidden_dim, hidden_dim, num_relations=num_relations))
        self.dropout = torch.nn.Dropout(dropout)
        self.node_head = torch.nn.Linear(hidden_dim, 1)
        self.graph_head = torch.nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x, edge_index, edge_type, batch = data.x, data.edge_index, data.edge_type, data.batch

        if x.shape[1] == self.base_in_dim:
            node_type_feats = self.node_type_emb(data.node_type_ids)
            x = torch.cat([x, node_type_feats], dim=1)
        elif x.shape[1] != self.base_in_dim + self.node_type_emb_dim:
            raise ValueError(
                f"Unexpected x dim {x.shape[1]} (expected {self.base_in_dim} or {self.base_in_dim + self.node_type_emb_dim})"
            )

        for conv in self.convs:
            x = conv(x, edge_index, edge_type)
            x = torch.relu(x)
            x = self.dropout(x)
        node_logits = self.node_head(x).squeeze(-1)
        graph_emb = global_mean_pool(x, batch)
        graph_logits = self.graph_head(graph_emb).squeeze(-1)
        return node_logits, graph_logits


model = RGCNDetector(
    base_in_dim=checkpoint["base_in_dim"],
    hidden_dim=checkpoint["hidden_dim"],
    num_relations=len(checkpoint["relations"]),
    num_node_types=len(node_type_to_id),
    node_type_emb_dim=64,
).to(DEVICE)
model.load_state_dict(checkpoint["model_state"])
model.eval()

print(f"Device: {DEVICE}")

Device: cpu


In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
encoder = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
encoder.eval()

STATEMENT_NODES = {
    "StatementExpression",
    "ReturnStatement",
    "IfStatement",
    "ForStatement",
    "WhileStatement",
    "DoStatement",
    "SwitchStatement",
    "TryStatement",
    "ThrowStatement",
    "BreakStatement",
    "ContinueStatement",
    "BlockStatement",
}

WRAP_TEMPLATE = """public class Dummy {{
    {method_code}
}}"""


def wrap_method(method_code: str) -> str:
    return WRAP_TEMPLATE.format(method_code=method_code)


def iter_children(node):
    for child in node.children:
        if child is None:
            continue
        if isinstance(child, list):
            for item in child:
                if item is not None:
                    yield item
        else:
            yield child


def get_node_label(node) -> str | None:
    if hasattr(node, "name") and node.name:
        return str(node.name)
    if hasattr(node, "member") and node.member:
        return str(node.member)
    if hasattr(node, "value") and node.value is not None:
        return str(node.value)
    if hasattr(node, "operator") and node.operator:
        return str(node.operator)
    if hasattr(node, "type") and isinstance(node.type, str):
        return str(node.type)
    return None


def compute_line_offsets(code: str) -> list[int]:
    offsets = [0]
    for idx, ch in enumerate(code):
        if ch == "\n":
            offsets.append(idx + 1)
    return offsets


def line_col_to_offset(code: str, line: int, col: int) -> int | None:
    if line <= 0 or col <= 0:
        return None
    line_offsets = compute_line_offsets(code)
    if line > len(line_offsets):
        return None
    return line_offsets[line - 1] + (col - 1)


def offset_to_line_col(code: str, offset: int) -> tuple[int, int]:
    if offset < 0:
        return (1, 1)
    line_offsets = compute_line_offsets(code)
    line = 1
    for i, start in enumerate(line_offsets, 1):
        if start <= offset:
            line = i
        else:
            break
    col = offset - line_offsets[line - 1] + 1
    return (line, col)


def build_ast_graph(code: str):
    tree = javalang.parse.parse(wrap_method(code))

    nodes = []
    parents = []
    children = []
    id_to_index = {}

    def is_wrapper(node) -> bool:
        if isinstance(node, javalang.tree.CompilationUnit):
            return True
        if isinstance(node, javalang.tree.ClassDeclaration) and getattr(node, "name", None) == "Dummy":
            return True
        return False

    def visit(node, parent_idx: int | None):
        if not isinstance(node, javalang.tree.Node):
            return
        if is_wrapper(node):
            for child in iter_children(node):
                visit(child, parent_idx)
            return

        idx = len(nodes)
        id_to_index[id(node)] = idx

        label = get_node_label(node)
        line = None
        col = None
        if getattr(node, "position", None):
            line = node.position.line
            col = node.position.column
        start_pos = line_col_to_offset(code, line, col) if line and col else None
        end_pos = start_pos + len(label) if (start_pos is not None and label) else start_pos

        nodes.append(
            {
                "raw": node,
                "node_type": node.__class__.__name__,
                "label": label,
                "line": line,
                "col": col,
                "start_pos": start_pos,
                "end_pos": end_pos,
            }
        )
        parents.append(parent_idx)
        children.append([])
        if parent_idx is not None:
            children[parent_idx].append(idx)

        for child in iter_children(node):
            visit(child, idx)

    visit(tree, None)

    edges = []
    edge_types = []

    def add_edge(src: int, dst: int, rel: str):
        edges.append((src, dst))
        edge_types.append(RELATION_TO_ID[rel])

    for parent_idx, child_list in enumerate(children):
        for child_idx in child_list:
            add_edge(parent_idx, child_idx, "AST_CHILD")
            add_edge(child_idx, parent_idx, "AST_PARENT")

    for parent_idx, child_list in enumerate(children):
        stmt_children = [c for c in child_list if nodes[c]["node_type"] in STATEMENT_NODES]
        for a, b in zip(stmt_children, stmt_children[1:]):
            add_edge(a, b, "CFG_NEXT")

    for idx, node in enumerate(nodes):
        node_type = node["node_type"]
        raw = node["raw"]

        if node_type == "IfStatement":
            then_node = getattr(raw, "then_statement", None)
            else_node = getattr(raw, "else_statement", None)
            then_idx = id_to_index.get(id(then_node))
            else_idx = id_to_index.get(id(else_node))
            if then_idx is not None:
                add_edge(idx, then_idx, "CFG_TRUE")
            if else_idx is not None:
                add_edge(idx, else_idx, "CFG_FALSE")

        if node_type in {"ForStatement", "WhileStatement", "DoStatement"}:
            body = getattr(raw, "body", None)
            body_idx = id_to_index.get(id(body))
            if body_idx is not None:
                add_edge(idx, body_idx, "CFG_LOOP")

    last_def = {}

    def extract_assigned_name(raw_node) -> str | None:
        target = getattr(raw_node, "expressionl", None) or getattr(raw_node, "left", None)
        if target is None:
            return None
        if hasattr(target, "member") and target.member:
            return str(target.member)
        if hasattr(target, "name") and target.name:
            return str(target.name)
        return None

    for idx, node in enumerate(nodes):
        node_type = node["node_type"]
        label = node["label"]
        raw = node["raw"]

        if node_type in {"VariableDeclarator", "FormalParameter"} and label:
            last_def[label] = idx

        if node_type == "Assignment":
            assigned = extract_assigned_name(raw)
            if assigned:
                last_def[assigned] = idx

        if node_type == "MemberReference" and label:
            if label in last_def:
                def_idx = last_def[label]
                add_edge(def_idx, idx, "DEF_USE")
                add_edge(idx, def_idx, "USE_DEF")

    return nodes, parents, children, edges, edge_types


def match_actions_to_nodes(code: str, nodes: list[dict], parents: list[int | None], actions):
    action_map: dict[int, EditType] = {}

    for action in actions:
        node_type = action.node.node_type
        label = action.node.label
        pos = action.node.position
        line = None
        col = None
        if pos is not None:
            line, col = offset_to_line_col(code, pos[0])

        candidates = [i for i, n in enumerate(nodes) if n["node_type"] == node_type]
        if label:
            label_candidates = [i for i in candidates if nodes[i]["label"] == label]
            if label_candidates:
                candidates = label_candidates
        if line:
            line_candidates = [i for i in candidates if nodes[i]["line"] == line]
            if line_candidates:
                candidates = line_candidates

        if not candidates:
            continue

        if col is not None:
            candidates.sort(key=lambda i: abs((nodes[i]["col"] or col) - col))

        matched_idx = candidates[0]
        action_map[matched_idx] = action.action_type

    subtree_changed = [0] * len(nodes)
    for idx in action_map.keys():
        cur = parents[idx]
        while cur is not None:
            subtree_changed[cur] = 1
            cur = parents[cur]

    return action_map, subtree_changed


def build_diff_features(action_map: dict[int, EditType], subtree_changed: list[int], num_nodes: int):
    diff_feats = []
    labels = []
    for idx in range(num_nodes):
        action = action_map.get(idx)
        is_diff = 1 if action is not None else 0
        action_none = 1 if action is None else 0
        action_update = 1 if action == EditType.UPDATE else 0
        action_delete = 1 if action == EditType.DELETE else 0
        action_move = 1 if action == EditType.MOVE else 0
        diff_feats.append([
            is_diff,
            action_none,
            action_update,
            action_delete,
            action_move,
            subtree_changed[idx],
        ])
        labels.append(1 if action in {EditType.UPDATE, EditType.DELETE} else 0)
    return torch.tensor(diff_feats, dtype=torch.float), torch.tensor(labels, dtype=torch.long)


def add_diff_edges(children, parents, action_map, edges, edge_types):
    changed_nodes = set(action_map.keys())
    for idx in changed_nodes:
        parent = parents[idx]
        if parent is not None:
            edges.append((idx, parent))
            edge_types.append(RELATION_TO_ID["DIFF_PARENT"])
            for sibling in children[parent]:
                if sibling == idx:
                    continue
                edges.append((idx, sibling))
                edge_types.append(RELATION_TO_ID["DIFF_SIBLING"])
                edges.append((sibling, idx))
                edge_types.append(RELATION_TO_ID["DIFF_SIBLING"])


def get_node_type_id(node_type: str) -> int:
    return node_type_to_id.get(node_type, node_type_to_id["UNK"])


def compute_code_embeddings(code: str, nodes: list[dict], children: list[list[int]]):
    enc = tokenizer(
        code,
        return_tensors="pt",
        return_offsets_mapping=True,
        truncation=True,
        max_length=512,
    )
    offsets = enc.pop("offset_mapping")[0].tolist()
    enc = {k: v.to(DEVICE) for k, v in enc.items()}

    with torch.inference_mode():
        outputs = encoder(**enc)
    token_embs = outputs.last_hidden_state[0].cpu()

    num_nodes = len(nodes)
    node_embs = torch.zeros((num_nodes, token_embs.shape[-1]), dtype=torch.float)

    for idx, node in enumerate(nodes):
        start = node["start_pos"]
        end = node["end_pos"]
        if start is None or end is None or start == end:
            continue
        token_idxs = [
            t for t, (s, e) in enumerate(offsets)
            if not (s == 0 and e == 0) and s < end and e > start
        ]
        if token_idxs:
            node_embs[idx] = token_embs[token_idxs].mean(dim=0)

    for idx in reversed(range(num_nodes)):
        if torch.all(node_embs[idx] == 0) and children[idx]:
            child_embs = [node_embs[c] for c in children[idx] if torch.any(node_embs[c] != 0)]
            if child_embs:
                node_embs[idx] = torch.stack(child_embs).mean(dim=0)

    return node_embs


def build_graph_parts(buggy_code: str, fixed_code: str, method_id: str):
    nodes, parents, children, edges, edge_types = build_ast_graph(buggy_code)

    diff_result = gumtree_diff.diff(buggy_code, fixed_code)
    action_map, subtree_changed = match_actions_to_nodes(buggy_code, nodes, parents, diff_result.actions)

    diff_feats, labels = build_diff_features(action_map, subtree_changed, len(nodes))
    add_diff_edges(children, parents, action_map, edges, edge_types)

    node_type_ids = torch.tensor([get_node_type_id(n["node_type"]) for n in nodes], dtype=torch.long)

    return {
        "buggy_code": buggy_code,
        "nodes": nodes,
        "parents": parents,
        "children": children,
        "edges": edges,
        "edge_types": edge_types,
        "diff_feats": diff_feats,
        "labels": labels,
        "node_type_ids": node_type_ids,
        "method_id": method_id,
    }


def build_graph_sample(
    buggy_code: str,
    fixed_code: str,
    method_id: str,
):
    parts = build_graph_parts(buggy_code, fixed_code, method_id)

    code_embs = compute_code_embeddings(
        parts["buggy_code"], parts["nodes"], parts["children"]
    )

    x = torch.cat([code_embs, parts["diff_feats"]], dim=1)

    if parts["edges"]:
        edge_index = torch.tensor(parts["edges"], dtype=torch.long).t().contiguous()
        edge_type = torch.tensor(parts["edge_types"], dtype=torch.long)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_type = torch.empty((0,), dtype=torch.long)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_type=edge_type,
        y=parts["labels"],
        node_type_ids=parts["node_type_ids"],
        method_id=parts["method_id"],
    )
    return data, parts

Some weights of RobertaModel were not initialized from the model checkpoint at microsoft/graphcodebert-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
HUNK_RE = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@")


def _find_sublist(haystack: list[str], needle: list[str], start: int = 0) -> int | None:
    if not needle:
        return start
    max_i = len(haystack) - len(needle)
    for i in range(start, max_i + 1):
        if haystack[i : i + len(needle)] == needle:
            return i
    return None


def apply_unified_diff(source: str, diff_text: str) -> str:
    src_lines = source.splitlines()
    result = []
    src_pos = 0

    lines = diff_text.splitlines()
    i = 0
    while i < len(lines):
        line = lines[i]
        if not line.startswith("@@"):
            i += 1
            continue

        m = HUNK_RE.match(line)
        if not m:
            raise ValueError(f"Invalid hunk header: {line}")
        src_start = int(m.group(1)) - 1

        i += 1
        hunk_lines = []
        while i < len(lines) and not lines[i].startswith("@@"):
            hline = lines[i]
            if hline.startswith(("---", "+++", "diff ", "index ")):
                i += 1
                continue
            if hline == "":
                i += 1
                continue
            if hline[0] not in {" ", "-", "+", "\\"}:
                raise ValueError(f"Unexpected diff line: {hline}")
            hunk_lines.append(hline)
            i += 1

        before_lines = [h[1:] for h in hunk_lines if h.startswith((" ", "-"))]
        found = _find_sublist(src_lines, before_lines, start=src_pos)
        if found is None:
            found = src_start
        if found < src_pos:
            raise ValueError("Hunk overlaps previous content")

        result.extend(src_lines[src_pos:found])
        src_pos = found

        for hline in hunk_lines:
            if hline.startswith(" "):
                result.append(src_lines[src_pos])
                src_pos += 1
            elif hline.startswith("-"):
                src_pos += 1
            elif hline.startswith("+"):
                result.append(hline[1:])
            elif hline.startswith("\\"):
                pass

    result.extend(src_lines[src_pos:])
    fixed = "\n".join(result)
    if source.endswith("\n"):
        fixed += "\n"
    return fixed

# Infer

In [5]:
# BUGGY_CODE = """
# public int sum(int[] arr) {
#     int s = 0;
#     for (int i = 0; i <= arr.length; i++) {
#         s += arr[i];
#     }
#     return s;
# }
# """.strip()

# DIFF_TEXT = """
# --- a/Snippet.java
# +++ b/Snippet.java
# @@ -2,7 +2,7 @@
#  public int sum(int[] arr) {
#      int s = 0;
# -    for (int i = 0; i <= arr.length; i++) {
# +    for (int i = 0; i < arr.length; i++) {
#          s += arr[i];
#      }
#      return s;
# """.strip()
# FIXED_CODE = """""".strip()

# if DIFF_TEXT:
#     fixed_code = apply_unified_diff(BUGGY_CODE, DIFF_TEXT)
# else:
#     fixed_code = FIXED_CODE if FIXED_CODE else BUGGY_CODE


In [6]:
import pandas as pd
df=pd.read_csv('megadiff_single_function_100.csv')
buggy_funcs = df['buggy_function']
fixed_funcs = df['fixed_function']

In [7]:
def infer():
    data, parts = build_graph_sample(
        BUGGY_CODE,
        FIXED_CODE,
        method_id="manual",
    )

    loader = DataLoader([data], batch_size=1)
    batch = next(iter(loader)).to(DEVICE)

    with torch.inference_mode():
        node_logits, graph_logits = model(batch)

    node_probs = torch.sigmoid(node_logits).detach().cpu().numpy()
    graph_prob = float(torch.sigmoid(graph_logits).item())
    if graph_prob < 0.5:
        return
    print(f"Graph bug probability: {graph_prob:.4f}")

    # Node Inference
    TOP_K = 10

    nodes = parts["nodes"]
    lines = BUGGY_CODE.splitlines()

    ranked = sorted(range(len(node_probs)), key=lambda i: node_probs[i], reverse=True)
    print("Top node predictions:")

    for rank, idx in enumerate(ranked[:TOP_K], 1):
        node = nodes[idx]
        line = node.get("line")
        col = node.get("col")
        label = node.get("label") or ""
        node_type = node.get("node_type")
        prob = float(node_probs[idx])

        line_text = ""
        if line is not None and 1 <= line <= len(lines):
            line_text = lines[line - 1].strip()

        if prob < 0.2:
            return
        print(
            f"{rank:02d}. prob={prob:.4f} type={node_type} label={label} line={line} col={col} | {line_text}"
        )

In [None]:
for index in [62, 92, 179, 186, 839, 846, 888]:
    print(index)
    BUGGY_CODE = buggy_funcs[index]
    FIXED_CODE = fixed_funcs[index]
    infer()

800
801
802
Graph bug probability: 1.0000
Top node predictions:
01. prob=1.0000 type=ReturnStatement label= line=21 col=17 | 
02. prob=1.0000 type=MethodInvocation label=floor line=12 col=36 | final double lonMinFloor = Math.floor(lonMin);
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
Graph bug probability: 1.0000
Top node predictions:
01. prob=0.9901 type=BreakStatement label= line=41 col=9 | case CLOSE:
819
820
821
822
823
824
825
826
827
828
829
830
Graph bug probability: 1.0000
Top node predictions:
831
832
833
834
835
Graph bug probability: 1.0000
Top node predictions:
836
837
838
839
Graph bug probability: 0.9979
Top node predictions:
01. prob=1.0000 type=Assignment label=Literal(postfix_operators=[], prefix_operators=[], qualifier=None, selectors=[], value=null) line=None col=None | 
02. prob=0.3495 type=MethodInvocation label=substring line=141 col=37 | break;
840
841
842
843
844
845
846
Graph bug probability: 1.0000
Top node predictions:
01. prob=1.0000 type=

KeyboardInterrupt: 

In [None]:
802 808 818 839 846 888

(125, 146, 177, 173, 179, 186)

In [None]:
62, 92, 179, 186, 839, 846, 888

(62, 92)

In [65]:
index = 62
BUGGY_CODE = buggy_funcs[index]
FIXED_CODE = fixed_funcs[index]

BUGGY_CODE
infer()

Graph bug probability: 0.9976
Top node predictions:
01. prob=1.0000 type=MethodInvocation label=getAbsolutePath line=25 col=47 | when(project.getBuild()).thenReturn(build);


In [61]:
import difflib 

diff_output = list(difflib.unified_diff(
    BUGGY_CODE.splitlines(keepends=True), 
    FIXED_CODE.splitlines(keepends=True),
))
print(''.join(diff_output))

--- 
+++ 
@@ -1,4 +1,4 @@
-	public void testDefaultAgentRepoAndBundlePoolFromProfileRepo() {
+	public void testDefaultAgentRepoAndBundlePoolFromProfileRepo() throws InterruptedException {
 		File testData = getTestData("0.1", "testData/sdkpatchingtest");
 		// /p2/org.eclipse.equinox.p2.engine/profileRegistry");
 		File tempFolder = getTempFolder();
@@ -38,6 +38,13 @@
 		assertFalse(repoCollector.isEmpty());
 		assertTrue(repoCollector.toCollection().containsAll(profileCollector.toCollection()));
 
-		assertTrue(manager.contains(tempFolder.toURI()));
-		assertTrue(manager.contains(defaultAgenRepositoryDirectory.toURI()));
+		int maxTries = 20;
+		int current = 0;
+		while (true) {
+			if (manager.contains(tempFolder.toURI()) && manager.contains(defaultAgenRepositoryDirectory.toURI()))
+				break;
+			if (++current == maxTries)
+				fail("profile artifact repos not added");
+			Thread.sleep(100);
+		}
 	}



In [66]:
print(BUGGY_CODE)

    public void setup(){
        //Create the temp dir
        final File sysTempDir = new File(System.getProperty("java.io.tmpdir"));
        String dirName = UUID.randomUUID().toString();
        tempDir = new File(sysTempDir, dirName);
        productDir = new File(tempDir,PROJECT_ID);
        tempResourcesDir = new File(productDir,TMP_RESOURCES);
        generatedHomeDir = new File(tempResourcesDir,GENERATED_HOME);
        pluginsDir = new File(generatedHomeDir,PLUGINS);
        bundledPluginsDir = new File(generatedHomeDir,BUNDLED_PLUGINS);

        //setup maven mocks
        MavenProject project = mock(MavenProject.class);
        Build build = mock(Build.class);

        //Mockito throws NoClassDefFoundError: org/apache/maven/project/ProjectBuilderConfiguration
        //when mocking the session
        //MavenSession session = mock(MavenSession.class);

        PluginManager pluginManager = mock(PluginManager.class);
        List<MavenProject> reactor = Collections.<MavenProje

# Visual

In [14]:
import difflib

diff_output = list(difflib.unified_diff(
    BUGGY_CODE.splitlines(keepends=True),
    FIXED_CODE.splitlines(keepends=True),
    fromfile='buggy',
    tofile='fixed'
))
print(''.join(diff_output))


--- buggy
+++ fixed
@@ -7,13 +7,13 @@
         boolean isBetter = false;
         switch (policy) {
             case MINIMIZE:
-                if (bestVal > val) {
+                if (bestVal > val || nbSol==1) {
                     bestVal = val;
                     isBetter = true;
                 }
                 break;
             case MAXIMIZE:
-                if (bestVal < val) {
+                if (bestVal < val || nbSol==1) {
                     bestVal = val;
                     isBetter = true;
                 }



In [18]:
from utils.feature_extractor import DiffFeatureExtractor
extractor = DiffFeatureExtractor(gumtree_path='./gumtree-4.0.0-beta4/bin/gumtree')
features = extractor.extract(BUGGY_CODE, FIXED_CODE)

ModuleNotFoundError: No module named 'utils.feature_extractor'

In [None]:
# Step 4: Visualize buggy regions
from IPython.display import HTML

def visualize_buggy_code(code: str, line_mask: dict) -> str:
    """Highlight buggy lines in HTML"""
    lines = code.split('\n')
    html_lines = []
    
    for i, line in enumerate(lines, 1):
        is_buggy = line_mask.get(i, 0) == 1
        escaped_line = line.replace('<', '&lt;').replace('>', '&gt;')
        
        if is_buggy:
            html_lines.append(f'<span style="background-color: #ffcccc; display: block;">{i:3d} | üêõ {escaped_line}</span>')
        else:
            html_lines.append(f'<span style="display: block;">{i:3d} |    {escaped_line}</span>')
    
    return f'<pre style="font-family: monospace; font-size: 12px; background: #f5f5f5; padding: 10px; border-radius: 5px;">{"".join(html_lines)}</pre>'

# Display buggy code with highlights
html_output = visualize_buggy_code(BUGGY_CODE, features.mask.line_mask)
HTML(html_output)

# Plot hist

In [None]:
import torch
hist = torch.load('output/train_hist.pt')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

epochs = [item["epoch"] for item in hist]
train_loss = [item["train_loss"] for item in hist]
val_loss = [item["val_loss"] for item in hist]

train_node = np.array([item["train_node"] for item in hist])
val_node = np.array([item["val_node"] for item in hist])
train_graph = np.array([item["train_graph"] for item in hist])
val_graph = np.array([item["val_graph"] for item in hist])

fig, axes = plt.subplots(3, 1, figsize=(12, 12), sharex=True)

axes[0].plot(epochs, train_loss, label="train_loss")
axes[0].plot(epochs, val_loss, label="val_loss")
axes[0].set_ylabel("Loss")
axes[0].grid(True, alpha=0.3)
axes[0].legend()

metrics = ['pre', 'rec', 'f1', 'acc']
for i in range(train_node.shape[1]):
    axes[1].plot(epochs, train_node[:, i], label=f"train_node_{metrics[i]}")
    axes[1].plot(epochs, val_node[:, i], linestyle="--", label=f"val_node_{metrics[i]}")
axes[1].set_ylabel("Node metrics")
axes[1].grid(True, alpha=0.3)
axes[1].legend(ncol=2, fontsize=9)

for i in range(train_graph.shape[1]):
    axes[2].plot(epochs, train_graph[:, i], label=f"train_graph_{metrics[i]}")
    axes[2].plot(epochs, val_graph[:, i], linestyle="--", label=f"val_graph_{metrics[i]}")
axes[2].set_ylabel("Graph metrics")
axes[2].set_xlabel("Epoch")
axes[2].grid(True, alpha=0.3)
axes[2].legend(ncol=2, fontsize=9)

plt.tight_layout()
plt.show()