In [None]:
from huggingface_hub import login

login("my_token")

In [None]:
import pandas as pd

SRC = "../data/github-labels-top3-803k-test.csv"
label_col = "issue_label"
target_labels = ["bug", "enhancement", "question"]
n_per_class = 10

df_all = pd.read_csv(SRC)
df_all = df_all[df_all[label_col].isin(target_labels)]


df_eval = (
    df_all.groupby(label_col, group_keys=False)
    .sample(n=n_per_class, random_state=200)
    .sort_values([label_col])
    .reset_index(drop=True)
)

df_eval.to_csv("balanced_eval_set.csv", index=False)

print(df_eval[label_col].value_counts())

issue_label
bug            10
enhancement    10
question       10
Name: count, dtype: int64


In [4]:
from transformers import BertForSequenceClassification, BertTokenizer, AutoConfig
from lime.lime_text import LimeTextExplainer
from IPython.display import HTML, display
import pandas as pd
import numpy as np
import shap, torch, os
from typing import Optional

MODEL_PATH = "../models/nlbse/"
config = AutoConfig.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH, config=config)
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

display_labels = ["bug", "enhancement", "question"]
if getattr(model.config, "id2label", None):
    raw = [model.config.id2label[i] for i in range(model.config.num_labels)]
    labels = (
        display_labels if set(raw) == {f"LABEL_{i}" for i in range(len(raw))} else raw
    )
else:
    labels = display_labels
assert len(labels) == model.config.num_labels, "num_labels와 labels 길이가 다릅니다."


