# MindLens-AI — 03: Explainability (RQ1)

**RQ1:** Can explainable NLP models maintain high performance (≥80% accuracy) while providing meaningful, human-interpretable explanations aligned with mental health indicators?

Tools: SHAP (global + local) and LIME (instance-level). Includes a quantitative Interpretability Score.

In [None]:
# Setup
import sys, os
sys.path.insert(0, os.path.abspath(".."))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap

from src.model import load_model
from src.explainability import (
    explain_lime, explain_shap, shap_summary_plot,
    compute_interpretability_score, MH_LEXICON,
)

sns.set_theme(style="whitegrid")

# Load artifacts
model, vectorizer = load_model("../data/processed/model_artifacts.joblib")
split = joblib.load("../data/processed/test_split.joblib")
X_test = split["X_test"]
X_train = split["X_train"]
y_test = split["y_test"]
texts_test = split["texts_test"]
feature_names = split["feature_names"]

print(f"Loaded model + {len(texts_test)} test samples ✓")

## 1. SHAP — Global Feature Importance

In [None]:
# Compute SHAP values
shap_values, shap_explainer = explain_shap(model, X_train, X_test, feature_names)
print(f"SHAP values shape: {shap_values.shape}")

# Summary plot (top 20 global features)
shap_summary_plot(shap_values, X_test, feature_names, max_display=20)

## 2. SHAP — Force Plots (Individual Predictions)

In [None]:
# Pick 5 Risk + 5 No-Risk samples
risk_idx = np.where(y_test == 1)[0][:5]
safe_idx = np.where(y_test == 0)[0][:5]
sample_indices = np.concatenate([risk_idx, safe_idx])

for i in sample_indices:
    label = "RISK" if y_test[i] == 1 else "NO RISK"
    print(f"\n--- Sample {i} (True: {label}) ---")
    print(f"Text: {texts_test[i][:150]}...")
    plt.figure()
    shap.force_plot(
        shap_explainer.expected_value,
        shap_values[i],
        feature_names=feature_names,
        matplotlib=True,
    )
    plt.show()

## 3. LIME — Instance-Level Explanations

In [None]:
# LIME explanations for same 10 samples
for i in sample_indices:
    label = "RISK" if y_test[i] == 1 else "NO RISK"
    text = texts_test[i]
    print(f"\n--- LIME: Sample {i} (True: {label}) ---")
    print(f"Text: {text[:150]}...")

    exp = explain_lime(model, vectorizer, text, num_features=10)
    fig = exp.as_pyplot_figure()
    plt.title(f"LIME — Sample {i} ({label})")
    plt.tight_layout()
    plt.show()

    # Show top words
    print("Top features:", exp.as_list())

## 4. Quantitative Interpretability Score

**Interpretability Score** = (# of top-k SHAP features overlapping with mental health lexicon) / k

In [None]:
print(f"Mental Health Lexicon ({len(MH_LEXICON)} words):")
print(sorted(MH_LEXICON))

# Compute at k=10 and k=20
for k in [10, 20]:
    result = compute_interpretability_score(shap_values, feature_names, MH_LEXICON, k=k)
    print(f"\n--- Interpretability Score (k={k}) ---")
    print(f"  Mean:  {result['mean_score']:.4f}")
    print(f"  Std:   {result['std_score']:.4f}")

# Distribution plot
result_k10 = compute_interpretability_score(shap_values, feature_names, MH_LEXICON, k=10)
fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(result_k10["per_sample_scores"], bins=20, edgecolor="black", alpha=0.7, color="#2196F3")
ax.axvline(result_k10["mean_score"], color="red", linestyle="--", label=f"Mean = {result_k10['mean_score']:.3f}")
ax.set_xlabel("Interpretability Score (k=10)")
ax.set_ylabel("Count")
ax.set_title("Distribution of Interpretability Scores")
ax.legend()
plt.tight_layout()
plt.show()

## 5. RQ1 Conclusion

**H1:** Explainable models can achieve ≥80% accuracy without major performance loss.

In [None]:
from src.evaluation import evaluate_single

metrics = evaluate_single(model, X_test, y_test)
accuracy = metrics["accuracy"]
interp_score = result_k10["mean_score"]

print("=" * 50)
print("RQ1 CONCLUSION")
print("=" * 50)
print(f"  Accuracy:              {accuracy:.4f}  (target ≥ 0.80)")
print(f"  F1 Score:              {metrics['f1']:.4f}")
print(f"  Interpretability (k=10): {interp_score:.4f}")
print()

if accuracy >= 0.80 and interp_score >= 0.15:
    print("✓ H1 SUPPORTED: High accuracy with meaningful, interpretable explanations.")
else:
    print("✗ H1 NOT SUPPORTED: Either accuracy < 80% or interpretability score too low.")

print("\nExplainability notebook complete ✓")