In [1]:
import pandas as pd
import ast
import os
from transformers import BertTokenizerFast, BertForSequenceClassification
import numpy as np
import shap
import torch

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load tokenizer and fine-tuned classifier weights
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
model.load_state_dict(torch.load("../outputs/model/bert_toxic_classifier.pt", map_location=device))
model.to(device)
model.eval()

# Wrapper so SHAP can call the model with raw texts
class WrappedModel:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    def __call__(self, texts):
        if isinstance(texts, str):
            texts = [texts]
        elif isinstance(texts, (pd.Series, np.ndarray)):
            texts = texts.tolist()
        enc = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
        enc = {k: v.to(self.device) for k, v in enc.items()}
        with torch.no_grad():
            outputs = self.model(**enc)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        return probs[:, 1].detach().cpu().numpy()

# SHAP explainer using the same tokenizer (Text masker)
explainer = shap.Explainer(WrappedModel(model, tokenizer, device), shap.maskers.Text(tokenizer))

🟢 Using device: cuda


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
# Load Toxic Spans data and align indices to SHAP results
df_orig = pd.read_csv("../data/toxic_spans.csv")
df_orig["position"] = df_orig["position"].apply(ast.literal_eval)
df_orig["text_of_post"] = df_orig["text_of_post"].astype(str).apply(lambda x: x.strip())
df_orig = df_orig[df_orig["text_of_post"] != ""].reset_index(drop=True)

# Load per-sample SHAP metrics and attach gold spans by original index
df_results = pd.read_csv("../outputs/shap/evaluation_metrics.csv")
df_results["true_spans"] = df_results["index"].apply(lambda idx: df_orig.loc[idx, "position"])

In [4]:
# goldnpred = gold, gold only = lightgreen, pred only = lightcoral
def mark_spans_combined(text, true_spans, pred_spans):
    html = ""
    for i, c in enumerate(text):
        if i in true_spans and i in pred_spans:
            html += f'<span style="background-color:gold;">{c}</span>'
        elif i in true_spans:
            html += f'<span style="background-color:lightgreen;">{c}</span>'
        elif i in pred_spans:
            html += f'<span style="background-color:lightcoral;">{c}</span>'
        else:
            html += c
    return html


In [6]:
# Take top-N examples by SHA F1 and write combined HTML overlays
top_n = 10
samples = df_results.sort_values("f1", ascending=False).head(top_n)

os.makedirs("../outputs/figures", exist_ok=True)
with open("../outputs/figures/shap_top_samples.html", "w", encoding="utf-8") as f:
    for i, row in samples.iterrows():
        text = row["text"]
        pred_spans = set()
        true_spans = set(ast.literal_eval(str(row["true_spans"])))
        tokens = ast.literal_eval(row["tokens"])
        # 'scores' saved as string; parse to numpy array
        scores = np.fromstring(row["scores"].strip("[]"), sep=" ")

        # Project token-level scores back to character offsets
        enc = tokenizer(text, return_offsets_mapping=True, truncation=True, padding="max_length", max_length=128)
        offsets = enc["offset_mapping"]

        # Threshold SHAP values to choose "toxic-driving" tokens
        threshold = 0.5
        for j, (start, end) in enumerate(offsets):
            if start == end or j >= len(scores):
                continue
            if scores[j] >= threshold:
                pred_spans.update(range(start, end))

        html = mark_spans_combined(text, true_spans, pred_spans)
        f.write(f"""
        <h4>Sample {int(row['index'])}</h4>
        <p><strong>F1:</strong> {row['f1']:.2f}, <strong>Precision:</strong> {row['precision']:.2f}, <strong>Recall:</strong> {row['recall']:.2f}</p>
        <pre style='font-size: 15px; line-height: 1.5; white-space: pre-wrap'>{html}</pre>
        <hr>
        """)

print("Saved HTML visualizations: ../outputs/figures/shap_top_samples.html")


✅ Saved HTML visualizations: ../outputs/figures/shap_top_samples.html


In [7]:
# Save SHAP's native text plots for a couple of top samples
for i, row in samples.head(2).iterrows():
    text = row["text"]
    shap_values = explainer([text])
    try:
        shap_html = shap.plots.text(shap_values[0], display=False)
        with open(f"../outputs/figures/shap_explanation_{i}.html", "w", encoding="utf-8") as f:
            f.write(shap_html)
        print(f"Saved SHAP explanation HTML for sample {i}")
    except Exception as e:
        print(f"Failed to save SHAP plot for {i}: {e}")


✅ Saved SHAP explanation HTML for sample 2502
✅ Saved SHAP explanation HTML for sample 6939


In [9]:
# Export the metadata for the selected top-N samples
samples.to_csv("../outputs/figures/shap_top_samples.csv", index=False)
print("Saved top sample data: ../outputs/figures/shap_top_samples.csv")

✅ Saved top sample data: ../outputs/figures/shap_top_samples.csv
