In [None]:
from huggingface_hub import login

login("my_token")

In [None]:
import pandas as pd

SRC = "../data/github-labels-top3-803k-test.csv"
label_col = "issue_label"
target_labels = ["bug", "enhancement", "question"]
n_per_class = 10

df_all = pd.read_csv(SRC)
df_all = df_all[df_all[label_col].isin(target_labels)]


df_eval = (
    df_all.groupby(label_col, group_keys=False)
    .sample(n=n_per_class, random_state=200)
    .sort_values([label_col])
    .reset_index(drop=True)
)

df_eval.to_csv("balanced_eval_set.csv", index=False)

print(df_eval[label_col].value_counts())

issue_label
bug            10
enhancement    10
question       10
Name: count, dtype: int64


In [4]:
from transformers import BertForSequenceClassification, BertTokenizer, AutoConfig
from lime.lime_text import LimeTextExplainer
from IPython.display import HTML, display
import pandas as pd
import numpy as np
import shap, torch, os
from typing import Optional

MODEL_PATH = "../models/nlbse/"
config = AutoConfig.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH, config=config)
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

display_labels = ["bug", "enhancement", "question"]
if getattr(model.config, "id2label", None):
    raw = [model.config.id2label[i] for i in range(model.config.num_labels)]
    labels = (
        display_labels if set(raw) == {f"LABEL_{i}" for i in range(len(raw))} else raw
    )
else:
    labels = display_labels
assert len(labels) == model.config.num_labels, "num_labels와 labels 길이가 다릅니다."


