# Emotional Attention (Dual-Head) con **propagación afectiva**

Notebook listo para **Google Colab**. Entrena un mini decoder con **Atención Emocional** para **aprender**
a responder con tono emocional derivado de la entrada (no clasifica). Incluye:
- Capa **Dual-Head Emotional Attention + Gating** en el bloque Transformer.
- Señales de entrenamiento **sin etiquetas**: *LM + propagación* entrada→salida y **distil latente opcional** con teacher congelado.
- Generación condicionada por el **estado emocional inferido** del input.
- Evaluación por **alineación intrínseca** (coseno entre emoción de entrada y emoción de salida generada).

⚠️ Es un POC mínimo con dataset pequeño. Reemplaza por tus pares (input→respuesta) para resultados reales.

In [None]:
#@title Instalación mínima sin conflictos (reinicio limpio recomendado)
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TRANSFORMERS_NO_TORCHVISION"] = "1"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

# Limpia paquetes conflictivos
%pip uninstall -y -q pysentimiento accelerate transformers tokenizers huggingface_hub safetensors peft

# Set estable para este notebook
%pip install -q "transformers==4.41.2" "safetensors==0.4.2"

import transformers, torch, sys
print("transformers:", transformers.__version__, "| python:", sys.version)

# Reinicia runtime para entorno limpio
import os; os._exit(0)

In [None]:
#@title Imports
import math, random
from typing import List, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from transformers import AutoModel, AutoTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

## Dataset mini (pares input→respuesta)

In [None]:
#@title Datos y utilidades
vocab = ['<pad>', '<bos>', '<eos>', '<sep>',
         'te','quiero','mamá','hola','cariño','yo','tambien','estoy','triste','lo','siento','aqui','estoy2',
         'no','odio','mucho','calma','hablemos','por','favor','feliz','me','alegra','porque']

stoi = {w:i for i,w in enumerate(vocab)}
itos = {i:w for w,i in stoi.items()}

def encode(words: List[str]):
    return [stoi[w] for w in words]

pairs = [
    (['te','quiero','mamá'],                 ['yo','tambien','te','quiero']),
    (['hola','mamá'],                        ['hola','cariño']),
    (['estoy','triste'],                     ['lo','siento','aqui','estoy2']),
    (['no','te','quiero'],                   ['lo','siento','hablemos']),
    (['te','odio','mucho'],                  ['calma','hablemos','por','favor']),
    (['feliz'],                              ['me','alegra','mucho']),
]

random.seed(42)
random.shuffle(pairs)
split = int(0.7*len(pairs))
train_pairs = pairs[:split]
valid_pairs = pairs[split:]

def build_batch(pairs):
    Xs, pad_masks, in_masks, out_masks = [], [], [], []
    for inp, out in pairs:
        ids = [stoi['<bos>']] + encode(inp) + [stoi['<sep>']] + encode(out) + [stoi['<eos>']]
        Xs.append(ids)
    maxlen = max(len(x) for x in Xs)
    for i, ids in enumerate(Xs):
        pad = [stoi['<pad>']] * (maxlen - len(ids))
        Xs[i] = ids + pad
        pad_masks.append([0]*len(ids) + [1]*len(pad))
        in_mask = [0]*maxlen
        out_mask = [0]*maxlen
        if stoi['<sep>'] in ids:
            sep_pos = ids.index(stoi['<sep>'])
        else:
            sep_pos = 1
        for t in range(0, sep_pos): in_mask[t] = 1
        for t in range(sep_pos+1, len(ids)):
            if ids[t] != stoi['<pad>']: out_mask[t] = 1
        in_masks.append(in_mask)
        out_masks.append(out_mask)
    X = torch.tensor(Xs, device=device)
    PM = torch.tensor(pad_masks, device=device).bool()
    INM = torch.tensor(in_masks, device=device).bool()
    OUTM = torch.tensor(out_masks, device=device).bool()
    return X, PM, INM, OUTM

Xtr, Mtr, INtr, OUTtr = build_batch(train_pairs)
Xva, Mva, INva, OUTva = build_batch(valid_pairs)

print('Train batch:', Xtr.shape, 'Valid batch:', Xva.shape)

## Teacher congelado (opcional)

In [None]:
#@title Cargar Teacher HF (opcional)
USE_DISTIL = True  # pon a False si no quieres distil

HF_NAME = "pysentimiento/robertuito-sentiment-analysis"
tok_teacher = AutoTokenizer.from_pretrained(HF_NAME, use_fast=True)
teacher = AutoModel.from_pretrained(HF_NAME).to(device)
for p in teacher.parameters():
    p.requires_grad = False

d_emo = 64
teacher_proj = nn.Sequential(
    nn.Linear(teacher.config.hidden_size, teacher.config.hidden_size),
    nn.GELU(),
    nn.Linear(teacher.config.hidden_size, d_emo),
).to(device)

