In [None]:
# =============================================================================
# NOTEBOOK 3: EVALUACIÓN E INFERENCIA
# =============================================================================

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix

# --- CONFIGURACIÓN ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "data_processed"
MODELS_DIR = "models"
ID2LABEL = {0: "Bajista", 1: "Alcista", 2: "Neutral"}

print(f"Usando dispositivo: {DEVICE}")

# -----------------------------------------------------------------------------
# 1. REDEFINICIÓN DE CLASES
# -----------------------------------------------------------------------------
# Necesario para que PyTorch pueda cargar la estructura del modelo guardado
class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim, 1)
    def forward(self, rnn_output, mask):
        scores = self.attn(rnn_output).squeeze(-1).masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=1)
        return torch.sum(rnn_output * attn_weights.unsqueeze(-1), dim=1)

class RecurrentClassifier(nn.Module):
    def __init__(self, model_type, vocab_size, embed_dim, hidden_dim, out_dim, n_layers, dropout, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        if model_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0)
        else:
            self.rnn = nn.GRU(embed_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0)
        self.attention = AttentionPooling(hidden_dim)
        self.fc = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.pad_idx = pad_idx
        self.model_type = model_type
    def forward(self, x):
        mask = (x != self.pad_idx).float()
        emb = self.embedding(x)
        if self.model_type == 'lstm':
            rnn_out, _ = self.rnn(emb)
        else:
            rnn_out, _ = self.rnn(emb)
        context = self.attention(rnn_out, mask)
        return self.fc(self.dropout(context))

# -----------------------------------------------------------------------------
# 2. CARGA DEL MODELO Y EVALUACIÓN
# -----------------------------------------------------------------------------
def load_model(path):
    print(f"Cargando modelo desde: {path}")
    checkpoint = torch.load(path, map_location=DEVICE)
    conf = checkpoint['config']
    
    model = RecurrentClassifier(
        model_type=conf['model_type'],
        vocab_size=conf['vocab_size'],
        embed_dim=conf['embed_dim'],
        hidden_dim=conf['hidden_dim'],
        out_dim=3,
        n_layers=conf['n_layers'],
        dropout=conf['dropout'],
        pad_idx=conf['pad_idx']
    )
    
    model.load_state_dict(checkpoint['model_state'])
    model.to(DEVICE)
    model.eval()
    return model, checkpoint['vocab']

# Cargar historial para graficar
histories = pickle.load(open(f"{MODELS_DIR}/history.pkl", "rb"))

# Graficar curvas de aprendizaje
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
for m_name, h in histories.items():
    plt.plot(h['train_loss'], label=f'{m_name.upper()} Train')
    plt.plot(h['val_loss'], '--', label=f'{m_name.upper()} Val')
plt.title("Pérdida (Loss)")
plt.xlabel("Épocas")
plt.legend()

plt.subplot(1, 2, 2)
for m_name, h in histories.items():
    plt.plot(h['train_acc'], label=f'{m_name.upper()} Train')
    plt.plot(h['val_acc'], '--', label=f'{m_name.upper()} Val')
plt.title("Precisión (Accuracy)")
plt.xlabel("Épocas")
plt.legend()

plt.tight_layout()
plt.show()

# -----------------------------------------------------------------------------
# 3. EVALUACIÓN EN TEST SET
# -----------------------------------------------------------------------------
# Elegimos el modelo LSTM para evaluar (puedes cambiar a 'gru_best_model.pth')
model_path = f"{MODELS_DIR}/lstm_best_model.pth" 
model, vocab = load_model(model_path)

test_df = pd.read_csv(f"{DATA_DIR}/test.csv")

def predict_batch(model, texts, vocab, max_len=40):
    encoded_list = []
    for t in texts:
        tokens = str(t).lower().split()
        ids = [vocab.get(tok, vocab["<UNK>"]) for tok in tokens]
        if len(ids) < max_len:
            ids = ids + [vocab["<PAD>"]] * (max_len - len(ids))
        else:
            ids = ids[:max_len]
        encoded_list.append(ids)
    
    x = torch.tensor(encoded_list, dtype=torch.long).to(DEVICE)
    with torch.no_grad():
        logits = model(x)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
    return preds

print("\nRealizando predicciones en Test Set...")
y_pred = predict_batch(model, test_df["text"].values, vocab)
y_true = test_df["label"].values

print("\n" + "="*50)
print(" REPORTE DE CLASIFICACIÓN (TEST SET)")
print("="*50)
print(classification_report(y_true, y_pred, target_names=["Bajista", "Alcista", "Neutral"]))

# Matriz de Confusión
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=["Bajista", "Alcista", "Neutral"], 
            yticklabels=["Bajista", "Alcista", "Neutral"])
plt.ylabel('Etiqueta Real')
plt.xlabel('Predicción')
plt.title('Matriz de Confusión')
plt.show()

# -----------------------------------------------------------------------------
# 4. INFERENCIA INTERACTIVA
# -----------------------------------------------------------------------------
def analizar_frase(texto):
    model.eval()
    pred_idx = predict_batch(model, [texto], vocab)[0]
    
    # Obtener probabilidades para mostrar confianza (opcional)
    # Requeriría modificar predict_batch para devolver softmax, pero esto basta
    etiqueta = ID2LABEL[pred_idx]
    
    print("-" * 40)
    print(f"Texto: '{texto}'")
    print(f"Predicción: {etiqueta} ({pred_idx})")
    print("-" * 40)

print("\n=== PRUEBAS EN VIVO ===")
analizar_frase("The market is crashing hard, everything is red!")
analizar_frase("Apple reported amazing earnings, stock is flying to the moon")
analizar_frase("The fed will announce interest rates tomorrow, market is waiting")