# 06 â€” Model Interpretability

**Feature importance, SHAP values, and party deviation analysis**

Export findings for the research paper.

In [None]:
from pathlib import Path
Path("../outputs").mkdir(exist_ok=True)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from src.ml.features import load_pairs, get_train_val_test, build_basic_features, add_enhanced_features
from src.ml.models import train_model_a, predict_model_a, build_X_for_model_a

MODEL_KW = dict(
    max_features=2000, ngram_range=(1, 1), min_df=1,
    use_besluit_tfidf=True, use_speech_position=True, use_speaker_loyalty=True,
    use_kabinetsappreciatie=True, use_zaak_soort=True, use_is_coalition=True,
)

df = load_pairs(sample=50000)
df = df[df["datum"].notna()]
df = build_basic_features(df)
train, val, test = get_train_val_test(df)
train = train[train["vote"].isin(["Voor", "Tegen"])]
val = val[val["vote"].isin(["Voor", "Tegen"])]
train, val, test = add_enhanced_features(train, val, test)

model = train_model_a(train, **MODEL_KW)

print(f"Train: {len(train):,} | Val: {len(val):,} | Test: {len(test):,}")

## 1. Feature Importance (Logistic Regression Coefficients)

Which words drive the model toward Voor vs Tegen?

In [None]:
coef = model["clf"].coef_[0]
n_party = len(model["party_enc"].get_feature_names_out())
tfidf_names = model["tfidf"].get_feature_names_out()
n_tfidf = len(tfidf_names)

tfidf_coef = coef[n_party : n_party + n_tfidf]
top_voor = np.argsort(tfidf_coef)[-15:][::-1]
top_tegen = np.argsort(tfidf_coef)[:15]

print("Top terms predicting VOOR:")
for i in top_voor:
    print(f"  {tfidf_names[i]}: {tfidf_coef[i]:.3f}")