@torch.no_grad()
def teacher_vec(words: List[str]) -> torch.Tensor:
    text = " ".join(words)
    enc = tok_teacher(text, return_tensors='pt', truncation=True, max_length=128).to(device)
    hs = teacher(**enc).last_hidden_state
    h_pool = hs.mean(dim=1)
    g_hat = teacher_proj(h_pool)[0]
    return g_hat

## Bloque Emotional Attention

In [None]:
#@title Implementación del bloque dual
def _shape(x, B, T, n_heads, d_head):
    return x.view(B, T, n_heads, d_head).transpose(1, 2)

class MultiHeadSelfAttn(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.nh = n_heads; self.dh = d_model // n_heads
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.o = nn.Linear(d_model, d_model)

    def forward(self, H, key_padding_mask=None, causal=True, return_attn=False):
        B, T, D = H.shape
        q = _shape(self.q(H), B, T, self.nh, self.dh)
        k = _shape(self.k(H), B, T, self.nh, self.dh)
        v = _shape(self.v(H), B, T, self.nh, self.dh)
        scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.dh)
        if causal:
            idx = torch.arange(T, device=H.device)
            causal_mask = (idx[None, :] <= idx[:, None]).float()
            scores = scores + (causal_mask[None, None, :, :] - 1) * 1e9
        if key_padding_mask is not None:
            mask = key_padding_mask[:, None, None, :].bool()
            scores = scores.masked_fill(mask, float('-inf'))
        A = torch.softmax(scores, dim=-1)
        O = (A @ v).transpose(1,2).contiguous().view(B,T,D)
        O = self.o(O)
        return (O, A) if return_attn else (O, None)

class DualHeadEmoAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_emo: int, dropout=0.1):
        super().__init__()
        self.sem = MultiHeadSelfAttn(d_model, n_heads)
        self.proj_u = nn.Linear(d_emo, d_model)
        self.emo = MultiHeadSelfAttn(d_model, n_heads)
        self.Wg_h = nn.Linear(d_model, d_model)
        self.Wg_e = nn.Linear(d_model, d_model)
        self.Wg_g = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, H, U, g, key_padding_mask=None, return_attn=False):
        O_sem, A_sem = self.sem(H, key_padding_mask=key_padding_mask, causal=True, return_attn=True)
        U_proj = self.proj_u(U)
        O_emo, A_emo = self.emo(U_proj, key_padding_mask=key_padding_mask, causal=True, return_attn=True)
        g_proj = self.Wg_g(self.proj_u(g)).expand(H.size(0), H.size(1), -1)
        G = torch.sigmoid(self.Wg_h(O_sem) + self.Wg_e(O_emo) + g_proj)
        mix = (1-G)*O_sem + G*O_emo
        out = self.norm(self.out(self.drop(mix)) + H)
        if return_attn:
            return out, {'A_sem':A_sem, 'A_emo':A_emo, 'G':G}
        return out, None

class EmoBlock(nn.Module):
    def __init__(self, d_model=256, n_heads=8, mlp_ratio=2.0, d_emo=64):
        super().__init__()
        self.dual = DualHeadEmoAttention(d_model, n_heads, d_emo)
        self.pre = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, int(mlp_ratio*d_model)),
            nn.GELU(),
            nn.Linear(int(mlp_ratio*d_model), d_model),
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(self, H, U, g, key_padding_mask=None, return_attn=False):
        H1, attn = self.dual(H, U, g, key_padding_mask=key_padding_mask, return_attn=True)
        H2 = self.norm(self.mlp(self.pre(H1)) + H1)
        return (H2, attn) if return_attn else (H2, None)

## Modelo completo

In [None]:
#@title Decoder emocional
class EmoDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, d_emo=64, n_layers=2):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(512, d_model)
        self.blocks = nn.ModuleList([EmoBlock(d_model, n_heads, mlp_ratio=2.0, d_emo=d_emo) for _ in range(n_layers)])
        self.lm_head = nn.Linear(d_model, vocab_size)

        self.emo_in_head  = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_emo))
        self.emo_out_head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_emo))
        self.u_head = nn.Linear(d_model, d_emo)

    def forward(self, X, pad_mask, in_mask, out_mask, return_attn=False):
        B, T = X.shape
        pos = torch.arange(T, device=X.device)[None, :].expand(B, T)
        H = self.tok(X) + self.pos(pos)

        H_in = H.masked_fill(~in_mask[...,None], 0.0)
        denom_in = in_mask.sum(1).clamp(min=1).view(B,1).float()
        H_in_pool = H_in.sum(1) / denom_in
        g_in = self.emo_in_head(H_in_pool).unsqueeze(1)   # (B,1,d_emo)

        U = self.u_head(H)

        attn_last = None
        for blk in self.blocks:
            H, attn = blk(H, U, g_in, key_padding_mask=pad_mask, return_attn=True)
            U = self.u_head(H)
            attn_last = attn

        logits = self.lm_head(H)

        H_out = H.masked_fill(~out_mask[...,None], 0.0)
        denom_out = out_mask.sum(1).clamp(min=1).view(B,1).float()
        H_out_pool = H_out.sum(1) / denom_out
        g_out = self.emo_out_head(H_out_pool)             # (B,d_emo)

        if return_attn:
            return logits, g_in.squeeze(1), g_out, attn_last
        return logits, g_in.squeeze(1), g_out, None