def predict(texts):
    enc = tokenizer(
        list(texts), return_tensors="pt", padding=True, truncation=True, max_length=512
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()
    return probs


CSV_PATH = "balanced_eval_set.csv"
if not os.path.exists(CSV_PATH):
    raise FileNotFoundError(
        f"{CSV_PATH} file not found. Run the first cell to create it."
    )
df = pd.read_csv(CSV_PATH).reset_index(drop=True)

text_col_title = "issue_title"
text_col_body = "issue_body"
gold_col = "issue_label"


def build_text(row, title_col=text_col_title, body_col=text_col_body, max_chars=800):
    title = str(row.get(title_col, ""))
    body = str(row.get(body_col, ""))[:max_chars]
    return (title + "\n\n" + body).strip()


samples = [build_text(row) for _, row in df.iterrows()]

explainer = LimeTextExplainer(class_names=labels, random_state=42)

for n, row in enumerate(df.itertuples(index=False)):
    txt = samples[n]
    probs = predict([txt])[0]
    idx = int(probs.argmax())
    predlab = labels[idx]
    conf = float(probs[idx])

    print(
        f"\n[LIME {n:02d}] gold: {getattr(row, gold_col)} | pred: {predlab} (conf {conf:.2f})"
    )
    exp = explainer.explain_instance(
        txt, predict, labels=[idx], num_features=12, num_samples=600
    )

    html = exp.as_html(text=txt)
    # display(HTML(html))
    with open(f"lime_{n:02d}.html", "w", encoding="utf-8") as f:
        f.write(html)


CSS_FIX = """
<style>
  html, body { margin: 8px; font-family: system-ui, -apple-system, Segoe UI, Roboto, sans-serif; }
  .shap { position: relative; }
  .shap .top { margin-bottom: 36px !important; position: relative; z-index: 1; }
  .shap .text, .shap .inputs, .shap .labels { position: relative; z-index: 9999 !important; }
  .shap .text span { line-height: 1.65 !important; } 
  svg { overflow: visible !important; }
</style>
"""


masker = shap.maskers.Text(
    tokenizer=tokenizer, mask_token=(tokenizer.mask_token or "[MASK]")
)
shap_explainer = shap.Explainer(
    predict, masker, output_names=labels, algorithm="partition"
)

shap_values = shap_explainer(samples)
print("\nSHAP ready:", len(shap_values), "samples")

for i, sv in enumerate(shap_values):
    txt = samples[i]
    probs = predict([txt])[0]
    cls = int(np.argmax(probs))

    html_obj = shap.plots.text(sv[..., cls], display=False)
    body = getattr(html_obj, "data", str(html_obj))
    page = f"<!doctype html><meta charset='utf-8'>{CSS_FIX}{body}"
    out = f"shap_{i:02d}.html"
    with open(out, "w", encoding="utf-8") as f:
        f.write(page)
    print(f"[SHAP {i:02d}] saved → {out} (pred={labels[cls]}, p={probs[cls]:.2f})")
mask_token_id = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")


def predict_with_manual_mask(text, mask_indices, keep=False, max_length=512):
    wp_tokens = tokenizer.tokenize(text)[: max_length - 2]
    ids = tokenizer.convert_tokens_to_ids(wp_tokens)
    input_ids = [tokenizer.cls_token_id] + ids + [tokenizer.sep_token_id]

    if keep:
        keep_set = set(mask_indices)
        for t_idx in range(len(wp_tokens)):
            pos = 1 + t_idx
            if t_idx not in keep_set:
                input_ids[pos] = mask_token_id
    else:
        for t_idx in mask_indices:
            pos = 1 + t_idx
            if 0 < pos < len(input_ids) - 1:
                input_ids[pos] = mask_token_id

    input_ids_tensor = torch.tensor([input_ids]).to(device)
    attn = torch.ones_like(input_ids_tensor).to(device)
    with torch.no_grad():
        probs = torch.softmax(
            model(input_ids=input_ids_tensor, attention_mask=attn).logits, dim=-1
        )
    return probs.squeeze(0).detach().cpu().numpy()


def topk_token_indices_for_class(sv, class_idx: int, k: int = 10, use_abs: bool = True):
    vals = sv.values[:, class_idx]  # (num_tokens,)
    order = np.argsort(-np.abs(vals)) if use_abs else np.argsort(-vals)
    return order[: min(k, len(order))].tolist()


def faithfulness_metrics(
    text: str, sv, k: int = 10, use_abs: bool = True, class_idx: Optional[int] = None
):
    base_probs = predict([text])[0]
    yhat = int(np.argmax(base_probs)) if class_idx is None else int(class_idx)
    p_full = float(base_probs[yhat])

    topk_idx = topk_token_indices_for_class(sv, yhat, k=k, use_abs=use_abs)

    p_drop = float(predict_with_manual_mask(text, topk_idx, keep=False)[yhat])
    compreh = p_full - p_drop

    p_keep = float(predict_with_manual_mask(text, topk_idx, keep=True)[yhat])
    suff = p_full - p_keep

    return {
        "pred_class": labels[yhat],
        "p_full": p_full,
        "p_drop": p_drop,
        "p_keep": p_keep,
        "comprehensiveness": compreh,
        "sufficiency": suff,
    }


def evaluate_faithfulness(shap_values, samples, k: int = 10, use_abs: bool = True):
    rows = []
    for i, (sv, txt) in enumerate(zip(shap_values, samples)):
        m = faithfulness_metrics(txt, sv, k=k, use_abs=use_abs)
        m["i"] = i
        m["gold"] = str(df.iloc[i][gold_col])
        m["k"] = k
        m["comp_norm"] = m["comprehensiveness"] / (m["p_full"] + 1e-8)
        m["suff_abs"] = abs(m["sufficiency"])
        rows.append(m)
    return pd.DataFrame(rows)


ks = [5, 10, 20]
res = pd.concat(
    [evaluate_faithfulness(shap_values, samples, k=k) for k in ks], ignore_index=True
)

print(
    res.groupby("k")[["comprehensiveness", "comp_norm", "sufficiency", "suff_abs"]]
    .mean()
    .round(3)
)

print(
    res.groupby(["k", "gold"])[
        ["comprehensiveness", "comp_norm", "sufficiency", "suff_abs"]
    ]
    .mean()
    .round(3)
)

res.to_csv("faithfulness_results.csv", index=False)
print("\nSaved: faithfulness_results.csv")


from PIL import Image, ImageChops
import os, glob, math


def trim_png(path, out=None, bg=(255, 255, 255), margin=4):
    im = Image.open(path).convert("RGB")
    bg_im = Image.new("RGB", im.size, bg)
    diff = ImageChops.difference(im, bg_im)
    bbox = diff.getbbox()
    if bbox is None:

        out = out or path
        im.save(out)
        return out
    left, top, right, bottom = bbox
    left = max(0, left - margin)
    top = max(0, top - margin)
    right = min(im.width, right + margin)
    bottom = min(im.height, bottom + margin)
    im2 = im.crop((left, top, right, bottom))

    out = out or path
    im2.save(out)
    return out


def stack_wide_png(
    path, out=None, max_panel_width=1200, n_cols=None, overlap=32, pad=12, bg="white"
):

    im = Image.open(path).convert("RGBA")
    W, H = im.size
    if n_cols is None:
        n_cols = max(2, math.ceil(W / max_panel_width))
    col_w = math.ceil(W / n_cols)

    panels = []
    x0 = 0
    for i in range(n_cols):
        x1 = min(W, x0 + col_w + (overlap if i < n_cols - 1 else 0))
        panels.append(im.crop((x0, 0, x1, H)))
        x0 += col_w

    newW = max(p.width for p in panels)
    newH = sum(p.height for p in panels) + pad * (len(panels) - 1)
    bg_rgba = (255, 255, 255, 0) if bg == "transparent" else (255, 255, 255, 255)
    out_im = Image.new("RGBA", (newW, newH), bg_rgba)

    y = 0
    for p in panels:
        out_im.paste(p, (0, y))
        y += p.height + pad

    root, ext = os.path.splitext(path)
    out = out or f"{root}_stack{n_cols}{ext}"
    out_im.save(out)
    print("saved:", out)
    return out


for p in sorted(glob.glob("Fig/*.png")):
    trim_png(p, margin=6)

for p in sorted(glob.glob("Fig/shap_*.png")):
    stacked = stack_wide_png(p, n_cols=2, max_panel_width=1400, overlap=36, pad=18)
    trim_png(stacked, margin=4)


from PIL import Image
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager


def _open_driver(width=2400, scale=1.8):
    opts = Options()
    opts.add_argument("--headless=new")
    opts.add_argument("--disable-gpu")
    opts.add_argument("--hide-scrollbars")
    opts.add_argument(f"--window-size={width},1200")
    opts.add_argument(f"--force-device-scale-factor={scale}")
    service = Service(ChromeDriverManager().install())
    return webdriver.Chrome(service=service, options=opts)


def _smart_lime_bbox(driver, container_sel="#explanation", min_area=35000):
    script = """
    const root = document.querySelector(arguments[0]) || document.body;
    const minArea = arguments[1];
    const els = root.querySelectorAll("*");
    const boxes = [];
    for (const el of els) {
      const cs = getComputedStyle(el);
      if (cs.display==='none' || cs.visibility==='hidden' || parseFloat(cs.opacity)===0) continue;
      const r = el.getBoundingClientRect();
      if (r.width < 2 || r.height < 2) continue;
      const tag = el.tagName.toLowerCase();
      const area = r.width * r.height;
      const cls = (el.className || "").toString();
      const importantText = /highlight|text|raw|table|explanation|prob|prediction/i.test(cls);
      const graphic = (tag==='svg' || tag==='canvas' || tag==='img');
      if (graphic || importantText || area > minArea) {
        boxes.push(r);
      }
    }
    if (boxes.length === 0) return null;
    let minX=Infinity, minY=Infinity, maxX=-Infinity, maxY=-Infinity;
    for (const b of boxes){ minX=Math.min(minX,b.left); minY=Math.min(minY,b.top); maxX=Math.max(maxX,b.right); maxY=Math.max(maxY,b.bottom); }
    return {x:minX, y:minY, w:maxX-minX, h:maxY-minY};
    """
    return driver.execute_script(script, container_sel, min_area)


def _fallback_union_bbox(driver):
    return driver.execute_script(
        """
      const els = document.body.querySelectorAll('*');
      let minX=Infinity,minY=Infinity,maxX=-Infinity,maxY=-Infinity, found=false;
      for(const el of els){
        const cs = getComputedStyle(el);
        if(cs.display==='none'||cs.visibility==='hidden'||parseFloat(cs.opacity)===0) continue;
        const r = el.getBoundingClientRect();
        if(r.width<2||r.height<2) continue;
        minX=Math.min(minX,r.left); minY=Math.min(minY,r.top);
        maxX=Math.max(maxX,r.right); maxY=Math.max(maxY,r.bottom);
        found=true;
      }
      if(!found) return null;
      return {x:minX,y:minY,w:maxX-minX,h:maxY-minY};
    """
    )


def html_to_png_lime(
    html_path,
    out_png,
    container_candidates=(
        "#explanation",
        ".explanation",
        ".explanation-container",
        ".lime",
        ".lime-container",
    ),
    width=2400,
    scale=1.8,
    pad_x=28,
    pad_y=16,
    extra_h=800,
):
    abspath = os.path.abspath(html_path).replace("\\", "/")
    url = "file:///" + abspath
    drv = _open_driver(width=width, scale=scale)
    try:
        drv.get(url)
        time.sleep(0.6)

        doc_h = drv.execute_script(
            "return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight);"
        )
        drv.set_window_size(width, int(doc_h) + extra_h)
        time.sleep(0.1)

        rect = None
        for sel in container_candidates:
            rect = _smart_lime_bbox(drv, sel, min_area=35000)
            if rect:
                break
        if rect is None:
            rect = _fallback_union_bbox(drv)
            if rect is None:
                raise RuntimeError("LIME: no visible content to capture.")

        dpr = drv.execute_script("return window.devicePixelRatio || 1;")
        left = int((rect["x"] - pad_x) * dpr)
        top = int((rect["y"] - pad_y) * dpr)
        right = int((rect["x"] + rect["w"] + pad_x) * dpr)
        bottom = int((rect["y"] + rect["h"] + pad_y) * dpr)

        png = drv.get_screenshot_as_png()
        img = Image.open(io.BytesIO(png))
        left = max(0, left)
        top = max(0, top)
        right = min(img.width, right)
        bottom = min(img.height, bottom)
        img.crop((left, top, right, bottom)).save(out_png)
        print(f"saved (LIME): {out_png}")
    finally:
        drv.quit()


for p in sorted(glob.glob("lime_*.html")):
    html_to_png_lime(
        p, p.replace(".html", ".png"), width=2400, scale=1.9, pad_x=36, pad_y=18
    )


[LIME 00] gold: bug | pred: bug (conf 0.93)

[LIME 01] gold: bug | pred: bug (conf 0.93)

[LIME 02] gold: bug | pred: bug (conf 0.79)

[LIME 03] gold: bug | pred: bug (conf 0.90)

[LIME 04] gold: bug | pred: bug (conf 0.95)

[LIME 05] gold: bug | pred: bug (conf 0.94)

[LIME 06] gold: bug | pred: bug (conf 0.95)

[LIME 07] gold: bug | pred: bug (conf 0.95)

[LIME 08] gold: bug | pred: question (conf 0.57)

[LIME 09] gold: bug | pred: bug (conf 0.92)

[LIME 10] gold: enhancement | pred: enhancement (conf 0.98)

[LIME 11] gold: enhancement | pred: enhancement (conf 0.95)

[LIME 12] gold: enhancement | pred: enhancement (conf 0.91)

[LIME 13] gold: enhancement | pred: enhancement (conf 0.85)

[LIME 14] gold: enhancement | pred: enhancement (conf 0.82)

[LIME 15] gold: enhancement | pred: enhancement (conf 0.98)

[LIME 16] gold: enhancement | pred: enhancement (conf 0.98)

[LIME 17] gold: enhancement | pred: enhancement (conf 0.96)

[LIME 18] gold: enhancement | pred: bug (conf 0.97)

[LI

PartitionExplainer explainer: 31it [00:55,  2.06s/it]                        



SHAP ready: 30 samples
[SHAP 00] saved → shap_00.html (pred=bug, p=0.93)
[SHAP 01] saved → shap_01.html (pred=bug, p=0.93)
[SHAP 02] saved → shap_02.html (pred=bug, p=0.79)
[SHAP 03] saved → shap_03.html (pred=bug, p=0.90)
[SHAP 04] saved → shap_04.html (pred=bug, p=0.95)
[SHAP 05] saved → shap_05.html (pred=bug, p=0.94)
[SHAP 06] saved → shap_06.html (pred=bug, p=0.95)
[SHAP 07] saved → shap_07.html (pred=bug, p=0.95)
[SHAP 08] saved → shap_08.html (pred=question, p=0.57)
[SHAP 09] saved → shap_09.html (pred=bug, p=0.92)
[SHAP 10] saved → shap_10.html (pred=enhancement, p=0.98)
[SHAP 11] saved → shap_11.html (pred=enhancement, p=0.95)
[SHAP 12] saved → shap_12.html (pred=enhancement, p=0.91)
[SHAP 13] saved → shap_13.html (pred=enhancement, p=0.85)
[SHAP 14] saved → shap_14.html (pred=enhancement, p=0.82)
[SHAP 15] saved → shap_15.html (pred=enhancement, p=0.98)
[SHAP 16] saved → shap_16.html (pred=enhancement, p=0.98)
[SHAP 17] saved → shap_17.html (pred=enhancement, p=0.96)
[SHAP 1