In [None]:
!pip -q install -U "numpy==1.26.4" "pandas==2.2.2"
!pip -q install -U "captum==0.8.0" "transformers" "matplotlib"

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.7/12.7 MB[0m [31m56.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0m

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import LayerIntegratedGradients

In [None]:
# ---------- 1) CONFIG (EDIT THESE) ----------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# (A) Hugging Face model name used during training (recommended)
MODEL_NAME = "roberta-base"              # TODO: change (e.g., "climatebert/distilroberta-base-climate-f")
NUM_LABELS  = 3                               # TODO: change
PTH_PATH    = "/content/best_model.pth"       # TODO: upload .pth to Colab or mount Drive
USE_LOCAL = False
LOCAL_DIR = "/content/my_model_dir"

MAX_LENGTH = 256      # token length for analysis
N_STEPS    = 50       # IG steps (higher = smoother but slower)

OUT_DIR = "/content/ig_report"

# Two examples (one correct and one wrong)
RIGHT_TEXT = "Verizon's environmental, health and safety management system provides a framework for identifying, controlling, and reducing the risks associated with the environments in which we operate. Besides regular management system assessments, internal and third-party compliance audits and inspections are performed annually at hundreds of facilities worldwide. The goal of these assessments is to identify and correct site-specific issues, and to educate and empower facility managers and supervisors to implement corrective actions. Verizon's environment, health and safety efforts are directed and supported by experienced experts around the world that support our operations and facilities."
RIGHT_TRUE_LABEL = 1  # int label

WRONG_TEXT = "Sustainable strategy 'red lines' For our sustainable strategy range, we incorporate a series of proprietary 'red lines' in order to ensure the poorest- performing companies from an ESG perspective are not eligible for investment."
WRONG_TRUE_LABEL = 2  # int label

In [None]:
# ---------- 2) Load tokenizer + model ----------
model_source = LOCAL_DIR if USE_LOCAL else MODEL_NAME

tokenizer = AutoTokenizer.from_pretrained(model_source)
model = AutoModelForSequenceClassification.from_pretrained(
    model_source,
    num_labels=NUM_LABELS
)

# Load .pth (state_dict)
state = torch.load(PTH_PATH, map_location="cpu")

# Common checkpoint wrapper: {"model_state_dict": ...}
if isinstance(state, dict) and "model_state_dict" in state:
    state = state["model_state_dict"]

# DataParallel keys cleanup: "module."
if isinstance(state, dict):
    state = {k.replace("module.", ""): v for k, v in state.items()}

missing, unexpected = model.load_state_dict(state, strict=False)

model.to(DEVICE).eval()

print("Loaded on:", DEVICE)
print("Missing keys:", len(missing), "| Unexpected keys:", len(unexpected))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded on: cpu
Missing keys: 0 | Unexpected keys: 0


In [None]:
# ---------- 3) Prediction + IG core ----------
def predict_logits(input_ids, attention_mask):
    return model(input_ids=input_ids, attention_mask=attention_mask).logits  # [B, C]

In [None]:
def encode(text, max_length=MAX_LENGTH):
    batch = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_length
    )
    return batch["input_ids"].to(DEVICE), batch["attention_mask"].to(DEVICE)

In [None]:
def get_pred(text, max_length=MAX_LENGTH):
    input_ids, attn = encode(text, max_length=max_length)
    with torch.no_grad():
        logits = predict_logits(input_ids, attn)
        probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
        pred = int(np.argmax(probs))
    return pred, probs

In [None]:
def get_word_embeddings_layer(m):
    # Add cases as needed for your backbone
    if hasattr(m, "bert"):
        return m.bert.embeddings.word_embeddings
    if hasattr(m, "roberta"):
        return m.roberta.embeddings.word_embeddings
    if hasattr(m, "distilbert"):
        return m.distilbert.embeddings.word_embeddings
    if hasattr(m, "albert"):
        return m.albert.embeddings.word_embeddings
    if hasattr(m, "deberta"):
        return m.deberta.embeddings.word_embeddings
    raise ValueError("Backbone not found. Add a case in get_word_embeddings_layer().")

emb_layer = get_word_embeddings_layer(model)

lig = LayerIntegratedGradients(
    forward_func=lambda input_ids, attention_mask: predict_logits(input_ids, attention_mask),
    layer=emb_layer
)


In [None]:
def integrated_gradients_for_text(text, target_label=None, max_length=MAX_LENGTH, n_steps=N_STEPS):
    input_ids, attn = encode(text, max_length=max_length)

    # default: explain predicted class
    if target_label is None:
        with torch.no_grad():
            target_label = int(torch.argmax(predict_logits(input_ids, attn), dim=-1).item())

    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    baseline_ids = torch.full_like(input_ids, fill_value=pad_id).to(DEVICE)

    attributions, delta = lig.attribute(
        inputs=input_ids,
        baselines=baseline_ids,
        additional_forward_args=(attn,),
        target=target_label,
        n_steps=n_steps,
        return_convergence_delta=True
    )

    # [1, seq_len, emb_dim] -> [seq_len]
    token_attr = attributions.sum(dim=-1).squeeze(0)
    token_attr = token_attr * attn.squeeze(0)  # zero pads

    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).detach().cpu().tolist())
    scores = token_attr.detach().cpu().numpy()

    # normalize to [-1, 1] for consistent visuals
    denom = np.max(np.abs(scores)) if np.max(np.abs(scores)) > 0 else 1.0
    scores = scores / denom

    return tokens, scores, float(delta.detach().cpu().item()), int(target_label)

In [None]:
# ---------- 4) Report-ready visual exports ----------
SPECIAL = set(["[CLS]","[SEP]","[PAD]","<s>","</s>","<pad>"])

