# ============================================================
# NOTEBOOK: 03_shap_analysis.ipynb
# AN√ÅLISIS DE INTERPRETABILIDAD CON SHAP
# ============================================================


In [None]:
# ============================================================
# Celda 1 - CONFIGURACI√ìN INICIAL Y IMPORTS
# ============================================================
"""
# üìä An√°lisis de Interpretabilidad con SHAP
## Explicando predicciones de DistilBERT en an√°lisis de sentimientos

Este notebook implementa SHAP (SHapley Additive exPlanations) para interpretar
las decisiones del modelo DistilBERT en la tarea de clasificaci√≥n de sentimientos.
"""

import warnings
warnings.filterwarnings('ignore')

# Imports b√°sicos
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import pickle
import time
from typing import List, Dict, Any, Tuple

# Configuraci√≥n de visualizaci√≥n
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 100)

# Imports de ML
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, classification_report

# SHAP
import shap

print("üîß CONFIGURACI√ìN DEL ENTORNO")
print("="*60)
print(f"üì¶ Versiones de librer√≠as:")
print(f"  ‚Ä¢ PyTorch: {torch.__version__}")
print(f"  ‚Ä¢ SHAP: {shap.__version__}")
print(f"  ‚Ä¢ Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f"  ‚Ä¢ GPU: {torch.cuda.get_device_name(0)}")


In [None]:
# ============================================================
# Celda 2 - CARGAR MODELO Y DATOS PREPARADOS
# ============================================================
print("\nüìö CARGANDO MODELO Y DATOS")
print("="*60)

# Configuraci√≥n
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
CACHE_DIR = "./models_cache"
DATA_DIR = "./explainability_analysis"

# Cargar modelo y tokenizer
print("\n1Ô∏è‚É£ Cargando modelo DistilBERT...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model.eval()

# Mover a GPU si est√° disponible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"‚úÖ Modelo cargado en {device}")

# Cargar casos seleccionados (notebook 02_model_evaluation)
print("\n2Ô∏è‚É£ Cargando casos seleccionados...")
try:
    # Buscar el archivo m√°s reciente
    import glob
    import os
    
    priority_files = glob.glob(f"{DATA_DIR}/priority_cases_*.json")
    if priority_files:
        latest_file = max(priority_files, key=os.path.getctime)
        with open(latest_file, 'r', encoding='utf-8') as f:
            priority_cases = json.load(f)
        print(f"‚úÖ Cargados {len(priority_cases)} casos prioritarios desde {latest_file}")
except Exception as e:
    print(f"Error cargando casos: {e}")
    priority_cases = []


In [None]:
# ============================================================
# Celda 3 - Verificaci√≥n del modelo en crudo
# ============================================================
print("\nüîß PREPARANDO FUNCI√ìN DE PREDICCI√ìN PARA SHAP")
print("="*60)

def predict_for_shap(texts: List[str]) -> np.ndarray:
    """
    Validamos manualmente el sesgo de predicci√≥n del modelo.
    
    Args:
        texts: Lista de textos a predecir
        
    Returns:
        Array de probabilidades [n_samples, n_classes]
    """
    # Tokenizar
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)
    
    # Predicci√≥n
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    
    return probs.cpu().numpy()

# Verificar funci√≥n
test_text = ["This is a test sentence."]
test_proba = predict_for_shap(test_text)
print(f"‚úÖ Funci√≥n de predicci√≥n verificada")
print(f"   ‚Ä¢ Input: '{test_text[0]}'")
print(f"   ‚Ä¢ Output shape: {test_proba.shape}")
print(f"   ‚Ä¢ Probabilidades: Neg={test_proba[0][0]:.3f}, Pos={test_proba[0][1]:.3f}")


In [None]:
# ============================================================
# Celda 4 - INICIALIZAR SHAP EXPLAINER
# ============================================================

# clear warnings
import warnings
warnings.filterwarnings('ignore', message='.*xformers.*')

print("\nüéØ INICIALIZANDO SHAP EXPLAINER")
print("="*60)

# M√©todo 1: Transformers Pipeline (m√°s eficiente para transformers)
print("\nüìå Configurando SHAP con Transformers Pipeline...")

# Crear pipeline de transformers
from transformers import pipeline
classifier = pipeline(
    "sentiment-analysis",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

# Inicializar SHAP Explainer
print("Inicializando explainer (puede tomar unos segundos)...")
explainer = shap.Explainer(classifier)
print("‚úÖ SHAP Explainer inicializado")

# Informaci√≥n del explainer
print(f"\nüìä Configuraci√≥n del Explainer:")
print(f"  ‚Ä¢ Tipo: {type(explainer).__name__}")
print(f"  ‚Ä¢ Modelo base: DistilBERT")
print(f"  ‚Ä¢ Clases: ['NEGATIVE', 'POSITIVE']")

In [None]:
# ============================================================
# Celda 5 - AN√ÅLISIS DE UN CASO SIMPLE
# ============================================================
print("\nüîç AN√ÅLISIS DE CASO SIMPLE")
print("="*60)

# Texto de ejemplo
simple_text = "This movie is absolutely terrible and boring."
print(f"\nüìù Texto: '{simple_text}'")

# Obtener predicci√≥n original
pred = classifier(simple_text)[0] # -> predicci√≥n usando el ejemplo para validar que obtenemos lo esperado (negative)
print(f"üéØ Predicci√≥n original: {pred['label']} (confianza: {pred['score']:.3f})")

# Calcular valores SHAP
print("\n‚è≥ Calculando valores SHAP (puede tomar 10-30 segundos)...")
start_time = time.time()
shap_values = explainer([simple_text])
calc_time = time.time() - start_time
print(f"‚úÖ C√°lculo completado en {calc_time:.2f} segundos")

# Mostrar informaci√≥n de shap_values
print(f"\nüìä Estructura de valores SHAP:")
print(f"  ‚Ä¢ Tipo: {type(shap_values)}")
print(f"  ‚Ä¢ Shape: {shap_values.shape}")
print(f"  ‚Ä¢ Clases: {shap_values.output_names}")

# Mostrar tokens y sus valores
print(f"\nüìù Tokens analizados y sus contribuciones:")
print("-"*50)

# Extraer tokens y valores
tokens = shap_values[0].data
values_neg = shap_values[0].values[:, 0]  # Impacto en NEGATIVE
values_pos = shap_values[0].values[:, 1]  # Impacto en POSITIVE

# # Mostrar cada token con su valor
# for i, token in enumerate(tokens):
#     print(f"  [{i}] '{token}':")
#     print(f"      ‚Üí NEGATIVE: {values_neg[i]:+.4f}")
#     print(f"      ‚Üí POSITIVE: {values_pos[i]:+.4f}")

# Identificar palabras m√°s influyentes
print(f"\nüéØ Palabras m√°s influyentes para POSITIVE:")
top_pos_idx = np.argsort(values_pos)[-3:][::-1]
for idx in top_pos_idx:
    print(f"  ‚Ä¢ '{tokens[idx]}': {values_pos[idx]:+.4f}")
    
print(f"\nüéØ Palabras m√°s influyentes para NEGATIVE:")
top_neg_idx = np.argsort(values_neg)[-3:][::-1]
for idx in top_neg_idx:
    print(f"  ‚Ä¢ '{tokens[idx]}': {values_neg[idx]:+.4f}")


In [None]:
# ============================================================
# Celda 6 - VISUALIZACI√ìN B√ÅSICA DE SHAP -> simple_text
# ============================================================
print("\nüìä VISUALIZACIONES B√ÅSICAS DE SHAP")
print("="*60)
print(f"üìå Analizando: '{simple_text}'")
print(f"   Predicci√≥n: {pred['label']} ({pred['score']:.1%})")

# Visualizaci√≥n 1: Text plot (muestra importancia en el texto)
print("\n1Ô∏è‚É£ Text Plot - Importancia de palabras en contexto:")
shap.plots.text(shap_values[0])

# Visualizaci√≥n 2: Bar plot (top palabras m√°s importantes)
print("\n2Ô∏è‚É£ Bar Plot - Top palabras m√°s influyentes:")
plt.figure(figsize=(10, 6))
# Para clasificaci√≥n binaria, usar la clase objetivo (POSITIVE = √≠ndice 1)
shap.plots.bar(shap_values[0, :, 1], max_display=10)

plt.show()

# Visualizaci√≥n 3: Waterfall (contribuci√≥n acumulativa)
print("\n3Ô∏è‚É£ Waterfall Plot - Contribuci√≥n acumulativa:")
plt.figure(figsize=(10, 6))
shap.plots.waterfall(shap_values[0, :, 1], max_display=10)

plt.show()

In [None]:
# ============================================================
# Celda 7 - AN√ÅLISIS DE M√öLTIPLES CASOS
# ============================================================
print("\nüîç AN√ÅLISIS DE M√öLTIPLES CASOS PRIORITARIOS")
print("="*60)

# Seleccionar subset de casos para an√°lisis
n_cases = min(5, len(priority_cases))
cases_to_analyze = priority_cases[:n_cases]

print(f"\nüìã Analizando {n_cases} casos prioritarios...")
print("-"*40)

# Almacenar resultados
shap_results = []

for i, case in enumerate(cases_to_analyze, 1):
    text = case['text']
    true_label = case['true_label']
    
    print(f"\nCaso {i}/{n_cases}:")
    print(f"  Texto: '{text[:80]}...'")
    print(f"  Categor√≠a: {case['category']}")
    print(f"  Label real: {true_label}")
    
    # Calcular SHAP
    print("  ‚è≥ Calculando SHAP...")
    start = time.time()
    shap_vals = explainer([text])
    elapsed = time.time() - start
    
    # Guardar resultados
    result = {
        'text': text,
        'true_label': true_label,
        'predicted_label': case.get('predicted_label', -1),
        'category': case['category'],
        'shap_values': shap_vals,
        'computation_time': elapsed
    }
    shap_results.append(result)
    
    print(f"  ‚úÖ Completado en {elapsed:.2f}s")

print(f"\n‚úÖ An√°lisis completado para {len(shap_results)} casos")
avg_time = np.mean([r['computation_time'] for r in shap_results])
print(f"‚è±Ô∏è Tiempo promedio por caso: {avg_time:.2f} segundos")


In [None]:
# ============================================================
# Celta 7.1 - EXTRACCI√ìN DE PALABRAS INFLUYENTES
# ============================================================

# Obtener top palabras influyentes por caso para cargar en el grafico
for i, result in enumerate(shap_results, 1):
    text = result['text']
    shap_vals = result['shap_values']
    
    tokens = shap_vals[0].data
    values_neg = shap_vals[0].values[:, 0]
    values_pos = shap_vals[0].values[:, 1]
    
    # Top palabras POSITIVE
    top_pos_idx = np.argsort(values_pos)[-3:][::-1]
    top_pos_words = [(tokens[idx], values_pos[idx]) for idx in top_pos_idx]
    
    # Top palabras NEGATIVE
    top_neg_idx = np.argsort(values_neg)[-3:][::-1]
    top_neg_words = [(tokens[idx], values_neg[idx]) for idx in top_neg_idx]
    
    result['top_positive_words'] = top_pos_words
    result['top_negative_words'] = top_neg_words

print("\n‚úÖ Extracci√≥n de palabras influyentes completada")
# muestro una comparativa en consola
for i, result in enumerate(shap_results, 1):
    print(f"\nüìã Caso {i} - Palabras m√°s influyentes:")
    print(f"  Texto: '{result['text'][:80]}...'")
    print(f"  Predicci√≥n: {result['predicted_label']}")
    print(f"  Top POSITIVE:")
    for word, val in result['top_positive_words']:
        print(f"    ‚Ä¢ '{word}': {val:+.4f}")
    print(f"  Top NEGATIVE:")
    for word, val in result['top_negative_words']:
        print(f"    ‚Ä¢ '{word}': {val:+.4f}")

In [None]:
# ============================================================
# Celta 7.2 - VISUALIZACI√ìN DE PALABRAS INFLUYENTES
# ============================================================

# Grafico para identificar palabras m√°s influyentes POSITIVE y NEGATIVE
for i, result in enumerate(shap_results, 1):
    text = result['text']
    shap_vals = result['shap_values']
    pred_label = result['predicted_label']
    
    print(f"\nüìä Caso {i} - Visualizaci√≥n de palabras influyentes:")
    print(f"  Texto: '{text[:80]}...'")
    print(f"  Predicci√≥n: {pred_label}")
    
    # Visualizaci√≥n Bar Plot para POSITIVE
    plt.figure(figsize=(10, 6))
    shap.plots.bar(shap_vals[0, :, 1], max_display=10)
    
    # Visualizaci√≥n Bar Plot para NEGATIVE
    plt.figure(figsize=(10, 6))
    shap.plots.bar(shap_vals[0, :, 0], max_display=10)

In [None]:
# ============================================================
# Celda 8 - COMPARACI√ìN DE EXPLICACIONES
# ============================================================
print("\nüìä COMPARACI√ìN DE EXPLICACIONES ENTRE CASOS")
print("="*60)

# Crear figura comparativa
fig, axes = plt.subplots(len(shap_results), 1, figsize=(12, 4*len(shap_results)))
if len(shap_results) == 1:
    axes = [axes]

for idx, (result, ax) in enumerate(zip(shap_results, axes)):
    # Extraer valores SHAP para clase positiva
    shap_val = result['shap_values'][0]
    
    # Obtener tokens y valores
    tokens = shap_val.data
    values = shap_val.values[:, 1]  # Clase POSITIVE
    
    # Seleccionar top 10 tokens m√°s importantes (por valor absoluto)
    top_indices = np.argsort(np.abs(values))[-10:]
    top_tokens = [tokens[i] for i in top_indices]
    top_values = [values[i] for i in top_indices]
    
    # Crear bar plot
    colors = ['green' if v > 0 else 'red' for v in top_values]
    ax.barh(range(len(top_tokens)), top_values, color=colors, alpha=0.7)
    ax.set_yticks(range(len(top_tokens)))
    ax.set_yticklabels(top_tokens)
    ax.set_xlabel('SHAP Value (impacto en predicci√≥n POSITIVE)')
    ax.set_title(f"Caso {idx+1}: {result['category']}", fontweight='bold')
    ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
    ax.grid(True, alpha=0.3)

plt.suptitle("Comparaci√≥n de Palabras M√°s Influyentes por Caso", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()


In [None]:
# ============================================================
# Celda 9 - AN√ÅLISIS DE ESTABILIDAD
# ============================================================
print("\nüî¨ AN√ÅLISIS DE ESTABILIDAD DE SHAP")
print("="*60)

# Usar un caso real ambiguo de tu dataset "index": 301, confianza ~51.6%
stability_text = priority_cases[301]['text'] 
print(f"\nüìù Texto ambiguo (confianza: 51.6%):")
print(f"'{stability_text[:100]}...'")

# Predicci√≥n inicial
pred = classifier(stability_text)[0]
print(f"\nüéØ Predicci√≥n: {pred['label']} (confianza: {pred['score']:.1%})")
print("   ‚ö†Ô∏è Caso fronterizo - ideal para probar estabilidad")

# Realizar m√∫ltiples explicaciones del mismo texto
n_runs = 3
stability_results = []
print(f"\nüîÑ Ejecutando {n_runs} explicaciones del mismo texto...")
for run in range(n_runs):
    print(f"  Run {run+1}/{n_runs}...", end='')
    shap_vals = explainer([stability_text])
    stability_results.append(shap_vals[0].values[:, 1])  # Clase POSITIVE
    print(" ‚úì")

# Calcular estad√≠sticas
tokens = shap_vals[0].data
variances = np.var([result for result in stability_results], axis=0)
mean_values = np.mean([result for result in stability_results], axis=0)
std_devs = np.std([result for result in stability_results], axis=0)

# Verificar estabilidad global
max_variance = np.max(variances)
is_deterministic = max_variance < 1e-10

print(f"\nüìä RESULTADOS DE ESTABILIDAD:")
print("="*60)

print(f"\n‚úÖ VEREDICTO: {'DETERMIN√çSTICO' if is_deterministic else 'ESTOC√ÅSTICO'}")
print(f"   ‚Ä¢ Varianza m√°xima: {max_variance:.2e}")
print(f"   ‚Ä¢ {'Todos los runs dan valores id√©nticos' if is_deterministic else 'Los valores var√≠an entre runs'}")

# Palabras m√°s influyentes (por valor absoluto)
abs_means = np.abs(mean_values)
top_indices = np.argsort(abs_means)[-5:][::-1]

print(f"\nüéØ TOP 5 PALABRAS M√ÅS INFLUYENTES:")
print(f"\n Si todos los valores son iguales entre runs, œÉ=0 para cada palabra")
print(f"\n Quiere decir que SHAP es reproducible y confiable")
print("-"*40)
for idx in top_indices:
    token = tokens[idx] if tokens[idx].strip() else '[SPACE]'
    impact = "‚Üí POSITIVE" if mean_values[idx] > 0 else "‚Üí NEGATIVE"
    print(f"  {token:20s}: Œº={mean_values[idx]:+.4f} {impact}")
    if not is_deterministic:
        print(f"  {'':20s}  œÉ={std_devs[idx]:.6f}")

# Mostrar comparaci√≥n entre runs si hay varianza
if not is_deterministic:
    print(f"\n‚ö†Ô∏è VARIABILIDAD DETECTADA:")
    print(f"   ‚Ä¢ Algunas palabras tienen valores SHAP que var√≠an entre runs")
    varying_tokens = np.where(variances > 1e-10)[0]
    for idx in varying_tokens[:3]:
        print(f"\n  Token: '{tokens[idx]}'")
        for run_idx, result in enumerate(stability_results):
            print(f"    Run {run_idx+1}: {result[idx]:.6f}")
else:
    print(f"\n‚úÖ CONSISTENCIA PERFECTA:")
    print(f"   ‚Ä¢ Cada palabra tiene exactamente el mismo valor SHAP en todos los runs")
    print(f"   ‚Ä¢ Esto confirma que SHAP es reproducible y confiable")

# Interpretaci√≥n del caso
print(f"\nüí° INTERPRETACI√ìN DEL CASO AMBIGUO:")
print("-"*40)
positive_words = sum(1 for v in mean_values if v > 0.01)
negative_words = sum(1 for v in mean_values if v < -0.01)
print(f"  ‚Ä¢ Palabras positivas: {positive_words}")
print(f"  ‚Ä¢ Palabras negativas: {negative_words}")
print(f"  ‚Ä¢ Balance: {'Inclinado a NEGATIVE' if negative_words > positive_words else 'Inclinado a POSITIVE'}")
print(f"  ‚Ä¢ Por eso la confianza es baja (~51%)")

In [None]:
# ============================================================
# Celda 10 - EXPORTAR RESULTADOS
# ============================================================
print("\nüíæ EXPORTANDO RESULTADOS DE SHAP")
print("="*60)

# Preparar datos para exportaci√≥n
export_data = {
    'metadata': {
        'model': MODEL_NAME,
        'method': 'SHAP',
        'n_cases_analyzed': len(shap_results),
        'avg_computation_time': avg_time,
        'timestamp': pd.Timestamp.now().isoformat()
    },
    'cases': []
}

for result in shap_results:
    case_data = {
        'text': result['text'],
        'true_label': result['true_label'],
        'category': result['category'],
        'computation_time': result['computation_time'],
        'top_positive_words': [],
        'top_negative_words': []
    }
    
    # Extraer palabras m√°s importantes
    shap_val = result['shap_values'][0]
    tokens = shap_val.data
    values = shap_val.values[:, 1]  # Clase POSITIVE
    
    # Top palabras positivas
    pos_indices = np.where(values > 0)[0]
    if len(pos_indices) > 0:
        top_pos = pos_indices[np.argsort(values[pos_indices])[-5:]]
        case_data['top_positive_words'] = [
            {'token': tokens[i], 'value': float(values[i])}
            for i in top_pos
        ]
    
    # Top palabras negativas
    neg_indices = np.where(values < 0)[0]
    if len(neg_indices) > 0:
        top_neg = neg_indices[np.argsort(values[neg_indices])[:5]]
        case_data['top_negative_words'] = [
            {'token': tokens[i], 'value': float(values[i])}
            for i in top_neg
        ]
    
    export_data['cases'].append(case_data)

# Guardar JSON
output_path = f"{DATA_DIR}/shap_results_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(export_data, f, ensure_ascii=False, indent=2)

print(f"‚úÖ Resultados guardados en: {output_path}")

In [None]:
# ============================================================
# Celda 11 - RESUMEN Y CONCLUSIONES
# ============================================================
print("\nüìä RESUMEN DEL AN√ÅLISIS CON SHAP")
print("="*60)

print("\n‚úÖ AN√ÅLISIS COMPLETADO:")
print(f"  ‚Ä¢ Casos analizados: {len(shap_results)}")
print(f"  ‚Ä¢ Tiempo promedio: {avg_time:.2f} segundos/caso")
print(f"  ‚Ä¢ Tiempo total: {sum(r['computation_time'] for r in shap_results):.2f} segundos")

print("\nüîç OBSERVACIONES CLAVE:")
print("  1. SHAP proporciona valores consistentes y matem√°ticamente fundamentados")
print("  2. El tiempo de c√≥mputo es significativo (~10-30s por texto)")
print("  3. Las explicaciones son determin√≠sticas (misma entrada = misma salida)")
print("  4. Permite identificar claramente palabras positivas vs negativas")

print("\nüí° VENTAJAS DE SHAP:")
print("  ‚úì Base te√≥rica s√≥lida (teor√≠a de juegos)")
print("  ‚úì Garant√≠as matem√°ticas (aditividad, consistencia)")
print("  ‚úì Explicaciones globales y locales")
print("  ‚úì Resultados reproducibles")

print("\n‚ö†Ô∏è LIMITACIONES DE SHAP:")
print("  ‚úó Computacionalmente costoso")
print("  ‚úó Requiere acceso al modelo completo")
print("  ‚úó Puede ser dif√≠cil de interpretar para usuarios no t√©cnicos")

print("\nüöÄ PR√ìXIMOS PASOS:")
print("  ‚Üí Comparar con resultados de LIME")
print("  ‚Üí Analizar casos donde SHAP y LIME difieren")
print("  ‚Üí Evaluar trade-off velocidad vs precisi√≥n")

print("\n" + "="*60)
print("‚úÖ Notebook de an√°lisis SHAP completado exitosamente")