## Entrenamiento

In [None]:
#@title Entrenar (LM + propagación + distil opcional)
d_model, d_emo, n_heads, n_layers = 256, 64, 8, 2
model = EmoDecoder(len(vocab), d_model=d_model, n_heads=n_heads, d_emo=d_emo, n_layers=n_layers).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
def shift_targets(X): return torch.roll(X, shifts=-1, dims=1)
targets_tr = shift_targets(Xtr); targets_va = shift_targets(Xva)

lambda_prop, lambda_distil, lambda_sparse = 0.5, 0.3, 1e-4

def cosine_loss(a, b):
    a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
    return (1 - (a*b).sum(-1)).mean()

for ep in range(50):
    model.train(); opt.zero_grad()
    logits, g_in, g_out, attn = model(Xtr, Mtr, INtr, OUTtr, return_attn=True)
    L_lm = F.cross_entropy(logits.view(-1, logits.size(-1)), targets_tr.view(-1), ignore_index=stoi['<pad>'])
    L_prop = cosine_loss(g_in, g_out)

    if USE_DISTIL:
        g_in_hat_list, g_out_hat_list = [], []
        for (inp, out) in train_pairs:
            g_in_hat_list.append(teacher_vec(inp))
            g_out_hat_list.append(teacher_vec(out))
        g_in_hat  = torch.stack(g_in_hat_list,  dim=0).to(device)
        g_out_hat = torch.stack(g_out_hat_list, dim=0).to(device)
        L_distil = F.mse_loss(g_in, g_in_hat) + F.mse_loss(g_out, g_out_hat)
    else:
        L_distil = torch.tensor(0.0, device=device)

    A_emo = attn['A_emo']
    L_sparse = A_emo.abs().mean()

    loss = L_lm + lambda_prop*L_prop + lambda_distil*L_distil + lambda_sparse*L_sparse
    loss.backward(); opt.step()

    if (ep+1) % 10 == 0:
        with torch.no_grad():
            model.eval()
            logits_v, gi_v, go_v, _ = model(Xva, Mva, INva, OUTva, return_attn=False)
            L_lm_v = F.cross_entropy(logits_v.view(-1, logits_v.size(-1)),
                                     shift_targets(Xva).view(-1),
                                     ignore_index=stoi['<pad>']).item()
            L_prop_v = cosine_loss(gi_v, go_v).item()
            ppl = math.exp(min(10, L_lm_v))
        print(f"Ep{ep+1:02d} | L={loss.item():.3f} | LM={L_lm.item():.3f} | PROP={L_prop.item():.3f} | DISTIL={L_distil.item():.3f} | PPL={ppl:.2f}")

## Generación condicionada por g_in (entrada)

In [None]:
#@title Generate con top-k y EMA
def sample_topk(logits_row, k=5, temp=0.9):
    probs = F.softmax(logits_row / temp, dim=-1)
    topk = torch.topk(probs, k)
    idx = torch.multinomial(topk.values, 1)
    return topk.indices[idx].item()

@torch.no_grad()
def infer_g_in_from_words(words: List[str]):
    ids = [stoi['<bos>']] + encode(words) + [stoi['<sep>']]
    X = torch.tensor([ids], device=device)
    pad = torch.zeros_like(X).bool()
    in_mask = torch.zeros_like(X).bool(); in_mask[:, :len(ids)-1] = True
    out_mask = torch.zeros_like(X).bool()
    _, g_in, _, _ = model(X, pad, in_mask, out_mask, return_attn=True)
    return g_in.unsqueeze(1)  # (1,1,d_emo)

