In [None]:
# =============================================================================
# NOTEBOOK 2: ARQUITECTURA Y ENTRENAMIENTO
# =============================================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import pickle
import os

# --- CONFIGURACIÓN ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "data_processed"
MODELS_DIR = "models"

# Crear carpeta de modelos si no existe
os.makedirs(MODELS_DIR, exist_ok=True)

# Hiperparámetros
MAX_LEN = 40
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
N_LAYERS = 2
DROPOUT = 0.3
BATCH_SIZE = 64
LR = 1e-3
EPOCHS = 8 

print(f"Usando dispositivo: {DEVICE}")

# -----------------------------------------------------------------------------
# 1. CARGA DE DATOS Y VOCABULARIO
# -----------------------------------------------------------------------------
print("Cargando datos procesados...")
vocab = pickle.load(open(f"{DATA_DIR}/vocab.pkl", "rb"))
train_df = pd.read_csv(f"{DATA_DIR}/train.csv")
val_df = pd.read_csv(f"{DATA_DIR}/val.csv")

# -----------------------------------------------------------------------------
# 2. DATASET Y DATALOADERS
# -----------------------------------------------------------------------------
def encode_text(text, vocab, max_len):
    """Convierte texto en lista de IDs numéricos con padding."""
    tokens = str(text).lower().split()
    ids = [vocab.get(tok, vocab["<UNK>"]) for tok in tokens]
    
    if len(ids) < max_len:
        ids = ids + [vocab["<PAD>"]] * (max_len - len(ids))
    else:
        ids = ids[:max_len]
    return np.array(ids, dtype=np.int64)

class FinancialTweetsDataset(Dataset):
    def __init__(self, df, vocab, max_len):
        self.texts = df["text"].astype(str).tolist()
        self.labels = df["label"].astype(int).tolist()
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        x = encode_text(self.texts[idx], self.vocab, self.max_len)
        y = self.labels[idx]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

train_ds = FinancialTweetsDataset(train_df, vocab, MAX_LEN)
val_ds = FinancialTweetsDataset(val_df, vocab, MAX_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

print("Dataloaders listos.")

# -----------------------------------------------------------------------------
# 3. DEFINICIÓN DEL MODELO (RNN + ATENCIÓN)
# -----------------------------------------------------------------------------
class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim, 1)

    def forward(self, rnn_output, mask):
        # rnn_output: (batch, seq, hidden)
        scores = self.attn(rnn_output).squeeze(-1)
        # Enmascarar padding (poner -infinito donde sea padding)
        scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=1)
        # Contexto ponderado
        context = torch.sum(rnn_output * attn_weights.unsqueeze(-1), dim=1)
        return context

class RecurrentClassifier(nn.Module):
    def __init__(self, model_type, vocab_size, embed_dim, hidden_dim, out_dim, n_layers, dropout, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        
        # Selección dinámica de LSTM o GRU
        if model_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0)
        else:
            self.rnn = nn.GRU(embed_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0)
        
        self.attention = AttentionPooling(hidden_dim)
        self.fc = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.pad_idx = pad_idx
        self.model_type = model_type

    def forward(self, x):
        mask = (x != self.pad_idx).float()
        emb = self.embedding(x)
        
        if self.model_type == 'lstm':
            rnn_out, (h, c) = self.rnn(emb)
        else:
            rnn_out, h = self.rnn(emb)
            
        context = self.attention(rnn_out, mask)
        logits = self.fc(self.dropout(context))
        return logits

# -----------------------------------------------------------------------------
# 4. BUCLE DE ENTRENAMIENTO
# -----------------------------------------------------------------------------
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    count = 0
    
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        count += xb.size(0)
        
    return total_loss / count, correct / count

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    count = 0
    
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb)
            loss = criterion(logits, yb)
            
            total_loss += loss.item() * xb.size(0)
            correct += (logits.argmax(1) == yb).sum().item()
            count += xb.size(0)
            
    return total_loss / count, correct / count

# -----------------------------------------------------------------------------
# 5. EJECUCIÓN DEL ENTRENAMIENTO
# -----------------------------------------------------------------------------
# Pesos para balancear clases (opcional, ajustado a distribución típica)
class_weights = torch.tensor([1.0, 1.0, 1.0]).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Diccionario para guardar historiales
histories = {}

for m_type in ['lstm', 'gru']:
    print(f"\n{'='*40}")
    print(f" Entrenando Modelo: {m_type.upper()}")
    print(f"{'='*40}")
    
    model = RecurrentClassifier(
        model_type=m_type,
        vocab_size=len(vocab),
        embed_dim=EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM,
        out_dim=3, # 3 clases
        n_layers=N_LAYERS,
        dropout=DROPOUT,
        pad_idx=vocab["<PAD>"]
    ).to(DEVICE)
    
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
    
    best_acc = 0.0
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    for ep in range(EPOCHS):
        tl, ta = train_epoch(model, train_loader, optimizer, criterion)
        vl, va = evaluate(model, val_loader, criterion)
        scheduler.step()
        
        history['train_loss'].append(tl)
        history['val_loss'].append(vl)
        history['train_acc'].append(ta)
        history['val_acc'].append(va)
        
        print(f"Epoch {ep+1}/{EPOCHS} | Train Loss: {tl:.4f} Acc: {ta:.4f} | Val Loss: {vl:.4f} Acc: {va:.4f}")
        
        # Guardar mejor modelo
        if va > best_acc:
            best_acc = va
            state = {
                'model_state': model.state_dict(),
                'config': {
                    'model_type': m_type,
                    'vocab_size': len(vocab),
                    'embed_dim': EMBEDDING_DIM, 
                    'hidden_dim': HIDDEN_DIM,
                    'n_layers': N_LAYERS,
                    'dropout': DROPOUT,
                    'pad_idx': vocab["<PAD>"]
                },
                'vocab': vocab
            }
            save_path = f"{MODELS_DIR}/{m_type}_best_model.pth"
            torch.save(state, save_path)
            print(f" --> Nuevo récord! Modelo guardado en {save_path}")
            
    histories[m_type] = history

# Guardar historial completo para gráficas
with open(f"{MODELS_DIR}/history.pkl", "wb") as f:
    pickle.dump(histories, f)

print("\n¡Entrenamiento finalizado para ambos modelos!")