In [None]:
def _rgba_pos(a):  # positive -> red-ish
    a = float(np.clip(a, 0, 1))
    return (1.0, 0.55, 0.45, 0.12 + 0.65*a)

def _rgba_neg(a):  # negative -> blue-ish
    a = float(np.clip(a, 0, 1))
    return (0.35, 0.55, 1.0, 0.12 + 0.65*a)

In [None]:
def tokens_to_report_figure(tokens, scores, title, out_png, out_pdf=None,
                            max_tokens=180, per_line=14, skip_special=True):
    pairs = []
    for t, s in zip(tokens, scores):
        if skip_special and t in SPECIAL:
            continue
        pairs.append((t, float(s)))
    pairs = pairs[:max_tokens]

    vals = np.array([abs(s) for _, s in pairs], dtype=float)
    denom = vals.max() if vals.max() > 0 else 1.0

    n_lines = int(np.ceil(len(pairs) / per_line)) if len(pairs) else 1
    fig_h = 1.6 + 0.55 * n_lines

    fig = plt.figure(figsize=(12, fig_h))
    ax = plt.gca()
    ax.axis("off")
    ax.set_title(title, fontsize=14, pad=14)

    x0, y0 = 0.02, 0.92
    dx, dy = 0.95 / per_line, 0.85 / max(n_lines, 1)

    for i, (tok, s) in enumerate(pairs):
        line = i // per_line
        col = i % per_line

        disp = tok.replace("##", "")
        a = abs(s) / denom
        face = _rgba_pos(a) if s >= 0 else _rgba_neg(a)

        x = x0 + col * dx
        y = y0 - line * dy

        ax.text(
            x, y, disp,
            transform=ax.transAxes,
            fontsize=11,
            va="top", ha="left",
            bbox=dict(boxstyle="round,pad=0.25", facecolor=face, edgecolor=(0,0,0,0.08))
        )

    plt.tight_layout()
    fig.savefig(out_png, dpi=300, bbox_inches="tight")
    if out_pdf:
        fig.savefig(out_pdf, bbox_inches="tight")
    plt.close(fig)

In [None]:
def topk_barplot(tokens, scores, title, out_png, out_pdf=None, k=20, skip_special=True):
    pairs = [(t, float(s)) for t, s in zip(tokens, scores) if (not skip_special or t not in SPECIAL)]
    pairs = [(t, s) for t, s in pairs if abs(s) > 0.02]
    pairs = sorted(pairs, key=lambda x: abs(x[1]), reverse=True)[:k]
    pairs = list(reversed(pairs))

    labels = [t.replace("##","") for t,_ in pairs]
    vals = [s for _,s in pairs]

    plt.figure(figsize=(10, 6))
    plt.barh(labels, vals)
    plt.title(title)
    plt.xlabel("Integrated Gradients attribution (normalized)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    if out_pdf:
        plt.savefig(out_pdf, bbox_inches="tight")
    plt.close()

In [None]:
def export_attr_table(tokens, scores, out_csv, skip_special=True):
    rows = []
    for t, s in zip(tokens, scores):
        if skip_special and t in SPECIAL:
            continue
        rows.append({"token": t.replace("##",""), "attribution_norm": float(s)})
    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    return df


In [None]:
def make_report_artifacts(text, true_label=None, tag="example", out_dir=OUT_DIR):
    os.makedirs(out_dir, exist_ok=True)

    pred, probs = get_pred(text)

    # explain the predicted class (good for analyzing errors)
    tokens, scores, delta, target = integrated_gradients_for_text(text, target_label=pred)

    title = f"{tag} | true={true_label} pred={pred} | delta={delta:.4g}"

    tokens_to_report_figure(
        tokens, scores, title=title,
        out_png=f"{out_dir}/{tag}_highlight.png",
        out_pdf=f"{out_dir}/{tag}_highlight.pdf"
    )

    topk_barplot(
        tokens, scores,
        title=f"{tag} — top token attributions (pred={pred})",
        out_png=f"{out_dir}/{tag}_topk.png",
        out_pdf=f"{out_dir}/{tag}_topk.pdf",
        k=20
    )

    df = export_attr_table(tokens, scores, out_csv=f"{out_dir}/{tag}_attributions.csv")

    with open(f"{out_dir}/{tag}_meta.txt", "w") as f:
        f.write(f"tag: {tag}\n")
        f.write(f"true_label: {true_label}\n")
        f.write(f"pred_label: {pred}\n")
        f.write(f"probs: {np.round(probs, 6).tolist()}\n")
        f.write(f"ig_target_explained: {target}\n")
        f.write(f"convergence_delta: {delta}\n")

    print("="*90)
    if true_label is not None:
        print(f"{tag}: true={true_label} pred={pred} correct? {pred==true_label}")
    else:
        print(f"{tag}: pred={pred}")
    print("probs:", np.round(probs, 4))
    print("Saved to:", out_dir)

    return pred, probs, df


In [None]:
# ---------- 5) Run for the two examples ----------
make_report_artifacts(RIGHT_TEXT, RIGHT_TRUE_LABEL, tag="A_correct")
make_report_artifacts(WRONG_TEXT,  WRONG_TRUE_LABEL,  tag="B_wrong")

print("\nDone. Files are in:", OUT_DIR)
print("Download them from the Colab file browser (left sidebar).")

A_correct: true=1 pred=1 correct? True
probs: [0.202  0.7253 0.0727]
Saved to: /content/ig_report
B_wrong: true=2 pred=1 correct? False
probs: [0.1248 0.79   0.0852]
Saved to: /content/ig_report

Done. Files are in: /content/ig_report
Download them from the Colab file browser (left sidebar).