print("\nTop terms predicting TEGEN:")
for i in top_tegen:
    print(f"  {tfidf_names[i]}: {tfidf_coef[i]:.3f}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].barh(range(15), tfidf_coef[top_voor], color="green", alpha=0.7)
axes[0].set_yticks(range(15))
axes[0].set_yticklabels([tfidf_names[i] for i in top_voor], fontsize=9)
axes[0].set_title("Pro-Voor terms")
axes[0].invert_yaxis()
axes[1].barh(range(15), tfidf_coef[top_tegen], color="red", alpha=0.7)
axes[1].set_yticks(range(15))
axes[1].set_yticklabels([tfidf_names[i] for i in top_tegen], fontsize=9)
axes[1].set_title("Pro-Tegen terms")
axes[1].invert_yaxis()
plt.tight_layout()
plt.savefig("../outputs/06_feature_importance.png", dpi=150, bbox_inches="tight")
plt.show()

## 2. SHAP Values (Per-Prediction Explanations)

"This speech predicted Tegen because of words X, Y, Z"

In [None]:
try:
    import shap

    X_val = build_X_for_model_a(val, model)
    if hasattr(X_val, "toarray"):
        X_val = X_val.toarray()

    explainer = shap.LinearExplainer(model["clf"], X_val)
    shap_values = explainer.shap_values(X_val[:100])

    feature_names = (
        list(model["party_enc"].get_feature_names_out())
        + list(model["tfidf"].get_feature_names_out())
    )
    if model.get("tfidf_besluit") is not None:
        feature_names += list(model["tfidf_besluit"].get_feature_names_out())
    if model.get("use_speech_position"):
        feature_names.append("speech_position")
    if model.get("use_speaker_loyalty"):
        feature_names.append("speaker_loyalty")
    if model.get("use_kabinetsappreciatie") and model.get("ka_enc") is not None:
        feature_names += list(model["ka_enc"].get_feature_names_out())
    if model.get("use_zaak_soort") and model.get("zs_enc") is not None:
        feature_names += list(model["zs_enc"].get_feature_names_out())
    if model.get("use_is_coalition"):
        feature_names.append("is_coalition")
    feature_names = feature_names[:X_val.shape[1]]

    shap.summary_plot(shap_values, X_val[:100], feature_names=feature_names[:50], show=False)
    plt.title("SHAP values for vote prediction (top 50 features)")
    plt.tight_layout()
    plt.savefig("../outputs/06_shap_summary.png", dpi=150, bbox_inches="tight")
    plt.show()
except ImportError:
    print("Install shap: pip install shap")

## 3. Party Deviation Analysis

Which speakers' speech most diverges from their party's voting pattern?

In [None]:
val_copy = val.copy()
val_copy["pred"] = predict_model_a(model, val)
val_copy["correct"] = val_copy["vote"] == val_copy["pred"]

party_majority = model["party_enc"].get_feature_names_out()
party_acc = val_copy.groupby("fractie").agg(
    total=("correct", "count"),
    correct=("correct", "sum"),
)
party_acc["accuracy"] = party_acc["correct"] / party_acc["total"]
party_acc = party_acc[party_acc["total"] >= 5].sort_values("accuracy")

print("Parties where speech-based model diverges most from party (lowest accuracy):")
print(party_acc.head(10).to_string())

In [None]:
speaker_dev = val_copy.groupby(["persoon_id", "achternaam", "fractie"]).agg(
    total=("correct", "count"),
    correct=("correct", "sum"),
)
speaker_dev["accuracy"] = speaker_dev["correct"] / speaker_dev["total"]
speaker_dev = speaker_dev[speaker_dev["total"] >= 3].sort_values("accuracy")

print("\nSpeakers whose speech most diverges from prediction (lowest accuracy):")
print(speaker_dev.head(15).to_string())

## 4. Linking Quality Analysis

How well does the speech-vote linking work? Check coverage of Kabinetsappreciatie and Zaak.Soort.

In [None]:
# Linking quality: coverage of enriched features
if "kabinetsappreciatie" in train.columns:
    ka_dist = train["kabinetsappreciatie"].fillna("Onbekend").value_counts()
    print("Kabinetsappreciatie distribution (train):")
    print(ka_dist.head(10).to_string())
if "zaak_soort" in train.columns:
    zs_dist = train["zaak_soort"].fillna("Onbekend").value_counts()
    print("\nZaak.Soort distribution (train):")
    print(zs_dist.head(10).to_string())
if "agendapunt_onderwerp" in train.columns:
    has_topic = train["agendapunt_onderwerp"].notna() & (train["agendapunt_onderwerp"].astype(str).str.len() > 5)
    print(f"\nPairs with non-empty agendapunt_onderwerp: {has_topic.sum():,} / {len(train):,} ({100*has_topic.mean():.1f}%)")

## 5. RobBERT Evaluation (if checkpoint exists)

Load the fine-tuned RobBERT transformer and evaluate on validation set.

In [None]:
from pathlib import Path
from src.ml.models import load_model_robbert, predict_model_robbert, evaluate

robbert_path = Path("../models/robbert_vote_classifier")
if (robbert_path / "config.json").exists():
    model_robbert = load_model_robbert(str(robbert_path))
    pred_robbert = predict_model_robbert(model_robbert, val)
    r_robbert = evaluate(val["vote"].values, pred_robbert)
    print(f"RobBERT val accuracy: {r_robbert['accuracy']*100:.1f}%")
    print(f"RobBERT F1 (macro): {r_robbert['f1_macro']:.3f}")
else:
    print("RobBERT checkpoint not found. Run: python scripts/train_robbert.py --epochs 5")

## 6. RobBERT Attention Visualization

Which tokens does the model attend to for a given prediction? (Requires RobBERT checkpoint.)

In [None]:
if (robbert_path / "config.json").exists():
    from src.ml.models import get_robbert_attention

    sample_row = val.iloc[10]
    cls_attn, tokens, pred = get_robbert_attention(model_robbert, sample_row)
    # Plot top tokens by CLS attention (skip [CLS] and [SEP])
    n_show = min(25, len(tokens))
    idx = np.argsort(cls_attn)[::-1][:n_show]
    fig, ax = plt.subplots(figsize=(10, 6))
    y_pos = np.arange(n_show)
    ax.barh(y_pos, cls_attn[idx], color="steelblue", alpha=0.8)
    ax.set_yticks(y_pos)
    ax.set_yticklabels([tokens[i] for i in idx], fontsize=9)
    ax.invert_yaxis()
    ax.set_xlabel("CLS attention weight")
    ax.set_title(f"RobBERT: tokens most attended to (prediction: {pred})")
    plt.tight_layout()
    plt.savefig("../outputs/06_robbert_attention.png", dpi=150, bbox_inches="tight")
    plt.show()
else:
    print("RobBERT checkpoint not found. Skip attention visualization.")

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))
party_acc_plot = party_acc.tail(15)
colors = ["green" if a > 0.7 else "orange" if a > 0.5 else "red" for a in party_acc_plot["accuracy"]]
ax.barh(party_acc_plot.index, party_acc_plot["accuracy"], color=colors, alpha=0.7)
ax.axvline(0.5, color="gray", linestyle="--")
ax.set_xlabel("Accuracy (speech model vs actual vote)")
ax.set_title("Party deviation: model accuracy by party")
plt.tight_layout()
plt.savefig("../outputs/06_party_deviation.png", dpi=150, bbox_inches="tight")
plt.show()