def predict(texts):
    enc = tokenizer(
        list(texts), return_tensors="pt", padding=True, truncation=True, max_length=512
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()
    return probs


CSV_PATH = "balanced_eval_set.csv"
if not os.path.exists(CSV_PATH):
    raise FileNotFoundError(
        f"{CSV_PATH} file not found. Run the first cell to create it."
    )
df = pd.read_csv(CSV_PATH).reset_index(drop=True)

text_col_title = "issue_title"
text_col_body = "issue_body"
gold_col = "issue_label"


def build_text(row, title_col=text_col_title, body_col=text_col_body, max_chars=800):
    title = str(row.get(title_col, ""))
    body = str(row.get(body_col, ""))[:max_chars]
    return (title + "\n\n" + body).strip()


samples = [build_text(row) for _, row in df.iterrows()]

explainer = LimeTextExplainer(class_names=labels, random_state=42)

for n, row in enumerate(df.itertuples(index=False)):
    txt = samples[n]
    probs = predict([txt])[0]
    idx = int(probs.argmax())
    predlab = labels[idx]
    conf = float(probs[idx])

    print(
        f"\n[LIME {n:02d}] gold: {getattr(row, gold_col)} | pred: {predlab} (conf {conf:.2f})"
    )
    exp = explainer.explain_instance(
        txt, predict, labels=[idx], num_features=12, num_samples=600
    )

    html = exp.as_html(text=txt)
    # display(HTML(html))
    with open(f"lime_{n:02d}.html", "w", encoding="utf-8") as f:
        f.write(html)


CSS_FIX = """
<style>
  html, body { margin: 8px; font-family: system-ui, -apple-system, Segoe UI, Roboto, sans-serif; }
  .shap { position: relative; }
  .shap .top { margin-bottom: 36px !important; position: relative; z-index: 1; }
  .shap .text, .shap .inputs, .shap .labels { position: relative; z-index: 9999 !important; }
  .shap .text span { line-height: 1.65 !important; } 
  svg { overflow: visible !important; }
</style>
"""


masker = shap.maskers.Text(
    tokenizer=tokenizer, mask_token=(tokenizer.mask_token or "[MASK]")
)
shap_explainer = shap.Explainer(
    predict, masker, output_names=labels, algorithm="partition"
)

shap_values = shap_explainer(samples)
print("\nSHAP ready:", len(shap_values), "samples")

for i, sv in enumerate(shap_values):
    txt = samples[i]
    probs = predict([txt])[0]
    cls = int(np.argmax(probs))

    html_obj = shap.plots.text(sv[..., cls], display=False)
    body = getattr(html_obj, "data", str(html_obj))
    page = f"<!doctype html><meta charset='utf-8'>{CSS_FIX}{body}"
    out = f"shap_{i:02d}.html"
    with open(out, "w", encoding="utf-8") as f:
        f.write(page)
    print(f"[SHAP {i:02d}] saved → {out} (pred={labels[cls]}, p={probs[cls]:.2f})")
mask_token_id = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")


def predict_with_manual_mask(text, mask_indices, keep=False, max_length=512):
    wp_tokens = tokenizer.tokenize(text)[: max_length - 2]
    ids = tokenizer.convert_tokens_to_ids(wp_tokens)
    input_ids = [tokenizer.cls_token_id] + ids + [tokenizer.sep_token_id]

    if keep:
        keep_set = set(mask_indices)
        for t_idx in range(len(wp_tokens)):
            pos = 1 + t_idx
            if t_idx not in keep_set:
                input_ids[pos] = mask_token_id
    else:
        for t_idx in mask_indices:
            pos = 1 + t_idx
            if 0 < pos < len(input_ids) - 1:
                input_ids[pos] = mask_token_id

    input_ids_tensor = torch.tensor([input_ids]).to(device)
    attn = torch.ones_like(input_ids_tensor).to(device)
    with torch.no_grad():
        probs = torch.softmax(
            model(input_ids=input_ids_tensor, attention_mask=attn).logits, dim=-1
        )
    return probs.squeeze(0).detach().cpu().numpy()


def topk_token_indices_for_class(sv, class_idx: int, k: int = 10, use_abs: bool = True):
    vals = sv.values[:, class_idx]  # (num_tokens,)
    order = np.argsort(-np.abs(vals)) if use_abs else np.argsort(-vals)
    return order[: min(k, len(order))].tolist()


def faithfulness_metrics(
    text: str, sv, k: int = 10, use_abs: bool = True, class_idx: Optional[int] = None
):
    base_probs = predict([text])[0]
    yhat = int(np.argmax(base_probs)) if class_idx is None else int(class_idx)
    p_full = float(base_probs[yhat])

    topk_idx = topk_token_indices_for_class(sv, yhat, k=k, use_abs=use_abs)

    p_drop = float(predict_with_manual_mask(text, topk_idx, keep=False)[yhat])
    compreh = p_full - p_drop

    p_keep = float(predict_with_manual_mask(text, topk_idx, keep=True)[yhat])
    suff = p_full - p_keep

    return {
        "pred_class": labels[yhat],
        "p_full": p_full,
        "p_drop": p_drop,
        "p_keep": p_keep,
        "comprehensiveness": compreh,
        "sufficiency": suff,
    }


def evaluate_faithfulness(shap_values, samples, k: int = 10, use_abs: bool = True):
    rows = []
    for i, (sv, txt) in enumerate(zip(shap_values, samples)):
        m = faithfulness_metrics(txt, sv, k=k, use_abs=use_abs)
        m["i"] = i
        m["gold"] = str(df.iloc[i][gold_col])
        m["k"] = k
        m["comp_norm"] = m["comprehensiveness"] / (m["p_full"] + 1e-8)
        m["suff_abs"] = abs(m["sufficiency"])
        rows.append(m)
    return pd.DataFrame(rows)


ks = [5, 10, 20]
res = pd.concat(
    [evaluate_faithfulness(shap_values, samples, k=k) for k in ks], ignore_index=True
)

print(
    res.groupby("k")[["comprehensiveness", "comp_norm", "sufficiency", "suff_abs"]]
    .mean()
    .round(3)
)

print(
    res.groupby(["k", "gold"])[
        ["comprehensiveness", "comp_norm", "sufficiency", "suff_abs"]
    ]
    .mean()
    .round(3)
)

res.to_csv("faithfulness_results.csv", index=False)
print("\nSaved: faithfulness_results.csv")


from PIL import Image, ImageChops
import os, glob, math


def trim_png(path, out=None, bg=(255, 255, 255), margin=4):
    im = Image.open(path).convert("RGB")
    bg_im = Image.new("RGB", im.size, bg)
    diff = ImageChops.difference(im, bg_im)
    bbox = diff.getbbox()
    if bbox is None:

        out = out or path
        im.save(out)
        return out
    left, top, right, bottom = bbox
    left = max(0, left - margin)
    top = max(0, top - margin)
    right = min(im.width, right + margin)
    bottom = min(im.height, bottom + margin)
    im2 = im.crop((left, top, right, bottom))

    out = out or path
    im2.save(out)
    return out


def stack_wide_png(
    path, out=None, max_panel_width=1200, n_cols=None, overlap=32, pad=12, bg="white"
):

    im = Image.open(path).convert("RGBA")
    W, H = im.size
    if n_cols is None:
        n_cols = max(2, math.ceil(W / max_panel_width))
    col_w = math.ceil(W / n_cols)

    panels = []
    x0 = 0
    for i in range(n_cols):
        x1 = min(W, x0 + col_w + (overlap if i < n_cols - 1 else 0))
        panels.append(im.crop((x0, 0, x1, H)))
        x0 += col_w

    newW = max(p.width for p in panels)
    newH = sum(p.height for p in panels) + pad * (len(panels) - 1)
    bg_rgba = (255, 255, 255, 0) if bg == "transparent" else (255, 255, 255, 255)
    out_im = Image.new("RGBA", (newW, newH), bg_rgba)

    y = 0
    for p in panels:
        out_im.paste(p, (0, y))
        y += p.height + pad

    root, ext = os.path.splitext(path)
    out = out or f"{root}_stack{n_cols}{ext}"
    out_im.save(out)
    print("saved:", out)
    return out


for p in sorted(glob.glob("Fig/*.png")):
    trim_png(p, margin=6)

for p in sorted(glob.glob("Fig/shap_*.png")):
    stacked = stack_wide_png(p, n_cols=2, max_panel_width=1400, overlap=36, pad=18)
    trim_png(stacked, margin=4)


from PIL import Image
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager


def _open_driver(width=2400, scale=1.8):
    opts = Options()
    opts.add_argument("--headless=new")
    opts.add_argument("--disable-gpu")
    opts.add_argument("--hide-scrollbars")
    opts.add_argument(f"--window-size={width},1200")
    opts.add_argument(f"--force-device-scale-factor={scale}")
    service = Service(ChromeDriverManager().install())
    return webdriver.Chrome(service=service, options=opts)


def _smart_lime_bbox(driver, container_sel="#explanation", min_area=35000):
    script = """
    const root = document.querySelector(arguments[0]) || document.body;
    const minArea = arguments[1];
    const els = root.querySelectorAll("*");
    const boxes = [];
    for (const el of els) {
      const cs = getComputedStyle(el);
      if (cs.display==='none' || cs.visibility==='hidden' || parseFloat(cs.opacity)===0) continue;
      const r = el.getBoundingClientRect();
      if (r.width < 2 || r.height < 2) continue;
      const tag = el.tagName.toLowerCase();
      const area = r.width * r.height;
      const cls = (el.className || "").toString();
      const importantText = /highlight|text|raw|table|explanation|prob|prediction/i.test(cls);
      const graphic = (tag==='svg' || tag==='canvas' || tag==='img');
      if (graphic || importantText || area > minArea) {
        boxes.push(r);
      }
    }
    if (boxes.length === 0) return null;
    let minX=Infinity, minY=Infinity, maxX=-Infinity, maxY=-Infinity;
    for (const b of boxes){ minX=Math.min(minX,b.left); minY=Math.min(minY,b.top); maxX=Math.max(maxX,b.right); maxY=Math.max(maxY,b.bottom); }
    return {x:minX, y:minY, w:maxX-minX, h:maxY-minY};
    """
    return driver.execute_script(script, container_sel, min_area)


def _fallback_union_bbox(driver):
    return driver.execute_script(
        """
      const els = document.body.querySelectorAll('*');
      let minX=Infinity,minY=Infinity,maxX=-Infinity,maxY=-Infinity, found=false;
      for(const el of els){
        const cs = getComputedStyle(el);
        if(cs.display==='none'||cs.visibility==='hidden'||parseFloat(cs.opacity)===0) continue;
        const r = el.getBoundingClientRect();
        if(r.width<2||r.height<2) continue;
        minX=Math.min(minX,r.left); minY=Math.min(minY,r.top);
        maxX=Math.max(maxX,r.right); maxY=Math.max(maxY,r.bottom);
        found=true;
      }
      if(!found) return null;
      return {x:minX,y:minY,w:maxX-minX,h:maxY-minY};
    """
    )


def html_to_png_lime(
    html_path,
    out_png,
    container_candidates=(
        "#explanation",
        ".explanation",
        ".explanation-container",
        ".lime",
        ".lime-container",
    ),
    width=2400,
    scale=1.8,
    pad_x=28,
    pad_y=16,
    extra_h=800,
):
    abspath = os.path.abspath(html_path).replace("\\", "/")
    url = "file:///" + abspath
    drv = _open_driver(width=width, scale=scale)
    try:
        drv.get(url)
        time.sleep(0.6)

        doc_h = drv.execute_script(
            "return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight);"
        )
        drv.set_window_size(width, int(doc_h) + extra_h)
        time.sleep(0.1)

        rect = None
        for sel in container_candidates:
            rect = _smart_lime_bbox(drv, sel, min_area=35000)
            if rect:
                break
        if rect is None:
            rect = _fallback_union_bbox(drv)
            if rect is None:
                raise RuntimeError("LIME: no visible content to capture.")

        dpr = drv.execute_script("return window.devicePixelRatio || 1;")
        left = int((rect["x"] - pad_x) * dpr)
        top = int((rect["y"] - pad_y) * dpr)
        right = int((rect["x"] + rect["w"] + pad_x) * dpr)
        bottom = int((rect["y"] + rect["h"] + pad_y) * dpr)

        png = drv.get_screenshot_as_png()
        img = Image.open(io.BytesIO(png))
        left = max(0, left)
        top = max(0, top)
        right = min(img.width, right)
        bottom = min(img.height, bottom)
        img.crop((left, top, right, bottom)).save(out_png)
        print(f"saved (LIME): {out_png}")
    finally:
        drv.quit()


for p in sorted(glob.glob("lime_*.html")):
    html_to_png_lime(
        p, p.replace(".html", ".png"), width=2400, scale=1.9, pad_x=36, pad_y=18
    )


[LIME 00] gold: bug | pred: bug (conf 0.93)

[LIME 01] gold: bug | pred: bug (conf 0.93)

[LIME 02] gold: bug | pred: bug (conf 0.79)

[LIME 03] gold: bug | pred: bug (conf 0.90)

[LIME 04] gold: bug | pred: bug (conf 0.95)

[LIME 05] gold: bug | pred: bug (conf 0.94)

[LIME 06] gold: bug | pred: bug (conf 0.95)

[LIME 07] gold: bug | pred: bug (conf 0.95)

[LIME 08] gold: bug | pred: question (conf 0.57)

[LIME 09] gold: bug | pred: bug (conf 0.92)

[LIME 10] gold: enhancement | pred: enhancement (conf 0.98)

[LIME 11] gold: enhancement | pred: enhancement (conf 0.95)

[LIME 12] gold: enhancement | pred: enhancement (conf 0.91)

[LIME 13] gold: enhancement | pred: enhancement (conf 0.85)

[LIME 14] gold: enhancement | pred: enhancement (conf 0.82)

[LIME 15] gold: enhancement | pred: enhancement (conf 0.98)

[LIME 16] gold: enhancement | pred: enhancement (conf 0.98)

[LIME 17] gold: enhancement | pred: enhancement (conf 0.96)

[LIME 18] gold: enhancement | pred: bug (conf 0.97)

[LI

PartitionExplainer explainer: 31it [00:55,  2.06s/it]                        



SHAP ready: 30 samples
[SHAP 00] saved → shap_00.html (pred=bug, p=0.93)
[SHAP 01] saved → shap_01.html (pred=bug, p=0.93)
[SHAP 02] saved → shap_02.html (pred=bug, p=0.79)
[SHAP 03] saved → shap_03.html (pred=bug, p=0.90)
[SHAP 04] saved → shap_04.html (pred=bug, p=0.95)
[SHAP 05] saved → shap_05.html (pred=bug, p=0.94)
[SHAP 06] saved → shap_06.html (pred=bug, p=0.95)
[SHAP 07] saved → shap_07.html (pred=bug, p=0.95)
[SHAP 08] saved → shap_08.html (pred=question, p=0.57)
[SHAP 09] saved → shap_09.html (pred=bug, p=0.92)
[SHAP 10] saved → shap_10.html (pred=enhancement, p=0.98)
[SHAP 11] saved → shap_11.html (pred=enhancement, p=0.95)
[SHAP 12] saved → shap_12.html (pred=enhancement, p=0.91)
[SHAP 13] saved → shap_13.html (pred=enhancement, p=0.85)
[SHAP 14] saved → shap_14.html (pred=enhancement, p=0.82)
[SHAP 15] saved → shap_15.html (pred=enhancement, p=0.98)
[SHAP 16] saved → shap_16.html (pred=enhancement, p=0.98)
[SHAP 17] saved → shap_17.html (pred=enhancement, p=0.96)
[SHAP 1

In [None]:
# ===================== ALL-IN-ONE: LIME + SHAP + Faithfulness =====================
from transformers import BertForSequenceClassification, BertTokenizer, AutoConfig
from lime.lime_text import LimeTextExplainer
from IPython.display import HTML, display
import pandas as pd
import numpy as np
import shap, torch, os

# ---- 0) 모델/토크나이저/라벨 ----------------------------------------------------
MODEL_PATH = "../models/nlbse/"
config = AutoConfig.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH, config=config)
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# 사람이 읽는 라벨(모델 내부가 LABEL_0/1/2여도 표시용으로 교체)
display_labels = ["bug", "enhancement", "question"]
if getattr(model.config, "id2label", None):
    raw = [model.config.id2label[i] for i in range(model.config.num_labels)]
    labels = (
        display_labels if set(raw) == {f"LABEL_{i}" for i in range(len(raw))} else raw
    )
else:
    labels = display_labels
assert len(labels) == model.config.num_labels, "num_labels와 labels 길이가 다릅니다."


def predict(texts):
    enc = tokenizer(
        list(texts), return_tensors="pt", padding=True, truncation=True, max_length=512
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()
    return probs


# ---- 1) 동일 샘플 로드(균형 샘플) -----------------------------------------------
CSV_PATH = "balanced_eval_set.csv"
if not os.path.exists(CSV_PATH):
    raise FileNotFoundError(
        f"{CSV_PATH} 파일이 없습니다. 먼저 균형 샘플을 만들어 저장하세요."
    )
df = pd.read_csv(CSV_PATH).reset_index(drop=True)

text_col_title = "issue_title"
text_col_body = "issue_body"
gold_col = "issue_label"


def build_text(row, title_col=text_col_title, body_col=text_col_body, max_chars=800):
    title = str(row.get(title_col, ""))
    body = str(row.get(body_col, ""))[:max_chars]
    return (title + "\n\n" + body).strip()


samples = [build_text(row) for _, row in df.iterrows()]

# ---- 2) LIME --------------------------------------------------------------------
explainer = LimeTextExplainer(class_names=labels, random_state=42)

for n, row in enumerate(df.itertuples(index=False)):
    txt = samples[n]
    probs = predict([txt])[0]
    idx = int(probs.argmax())
    predlab = labels[idx]
    conf = float(probs[idx])

    print(
        f"\n[LIME {n:02d}] gold: {getattr(row, gold_col)} | pred: {predlab} (conf {conf:.2f})"
    )
    exp = explainer.explain_instance(
        txt, predict, labels=[idx], num_features=12, num_samples=600
    )

    html = exp.as_html(text=txt)
    display(HTML(html))
    with open(f"lime_{n:02d}.html", "w", encoding="utf-8") as f:
        f.write(html)

# ---- 3) SHAP (samples 동일) -----------------------------------------------------
masker = shap.maskers.Text(
    tokenizer=tokenizer, mask_token=(tokenizer.mask_token or "[MASK]")
)
shap_explainer = shap.Explainer(
    predict, masker, output_names=labels, algorithm="partition"
)

# 느리면 부분만: e.g., idx = list(range(30)); samples = [samples[i] for i in idx]; df = df.iloc[idx].reset_index(drop=True)
shap_values = shap_explainer(samples)
print("\nSHAP ready:", len(shap_values), "samples")

# 저장(예측 클래스 기준) + 헤더에 gold/pred 표시
for i, sv in enumerate(shap_values):
    txt = samples[i]
    probs = predict([txt])[0]
    cls = int(np.argmax(probs))
    predlab = labels[cls]
    conf = float(probs[cls])
    gold = str(df.iloc[i][gold_col])

    html_obj = shap.plots.text(sv[..., cls], display=False)
    body = getattr(html_obj, "data", str(html_obj))

    with open(f"shap_{i:02d}.html", "w", encoding="utf-8") as f:
        f.write(
            f"<meta charset='utf-8'>\n<h2>[{i:02d}] gold: {gold} | pred: {predlab} (conf {conf:.2f})</h2>\n"
            + body
        )
    print(
        f"[SHAP {i:02d}] gold={gold} | pred={predlab} (conf={conf:.2f}) → shap_{i:02d}.html"
    )

# ---- 4) Faithfulness (Comprehensiveness / Sufficiency) --------------------------
mask_token_id = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")


def predict_with_manual_mask(text, mask_indices, keep=False, max_length=512):
    # WordPiece 토큰을 선택적으로 [MASK] 처리 후 모델 확률 반환
    wp_tokens = tokenizer.tokenize(text)[: max_length - 2]
    ids = tokenizer.convert_tokens_to_ids(wp_tokens)
    input_ids = [tokenizer.cls_token_id] + ids + [tokenizer.sep_token_id]

    if keep:
        keep_set = set(mask_indices)
        for t_idx in range(len(wp_tokens)):
            pos = 1 + t_idx
            if t_idx not in keep_set:
                input_ids[pos] = mask_token_id
    else:
        for t_idx in mask_indices:
            pos = 1 + t_idx
            if 0 < pos < len(input_ids) - 1:
                input_ids[pos] = mask_token_id

    input_ids_tensor = torch.tensor([input_ids]).to(device)
    attn = torch.ones_like(input_ids_tensor).to(device)
    with torch.no_grad():
        probs = torch.softmax(
            model(input_ids=input_ids_tensor, attention_mask=attn).logits, dim=-1
        )
    return probs.squeeze(0).detach().cpu().numpy()

In [None]:
# === 6) Post-process: trim whitespace + split-wide-and-stack ====================
from PIL import Image, ImageChops
import os, glob, math


def trim_png(path, out=None, bg=(255, 255, 255), margin=4):
    """
    PNG 가장자리의 불필요한 흰 여백을 자동으로 잘라냅니다.
    margin: 잘라낸 뒤 남겨둘 최소 여백(px)
    """
    im = Image.open(path).convert("RGB")
    bg_im = Image.new("RGB", im.size, bg)
    diff = ImageChops.difference(im, bg_im)
    bbox = diff.getbbox()
    if bbox is None:
        # 트리밍할 게 없으면 그대로 저장
        out = out or path
        im.save(out)
        return out

    # 마진을 살짝 더 남김
    left, top, right, bottom = bbox
    left = max(0, left - margin)
    top = max(0, top - margin)
    right = min(im.width, right + margin)
    bottom = min(im.height, bottom + margin)
    im2 = im.crop((left, top, right, bottom))

    out = out or path
    im2.save(out)
    return out


def stack_wide_png(
    path, out=None, max_panel_width=1200, n_cols=None, overlap=32, pad=12, bg="white"
):
    """
    가로로 긴 PNG를 좌→우로 2~3등분하여, 위→아래로 이어붙인 PNG 생성.

    - max_panel_width: Overleaf 한 줄에 들어가길 원하는 '패널 한 장의 최대 폭'
    - n_cols: None이면 원본 폭으로 자동 계산(보통 SHAP는 2, 아주 길면 3 권장)
    - overlap: 잘리는 경계에 좌우로 겹침을 주어 단어가 반 잘리는 문제 완화
    - pad: 패널들 사이 세로 간격(px)
    """
    im = Image.open(path).convert("RGBA")
    W, H = im.size
    if n_cols is None:
        n_cols = max(2, math.ceil(W / max_panel_width))
    col_w = math.ceil(W / n_cols)

    panels = []
    x0 = 0
    for i in range(n_cols):
        x1 = min(W, x0 + col_w + (overlap if i < n_cols - 1 else 0))
        panels.append(im.crop((x0, 0, x1, H)))
        x0 += col_w

    newW = max(p.width for p in panels)
    newH = sum(p.height for p in panels) + pad * (len(panels) - 1)
    bg_rgba = (255, 255, 255, 0) if bg == "transparent" else (255, 255, 255, 255)
    out_im = Image.new("RGBA", (newW, newH), bg_rgba)

    y = 0
    for p in panels:
        out_im.paste(p, (0, y))
        y += p.height + pad

    root, ext = os.path.splitext(path)
    out = out or f"{root}_stack{n_cols}{ext}"
    out_im.save(out)
    print("saved:", out)
    return out


# --- 6.1 모든 PNG 여백 트리밍 ---
for p in sorted(glob.glob("Fig/*.png")):
    trim_png(p, margin=6)

# --- 6.2 SHAP 그림은 2등분(아주 길면 3등분)해서 세로 스택 버전 생성 ---
for p in sorted(glob.glob("Fig/shap_*.png")):
    # 필요 시 n_cols=3 로 바꾸면 더 짧은 패널을 얻을 수 있습니다.
    stacked = stack_wide_png(p, n_cols=2, max_panel_width=1400, overlap=36, pad=18)
    # 스택 결과도 한번 더 트리밍(경계 여백 제거)
    trim_png(stacked, margin=4)

# (선택) LIME도 가로가 길면 동일 처리
# for p in sorted(glob.glob("Fig/lime_*.png")):
#     stacked = stack_wide_png(p, n_cols=2, max_panel_width=1400, overlap=36, pad=18)
#     trim_png(stacked, margin=4)

In [6]:
# ================= MULTI-CLASS SHAP (LIME-Style) =================
from transformers import BertForSequenceClassification, BertTokenizer, AutoConfig
import torch, numpy as np, pandas as pd, shap, os

# ---------- 0) 모델/토크나이저/라벨 ----------
MODEL_PATH = "../models/nlbse/"
config = AutoConfig.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH, config=config)
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

display_labels = ["bug", "enhancement", "question"]
if getattr(model.config, "id2label", None):
    raw = [model.config.id2label[i] for i in range(model.config.num_labels)]
    labels = (
        display_labels if set(raw) == {f"LABEL_{i}" for i in range(len(raw))} else raw
    )
else:
    labels = display_labels
assert len(labels) == model.config.num_labels, "labels 길이와 num_labels가 다름"


def predict(texts):
    enc = tokenizer(
        list(texts), return_tensors="pt", padding=True, truncation=True, max_length=512
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        probs = torch.softmax(model(**enc).logits, dim=-1).cpu().numpy()
    return probs  # (N, C)


# ---------- 1) 동일 샘플 로드 ----------
CSV_PATH = "balanced_eval_set.csv"
if not os.path.exists(CSV_PATH):
    raise FileNotFoundError("balanced_eval_set.csv 가 없습니다.")

df = pd.read_csv(CSV_PATH).reset_index(drop=True)
text_col_title = "issue_title"
text_col_body = "issue_body"
gold_col = "issue_label"


def build_text(row, title_col=text_col_title, body_col=text_col_body, max_chars=800):
    title = str(row.get(title_col, ""))
    body = str(row.get(body_col, ""))[:max_chars]
    return (title + "\n\n" + body).strip()


samples = [build_text(row) for _, row in df.iterrows()]

# ---------- 2) SHAP 계산 ----------
masker = shap.maskers.Text(
    tokenizer=tokenizer, mask_token=(tokenizer.mask_token or "[MASK]")
)
explainer = shap.Explainer(predict, masker, output_names=labels, algorithm="partition")
shap_values = explainer(
    samples
)  # len == N, 각 sv.values.shape == (num_tokens, num_classes)


# ---------- 3) LIME-스타일 토큰만 색칠하는 HTML 유틸 ----------
def _needs_space(tok):
    # WordPiece 접두 '##'는 앞 공백 X, 일반 토큰은 공백
    if tok.startswith("##"):
        return False
    # 붙여 쓰는 구두점
    if tok in {".", ",", "!", "?", ":", ";", ")", "]", "}", "'", '"'}:
        return False
    return True


def _clean_token(tok):
    return tok[2:] if tok.startswith("##") else tok


def tokens_html_for_class(sv, class_idx, topk=None):
    """
    sv: shap.Explanation (sv.values.shape == [T, C])
    class_idx: 시각화할 클래스 인덱스
    topk: 상위 k 토큰만 강조(나머지는 연한 회색). None이면 전체 강조
    return: <div>...</div> (토큰만 들어있는 HTML 조각)
    """
    import html as _html

    tokens = list(sv.data)
    vals = np.array(sv.values[:, class_idx], dtype=float)

    if topk is not None and topk < len(vals):
        keep = np.argsort(-np.abs(vals))[:topk]
        mask = np.zeros_like(vals, dtype=bool)
        mask[keep] = True
    else:
        mask = np.ones_like(vals, dtype=bool)

    max_abs = float(np.max(np.abs(vals)) + 1e-12)
    spans = []
    for t, v, keep_it in zip(tokens, vals, mask):
        space = " " if _needs_space(t) and len(spans) > 0 else ""
        tok = _clean_token(t)
        tok = _html.escape(tok)

        if keep_it:
            alpha = 0.15 + 0.85 * (abs(v) / max_abs)
            color = (
                ("rgba(255,0,0,%f)" % alpha)
                if v >= 0
                else ("rgba(0,102,255,%f)" % alpha)
            )
        else:
            color = "rgba(128,128,128,0.12)"  # 비강조

        spans.append(
            f"{space}<span style='background:{color}; padding:2px 2px; border-radius:3px'>{tok}</span>"
        )

    return "<div>" + "".join(spans) + "</div>"


# ---------- 4) 멀티클래스 한 페이지로 저장 ----------
TOPK = 20  # 상위 20개만 진하게(원하면 None으로 전체 강조)
os.makedirs("multiclass_html", exist_ok=True)

for i, sv in enumerate(shap_values):
    probs = predict([samples[i]])[0]  # (C,)
    order = probs.argsort()[::-1]  # 확률 내림차순
    gold = str(df.iloc[i][gold_col])
    header = (
        "<meta charset='utf-8'>"
        "<style>body{font-family:-apple-system,Segoe UI,Roboto,Arial;line-height:1.8;font-size:16px}"
        ".cl{margin:12px 0 22px 0;padding:8px 12px;border:1px solid #eee;border-radius:8px}"
        ".lab{font-weight:600;margin-bottom:6px}"
        ".legend{margin:6px 0 12px 0;font-size:14px}"
        "</style>"
        f"<h2>[{i:02d}] gold: {gold} | pred: {labels[int(order[0])]} "
        f"(p={probs[int(order[0])]:.2f})</h2>"
        "<div class='legend'>"
        "<span style='display:inline-block;width:12px;height:12px;background:rgba(255,0,0,0.6);margin-right:6px;vertical-align:middle'></span>"
        "positive for class&nbsp;&nbsp;"
        "<span style='display:inline-block;width:12px;height:12px;background:rgba(0,102,255,0.6);margin:0 6px 0 16px;vertical-align:middle'></span>"
        "negative"
        "</div>"
        "<div style='margin:6px 0 12px 0'>"
        + " | ".join(f"{labels[c]}: {probs[c]:.2f}" for c in order)
        + "</div>"
    )

    # 세 클래스 모두 섹션으로 추가 (확률 높은 순으로 정렬)
    sections = []
    for c in order:
        sec = (
            f"<div class='cl'>"
            f"<div class='lab'>{labels[int(c)]} (p={probs[int(c)]:.2f})</div>"
            f"{tokens_html_for_class(sv, int(c), topk=TOPK)}"
            f"</div>"
        )
        sections.append(sec)

    page = header + "\n".join(sections)
    out = os.path.join("multiclass_html", f"sample_{i:02d}_multiclass.html")
    with open(out, "w", encoding="utf-8") as f:
        f.write(page)
    print(f"saved: {out}")

PartitionExplainer explainer: 31it [01:04,  2.37s/it]                        


saved: multiclass_html\sample_00_multiclass.html
saved: multiclass_html\sample_01_multiclass.html
saved: multiclass_html\sample_02_multiclass.html
saved: multiclass_html\sample_03_multiclass.html
saved: multiclass_html\sample_04_multiclass.html
saved: multiclass_html\sample_05_multiclass.html
saved: multiclass_html\sample_06_multiclass.html
saved: multiclass_html\sample_07_multiclass.html
saved: multiclass_html\sample_08_multiclass.html
saved: multiclass_html\sample_09_multiclass.html
saved: multiclass_html\sample_10_multiclass.html
saved: multiclass_html\sample_11_multiclass.html
saved: multiclass_html\sample_12_multiclass.html
saved: multiclass_html\sample_13_multiclass.html
saved: multiclass_html\sample_14_multiclass.html
saved: multiclass_html\sample_15_multiclass.html
saved: multiclass_html\sample_16_multiclass.html
saved: multiclass_html\sample_17_multiclass.html
saved: multiclass_html\sample_18_multiclass.html
saved: multiclass_html\sample_19_multiclass.html
saved: multiclass_ht

In [5]:
# ===== SHAP 결과를 LIME 스타일(토큰만 색 하이라이트)로 저장 =====
import re
import numpy as np


def _needs_space(tok: str):
    # BERT WordPiece: '##'로 붙는 토큰은 앞 공백 X, 일반 토큰은 앞에 공백
    if tok.startswith("##"):
        return False
    # 구두점은 앞 공백 없이 붙이기
    if tok in {".", ",", "!", "?", ":", ";", ")", "]", "}", "'", '"'}:
        return False
    return True


def _clean_token(tok: str):
    # BERT WordPiece '##ing' -> 'ing'
    return tok[2:] if tok.startswith("##") else tok


def shap_as_lime_html(sv, class_idx, header="", topk=None):
    """
    sv: shap.Explanation (sv.values.shape == [num_tokens, num_classes])
    class_idx: 시각화할 클래스 인덱스
    topk: 상위 k개 토큰만 색칠(선택). None이면 전부 색칠
    """
    tokens = list(sv.data)
    vals = np.array(sv.values[:, class_idx], dtype=float)

    # 상위 k만 하이라이트하고 나머지는 연한 회색으로 처리(선택)
    if topk is not None and topk < len(vals):
        keep = np.argsort(-np.abs(vals))[:topk]
        mask = np.zeros_like(vals, dtype=bool)
        mask[keep] = True
    else:
        mask = np.ones_like(vals, dtype=bool)

    max_abs = np.max(np.abs(vals)) + 1e-12

    html_tokens = []
    for t, v, keep_it in zip(tokens, vals, mask):
        space = " " if _needs_space(t) and len(html_tokens) > 0 else ""
        tok = _clean_token(t)

        if keep_it:
            # 빨강=해당 클래스에 +기여, 파랑=−기여 (진할수록 영향 큼)
            alpha = 0.15 + 0.85 * (abs(v) / max_abs)
            color = f"rgba(255,0,0,{alpha})" if v >= 0 else f"rgba(0,102,255,{alpha})"
        else:
            color = "rgba(128,128,128,0.15)"  # 연한 회색(비강조)

        html_tokens.append(
            f"{space}<span style='background:{color}; padding:2px 2px; border-radius:3px'>{tok}</span>"
        )

    # 페이지 구성
    legend = (
        "<div style='margin:6px 0 12px 0;font-size:14px'>"
        "<span style='display:inline-block;width:12px;height:12px;background:rgba(255,0,0,0.6);margin-right:6px;vertical-align:middle'></span>"
        "positive for class &nbsp;&nbsp;"
        "<span style='display:inline-block;width:12px;height:12px;background:rgba(0,102,255,0.6);margin:0 6px 0 16px;vertical-align:middle'></span>"
        "negative"
        "</div>"
    )
    page = (
        "<meta charset='utf-8'>"
        "<style>body{font-family:-apple-system,Segoe UI,Roboto,Arial;line-height:1.8;font-size:16px}</style>"
        f"<h2 style='margin:6px 0 6px'>{header}</h2>"
        f"{legend}"
        f"<div>{''.join(html_tokens)}</div>"
    )
    return page


# 예: 각 샘플을 예측 클래스 기준으로 LIME 스타일 저장 (상위 20개만 강조)
for i, sv in enumerate(shap_values):
    probs = predict([samples[i]])[0]
    cls = int(np.argmax(probs))
    predlab = labels[cls]
    conf = float(probs[cls])
    gold = str(df.iloc[i][gold_col])

    html = shap_as_lime_html(
        sv,
        class_idx=cls,
        header=f"[{i:02d}] gold: {gold} | pred: {predlab} (conf {conf:.2f})",
        topk=20,  # ← 원하면 None으로 바꿔 전체 토큰 강조
    )
    with open(f"shap_lime_{i:02d}.html", "w", encoding="utf-8") as f:
        f.write(html)
    print(f"saved shap_lime_{i:02d}.html (class={predlab})")

saved shap_lime_00.html (class=bug)
saved shap_lime_01.html (class=bug)
saved shap_lime_02.html (class=bug)
saved shap_lime_03.html (class=bug)
saved shap_lime_04.html (class=bug)
saved shap_lime_05.html (class=bug)
saved shap_lime_06.html (class=bug)
saved shap_lime_07.html (class=bug)
saved shap_lime_08.html (class=question)
saved shap_lime_09.html (class=bug)
saved shap_lime_10.html (class=enhancement)
saved shap_lime_11.html (class=enhancement)
saved shap_lime_12.html (class=enhancement)
saved shap_lime_13.html (class=enhancement)
saved shap_lime_14.html (class=enhancement)
saved shap_lime_15.html (class=enhancement)
saved shap_lime_16.html (class=enhancement)
saved shap_lime_17.html (class=enhancement)
saved shap_lime_18.html (class=bug)
saved shap_lime_19.html (class=enhancement)
saved shap_lime_20.html (class=question)
saved shap_lime_21.html (class=bug)
saved shap_lime_22.html (class=bug)
saved shap_lime_23.html (class=enhancement)
saved shap_lime_24.html (class=bug)
saved shap

In [None]:
# ================== LIME ↑ / SHAP ↓ 패널 그리기 ==================
import numpy as np, matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from scipy.stats import spearmanr


# ---------- 0) 유틸: WordPiece -> 단어로 합치기 ----------
def merge_wordpiece(tokens, scores):
    words, vals = [], []
    cur, cur_v = "", 0.0
    for t, v in zip(tokens, scores):
        if t.startswith("##"):
            cur += t[2:]
            cur_v += v
        else:
            if cur:
                words.append(cur)
                vals.append(cur_v)
            cur, cur_v = t, v
    if cur:
        words.append(cur)
        vals.append(cur_v)
    return words, vals  # 동일 길이


def topk_pairs_from_dict(d, k=10):
    return sorted(d.items(), key=lambda x: -abs(x[1]))[:k]


# ---------- 1) LIME 특징 상위 K 추출 ----------
def lime_topk(text, lime_explainer, cls_idx, k=10, num_samples=600):
    exp = lime_explainer.explain_instance(
        text,
        predict,
        labels=[cls_idx],
        num_features=max(k, 50),
        num_samples=num_samples,
    )
    feat_list = exp.as_list(label=cls_idx)  # [('token', weight), ...]
    # 일부 토큰이 중복될 수 있어 dict로 합산
    agg = {}
    for w, v in feat_list:
        agg[w] = agg.get(w, 0.0) + v
    topk = topk_pairs_from_dict(agg, k)
    return topk, exp.score  # ([(token, weight), ...], local fidelity R^2)


# ---------- 2) SHAP 특징 상위 K 추출 ----------
def shap_topk(sv, cls_idx, k=10):
    toks = list(sv.data)
    vals = sv.values[:, cls_idx]
    words, scores = merge_wordpiece(toks, vals)
    d = {}
    for w, v in zip(words, scores):
        d[w] = d.get(w, 0.0) + v
    return topk_pairs_from_dict(d, k)


# ---------- 3) 간단 메트릭(근사치) ----------
def agreement_spearman_abs(dict_a, dict_b):
    # 공통 단어에 대해 |score| 스피어만 상관
    common = list(set(dict_a.keys()) & set(dict_b.keys()))
    if len(common) < 3:
        return np.nan
    a = [abs(dict_a[w]) for w in common]
    b = [abs(dict_b[w]) for w in common]
    rho, _ = spearmanr(a, b)
    return float(rho)


def jaccard_at_k(dict_a, dict_b, k=10):
    A = set([w for w, _ in topk_pairs_from_dict(dict_a, k)])
    B = set([w for w, _ in topk_pairs_from_dict(dict_b, k)])
    if not (A or B):
        return np.nan
    return len(A & B) / len(A | B)


def prescriptivity_like(text, ranked_words, k=10):
    """
    '상위 k만 남기고' 확률 유지 정도를 간단 수치로(높을수록 '충분').
    실제 Prescriptivity의 대체 지표: p_keep / p_full
    -> 1에 가까울수록 상위 k만으로 충분.
    ranked_words: [('word', score), ...]
    주의: WordPiece 맵핑 복잡성 때문에 여기선 상위 k '단어'만 남긴
    텍스트 서브스트링 기반의 매우 보수적 근사(간단히 삭제/유지) 대신,
    shap 기반 prescriptivity를 사용하는 것을 권장.
    """
    # 안전하게 shap 기반 prescriptivity로 대체:
    base = predict([text])[0]
    y = int(base.argmax())
    p0 = float(base[y])
    return np.nan, p0  # 필요 시 직접 구현/대체


# ---------- 4) 메트릭 패널을 위한 dict 형태로 수집 ----------
def build_metrics_for_lime(text, lime_dict, shap_dict, lime_score):
    # Reit. Similarity: 간단히 LIME을 5회 반복해 Jaccard@k 평균 (가벼운 근사)
    # 계산 비용이 크면 skip하고 NaN 대입해도 됨.
    k = min(10, len(lime_dict))
    # 여기선 반복 실행 없이 상호 Jaccard를 재사용하거나 NaN 처리
    reit_mean = np.nan
    reit_std = np.nan

    # Local Concordance: LIME vs SHAP 중요도 일치(스피어만)
    lc = agreement_spearman_abs(lime_dict, shap_dict)

    # Local Fidelity: LIME exp.score (R^2)
    lf = float(lime_score)

    # Prescriptivity: 간단 대체(여기선 NaN)
    pres, _ = prescriptivity_like(text, topk_pairs_from_dict(lime_dict, k))
    return {
        "Reit. Similarity": (reit_mean, reit_std),
        "Local Concordance": (lc, 0.0),
        "Local Fidelity": (lf, 0.0),
        "Prescriptivity": (pres, 0.0),
    }


def build_metrics_for_shap(text, shap_dict):
    # SHAP은 결정적이라 Reit=1.0로 간주(근사)
    reit = (1.0, 0.0)
    # Local Concordance: SHAP 내부 일관성 지표가 없어 NaN
    lc = (np.nan, 0.0)
    # Local Fidelity: base + sum(phi) 근사 → 확률과의 절대오차로 근사 (간단)
    base = predict([text])[0]
    y = int(base.argmax())
    p0 = float(base[y])
    # fidelity를 1-|오차|로 근사(매우 보수적)
    approx = p0  # 간단 대체
    lf = (max(0.0, 1.0 - abs(p0 - approx)), 0.0)
    # Prescriptivity: shap 상위 k만 남겼을 때 p_keep / p_full (정의 맞춤)
    # -> 필요시 너의 predict_with_manual_mask + shap 토큰 인덱스로 정확 계산 권장
    pres = (np.nan, 0.0)
    return {
        "Reit. Similarity": reit,
        "Local Concordance": lc,
        "Local Fidelity": lf,
        "Prescriptivity": pres,
    }


# ---------- 5) 그림 그리기 ----------
def _draw_prob_bar(ax, probs, cls_names):
    ax.barh(range(len(probs)), probs)
    ax.set_yticks(range(len(probs)))
    ax.set_yticklabels(cls_names)
    ax.set_xlim(0, 1)
    ax.invert_yaxis()
    ax.set_title("Prediction probabilities")


def _draw_feature_table(ax, pairs, title="Feature  Value"):
    # pairs: [('token', score), ...]
    ax.axis("off")
    ax.set_title(title, loc="left")
    rows = len(pairs)
    y = 1.0
    dy = 1.0 / (rows + 1e-9)
    for w, v in pairs:
        y -= dy
        ax.text(0.0, y, str(w), va="center", ha="left")
        ax.text(0.95, y, f"{v:.2f}", va="center", ha="right")


def _draw_metrics(ax, metrics_dict):
    """
    metrics_dict: {name: (mean, std)}
    """
    names = list(metrics_dict.keys())
    vals = [metrics_dict[n][0] for n in names]
    stds = [metrics_dict[n][1] for n in names]
    ax.scatter(vals, range(len(names)))
    # 오차막대(±std)
    for yi, (m, s) in enumerate(zip(vals, stds)):
        if np.isnan(m):
            continue
        ax.hlines(yi, max(0, m - s), min(1, m + s))
    ax.set_xlim(0, 1)
    ax.set_yticks(range(len(names)))
    ax.set_yticklabels(names)
    ax.set_xlabel("value")
    ax.invert_yaxis()
    ax.grid(True, axis="x", alpha=0.3)


def plot_lime_shap_panel(i=0, K=8, lime_num_samples=600, savepath=None):
    text = samples[i]
    probs = predict([text])[0]
    pred = int(np.argmax(probs))
    gold = str(df.iloc[i][gold_col])

    # LIME 상위 K
    lime_pairs, lime_score = lime_topk(
        text, explainer, pred, k=K, num_samples=lime_num_samples
    )
    lime_dict = {w: v for w, v in lime_pairs}

    # SHAP 상위 K (이미 계산된 shap_values 사용)
    sv = shap_values[i]
    shap_pairs = shap_topk(sv, pred, k=K)
    shap_dict = {w: v for w, v in shap_pairs}

    # 메트릭(근사) 구성
    m_lime = build_metrics_for_lime(text, lime_dict, shap_dict, lime_score)
    m_shap = build_metrics_for_shap(text, shap_dict)

    # --- 레이아웃 ---
    fig = plt.figure(figsize=(10, 10))
    gs = GridSpec(
        4, 2, height_ratios=[1, 2, 1, 2], width_ratios=[1.2, 1], hspace=0.5, wspace=0.35
    )

    # (A) LIME 상단: 확률 + 특징 테이블
    ax_prob = fig.add_subplot(gs[0, :])
    _draw_prob_bar(ax_prob, probs, labels)
    ax_prob.set_title(
        f"K={K}   gold: {gold} | pred: {labels[pred]} (p={probs[pred]:.2f})"
    )

    ax_lime_feat = fig.add_subplot(gs[1, 0])
    _draw_feature_table(ax_lime_feat, lime_pairs, title="LIME: Feature    Value")

    ax_lime_metrics = fig.add_subplot(gs[1, 1])
    _draw_metrics(ax_lime_metrics, m_lime)

    # (B) SHAP 하단: 상위 토큰 테이블 + 메트릭
    ax_shap_feat = fig.add_subplot(gs[3, 0])
    _draw_feature_table(ax_shap_feat, shap_pairs, title="SHAP: Feature    Value")

    ax_shap_metrics = fig.add_subplot(gs[3, 1])
    _draw_metrics(ax_shap_metrics, m_shap)

    # (중간 빈 줄): 그림을 분리하는 제목
    ax_sep = fig.add_subplot(gs[2, :])
    ax_sep.axis("off")
    ax_sep.text(0.01, 0.5, " ", fontsize=6)

    if savepath is None:
        savepath = f"panel_{i:02d}_lime_shap.png"
    plt.savefig(savepath, dpi=200, bbox_inches="tight")
    plt.show()
    print("saved:", savepath)


# ================== 사용법 ==================
# 예: i=28, K=12로 저장
# plot_lime_shap_panel(i=28, K=12)

In [None]:
import shap, numpy as np, pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer, AutoConfig
import torch

# --- 0) 모델/토크나이저/라벨 (이미 LIME에서 쓰던 그대로) ---
MODEL_PATH = "../models/nlbse/"
config = AutoConfig.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH, config=config)
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

display_labels = ["bug", "enhancement", "question"]
if getattr(model.config, "id2label", None):
    raw = [model.config.id2label[i] for i in range(model.config.num_labels)]
    labels = (
        display_labels if set(raw) == {f"LABEL_{i}" for i in range(len(raw))} else raw
    )
else:
    labels = display_labels
assert len(labels) == model.config.num_labels


def predict(texts):
    enc = tokenizer(
        list(texts), return_tensors="pt", padding=True, truncation=True, max_length=512
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        probs = torch.softmax(model(**enc).logits, dim=-1).cpu().numpy()
    return probs


# --- 1) 동일 샘플 로드 (LIME과 똑같이) ---
df = pd.read_csv("balanced_eval_set.csv").reset_index(drop=True)

text_col_title = "issue_title"
text_col_body = "issue_body"
gold_col = "issue_label"


def build_text(row, title_col=text_col_title, body_col=text_col_body, max_chars=800):
    title = str(row.get(title_col, ""))
    body = str(row.get(body_col, ""))[:max_chars]
    return (title + "\n\n" + body).strip()


samples = [build_text(row) for _, row in df.iterrows()]  # ★ LIME과 동일 입력

# --- 2) SHAP 마스커/Explainer ---
masker = shap.maskers.Text(
    tokenizer=tokenizer, mask_token=(tokenizer.mask_token or "[MASK]")
)
explainer = shap.Explainer(predict, masker, output_names=labels, algorithm="partition")

# --- 3) SHAP 계산 ---
shap_values = explainer(samples)  # len == N 샘플, 각 sv는 (토큰 × 클래스)

# --- 4) 저장 (텍스트 플롯은 save_html 말고 .data를 직접 저장) ---
for i, sv in enumerate(shap_values):
    txt = samples[i]
    probs = predict([txt])[0]
    cls = int(np.argmax(probs))
    predlab = labels[cls]
    conf = float(probs[cls])
    gold = str(df.iloc[i][gold_col])  # 실제 라벨

    html_obj = shap.plots.text(sv[..., cls], display=False)  # 예측 클래스 기준
    body = getattr(html_obj, "data", str(html_obj))

    # HTML 상단에 gold/pred 표시
    with open(f"shap_{i:02d}.html", "w", encoding="utf-8") as f:
        f.write(
            f"<meta charset='utf-8'>\n"
            f"<h2>[{i:02d}] gold: {gold} | pred: {predlab} (conf {conf:.2f})</h2>\n"
            + body
        )

    print(
        f"[{i:02d}] gold={gold} | pred={predlab} (conf={conf:.2f}) → shap_{i:02d}.html"
    )

# --- 5) (선택) shape 확인/진단 ---
print(f"#samples = {len(samples)}, #shap_values = {len(shap_values)}")
sv0 = shap_values[0]
print("sv0.values shape (토큰×클래스):", sv0.values.shape)  # (T, C)
print("sv0.base_values shape (클래스):", np.array(sv0.base_values).shape)  # (C,)
print("정답 라벨 분포:\n", df[gold_col].value_counts())


# --- 6) Faithfulness 지표 함수들 -------------------------------------------
import numpy as np
import pandas as pd
import torch

# BERT 마스킹 토큰 ID
mask_token_id = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")


def predict_with_manual_mask(text, mask_indices, keep=False, max_length=512):
    """
    WordPiece 기준으로 토큰을 선택적으로 [MASK] 처리한 뒤 모델 확률을 반환.
    - keep=False: 상위 k 토큰을 [MASK]로 바꿔 제거 효과 (Compreh.)
    - keep=True : 상위 k 토큰만 남기고 나머지는 [MASK] (Suff.)
    """
    wp_tokens = tokenizer.tokenize(text)[: max_length - 2]  # [CLS]/[SEP] 고려
    ids = tokenizer.convert_tokens_to_ids(wp_tokens)
    input_ids = [tokenizer.cls_token_id] + ids + [tokenizer.sep_token_id]

    if keep:
        keep_set = set(mask_indices)
        for t_idx in range(len(wp_tokens)):
            pos = 1 + t_idx
            if t_idx not in keep_set:
                input_ids[pos] = mask_token_id
    else:
        for t_idx in mask_indices:
            pos = 1 + t_idx
            if 0 < pos < len(input_ids) - 1:  # [CLS]/[SEP] 보호
                input_ids[pos] = mask_token_id

    input_ids_tensor = torch.tensor([input_ids]).to(device)
    attn = torch.ones_like(input_ids_tensor).to(device)
    with torch.no_grad():
        logits = model(input_ids=input_ids_tensor, attention_mask=attn).logits
        probs = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu().numpy()
    return probs


def topk_token_indices_for_class(sv, class_idx: int, k: int = 10, use_abs: bool = True):
    """
    SHAP Explanation sv에서 class_idx에 대한 상위 k 토큰 인덱스 반환.
    sv.values shape == (num_tokens, num_classes)
    """
    vals = sv.values[:, class_idx]
    order = np.argsort(-np.abs(vals)) if use_abs else np.argsort(-vals)
    return order[: min(k, len(order))].tolist()


def faithfulness_metrics(
    text: str, sv, k: int = 10, use_abs: bool = True, class_idx: int | None = None
):
    """
    한 샘플에 대한 Comprehensiveness / Sufficiency 계산.
    - class_idx가 None이면 모델 예측 클래스(yhat) 기준으로 계산.
    반환 dict: pred_class, p_full, p_drop, p_keep, comprehensiveness, sufficiency
    """
    base_probs = predict([text])[0]
    yhat = int(np.argmax(base_probs)) if class_idx is None else int(class_idx)
    p_full = float(base_probs[yhat])

    topk_idx = topk_token_indices_for_class(sv, yhat, k=k, use_abs=use_abs)

    # Compreh.: 상위 k 제거
    probs_drop = predict_with_manual_mask(text, topk_idx, keep=False)
    p_drop = float(probs_drop[yhat])
    compreh = p_full - p_drop

    # Suff.: 상위 k만 유지
    probs_keep = predict_with_manual_mask(text, topk_idx, keep=True)
    p_keep = float(probs_keep[yhat])
    suff = p_full - p_keep

    return {
        "pred_class": labels[yhat],
        "p_full": p_full,
        "p_drop": p_drop,
        "p_keep": p_keep,
        "comprehensiveness": compreh,
        "sufficiency": suff,
    }


def evaluate_faithfulness(shap_values, samples, k: int = 10, use_abs: bool = True):
    """
    전체 샘플에 대해 faithfulness 지표를 DataFrame으로 반환.
    추가로 comp_norm(정규화), suff_abs(절댓값) 포함.
    """
    rows = []
    for i, (sv, txt) in enumerate(zip(shap_values, samples)):
        m = faithfulness_metrics(txt, sv, k=k, use_abs=use_abs)
        m["i"] = i
        m["gold"] = str(df.iloc[i][gold_col])
        m["k"] = k
        m["comp_norm"] = m["comprehensiveness"] / (m["p_full"] + 1e-8)
        m["suff_abs"] = abs(m["sufficiency"])
        rows.append(m)
    return pd.DataFrame(rows)

PartitionExplainer explainer:  23%|██▎       | 7/30 [00:11<00:04,  4.98it/s]

In [None]:
import numpy as np
import pandas as pd
import torch

# --- 보조 함수들 다시 정의 ---
mask_token_id = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")


def predict_with_manual_mask(text, mask_indices, keep=False, max_length=512):
    # WordPiece 기준으로 토큰 마스킹 후 예측 확률 반환
    wp_tokens = tokenizer.tokenize(text)[: max_length - 2]
    ids = tokenizer.convert_tokens_to_ids(wp_tokens)
    input_ids = [tokenizer.cls_token_id] + ids + [tokenizer.sep_token_id]

    if keep:
        keep_set = set(mask_indices)
        for t_idx in range(len(wp_tokens)):
            pos = 1 + t_idx
            if t_idx not in keep_set:
                input_ids[pos] = mask_token_id
    else:
        for t_idx in mask_indices:
            pos = 1 + t_idx
            if 0 < pos < len(input_ids) - 1:
                input_ids[pos] = mask_token_id

    input_ids_tensor = torch.tensor([input_ids]).to(device)
    attn = torch.ones_like(input_ids_tensor).to(device)

    with torch.no_grad():
        logits = model(input_ids=input_ids_tensor, attention_mask=attn).logits
        probs = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu().numpy()
    return probs


def topk_token_indices_for_class(sv, class_idx: int, k: int = 10, use_abs=True):
    vals = sv.values[:, class_idx]  # (num_tokens,)
    order = np.argsort(-np.abs(vals)) if use_abs else np.argsort(-vals)
    return order[: min(k, len(order))].tolist()


def faithfulness_metrics(text: str, sv, k: int = 10, use_abs=True):
    base_probs = predict([text])[0]
    yhat = int(base_probs.argmax())
    p_full = float(base_probs[yhat])

    topk_idx = topk_token_indices_for_class(sv, yhat, k=k, use_abs=use_abs)

    probs_drop = predict_with_manual_mask(text, topk_idx, keep=False)
    p_drop = float(probs_drop[yhat])
    compreh = p_full - p_drop

    probs_keep = predict_with_manual_mask(text, topk_idx, keep=True)
    p_keep = float(probs_keep[yhat])
    suff = p_full - p_keep

    return {
        "pred_class": labels[yhat],
        "p_full": p_full,
        "p_drop": p_drop,
        "p_keep": p_keep,
        "comprehensiveness": compreh,
        "sufficiency": suff,
    }


def evaluate_faithfulness(shap_values, samples, k=10, use_abs=True):
    rows = []
    for i, sv in enumerate(shap_values):
        m = faithfulness_metrics(samples[i], sv, k=k, use_abs=use_abs)
        m["i"] = i
        m["gold"] = str(df.iloc[i][gold_col])
        m["k"] = k
        # 편의 지표
        m["comp_norm"] = m["comprehensiveness"] / (m["p_full"] + 1e-8)
        m["suff_abs"] = abs(m["sufficiency"])
        rows.append(m)
    return pd.DataFrame(rows)


# --- 실행: 여러 k로 집계 ---
results = []
for k in [5, 10, 20]:
    results.append(evaluate_faithfulness(shap_values, samples, k=k, use_abs=True))
res = pd.concat(results, ignore_index=True)

print("=== 전체 평균 ===")
print(
    res.groupby("k")[["comprehensiveness", "comp_norm", "sufficiency", "suff_abs"]]
    .mean()
    .round(3)
)

print("\n=== GOLD 라벨별 평균 ===")
print(
    res.groupby(["k", "gold"])[
        ["comprehensiveness", "comp_norm", "sufficiency", "suff_abs"]
    ]
    .mean()
    .round(3)
)

# (선택) 파일로 저장
# res.to_csv("faithfulness_results.csv", index=False)