@torch.no_grad()
def generate_response(input_words: List[str], max_new=12, k=5, temp=0.9, ema_alpha=0.15):
    ids = [stoi['<bos>']] + encode(input_words) + [stoi['<sep>']]
    X = torch.tensor([ids], device=device)

    def make_masks(X):
        B, T = X.shape
        pad = torch.zeros_like(X).bool()
        in_mask = torch.zeros_like(X).bool()
        out_mask = torch.zeros_like(X).bool()
        if stoi['<sep>'] in X[0].tolist():
            sep_pos = X[0].tolist().index(stoi['<sep>'])
        else:
            sep_pos = max(1, T-1)
        in_mask[:, :sep_pos] = True
        out_mask[:, sep_pos+1:] = True
        return pad, in_mask, out_mask

    g_t = infer_g_in_from_words(input_words)

    for _ in range(max_new):
        pad, in_mask, out_mask = make_masks(X)
        logits, _, _, _ = model(X, pad, in_mask, out_mask, return_attn=False)
        next_id = sample_topk(logits[0, -1], k=k, temp=temp)
        X = torch.cat([X, torch.tensor([[next_id]], device=device)], dim=1)
        if next_id == stoi['<eos>']:
            break
        pad, in_mask, out_mask = make_masks(X)
        _, _, g_out_step, _ = model(X, pad, in_mask, out_mask, return_attn=False)
        g_t = (1-ema_alpha)*g_t + ema_alpha*g_out_step.unsqueeze(1)
    return X[0].tolist()

def decode(ids):
    return ' '.join(itos[int(i)] for i in ids if itos[int(i)] not in ['<pad>','<bos>'])

for inp, _ in valid_pairs:
    gen_ids = generate_response(inp, max_new=10, k=5, temp=0.9)
    print('IN :', ' '.join(inp))
    print('OUT:', decode(gen_ids))

## Evaluación intrínseca (coseno entrada→salida)

In [None]:
#@title Alineación emocional
@torch.no_grad()
def emo_alignment_score(input_words: List[str], gen_ids: List[int]):
    g_in = infer_g_in_from_words(input_words).squeeze(0).squeeze(0)
    toks = gen_ids
    if stoi['<sep>'] in toks:
        start = toks.index(stoi['<sep>'])+1
    else:
        start = max(1, len(toks)//2)
    if stoi['<eos>'] in toks:
        end = toks.index(stoi['<eos>'])+1
    else:
        end = len(toks)
    X = torch.tensor([toks[:end]], device=device)
    pad = torch.zeros_like(X).bool()
    in_mask = torch.zeros_like(X).bool()
    out_mask = torch.zeros_like(X).bool(); out_mask[:, start:end] = True
    _, _, g_out, _ = model(X, pad, in_mask, out_mask, return_attn=False)
    gi = F.normalize(g_in, dim=-1)
    go = F.normalize(g_out, dim=-1)
    return float(torch.dot(gi, go).clamp(-1,1))

rows = []; scores = []
for inp, _ in valid_pairs:
    gen_ids = generate_response(inp, max_new=12, k=5, temp=0.9)
    score = emo_alignment_score(inp, gen_ids)
    rows.append((inp, decode(gen_ids), round(score,3)))
    scores.append(score)
print('Alineación media (coseno):', round(float(torch.tensor(scores).mean()),3))
for r in rows:
    print('IN :', ' '.join(r[0]))
    print('OUT:', r[1])
    print('COS:', r[2])

## Visualización de la atención emocional

In [None]:
#@title Heatmaps
@torch.no_grad()
def visualize_attention(example_idx=0):
    inp, out = valid_pairs[example_idx]
    ids = [stoi['<bos>']] + encode(inp) + [stoi['<sep>']] + encode(out) + [stoi['<eos>']]
    X = torch.tensor([ids], device=device)
    pad = torch.zeros_like(X).bool()
    in_mask = torch.zeros_like(X).bool()
    sep_pos = ids.index(stoi['<sep>'])
    in_mask[:, :sep_pos] = True
    out_mask = torch.zeros_like(X).bool(); out_mask[:, sep_pos+1:] = True
    _, _, _, attn = model(X, pad, in_mask, out_mask, return_attn=True)
    A_sem = attn['A_sem'][0].mean(0).cpu()
    A_emo = attn['A_emo'][0].mean(0).cpu()

    toks = [itos[i] for i in ids]
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(1,2, figsize=(10,4))
    im0 = axes[0].imshow(A_sem, aspect='auto'); axes[0].set_title('Atención SEM')
    axes[0].set_xticks(range(len(toks))); axes[0].set_xticklabels(toks, rotation=45, ha='right')
    axes[0].set_yticks(range(len(toks))); axes[0].set_yticklabels(toks); plt.colorbar(im0, ax=axes[0])
    im1 = axes[1].imshow(A_emo, aspect='auto'); axes[1].set_title('Atención EMO')
    axes[1].set_xticks(range(len(toks))); axes[1].set_xticklabels(toks, rotation=45, ha='right')
    axes[1].set_yticks(range(len(toks))); axes[1].set_yticklabels(toks); plt.colorbar(im1, ax=axes[1])
    plt.tight_layout(); plt.show()

visualize_attention(0)