# TP NLP ‚Äî T3 : **Attention** (Bahdanau) sur Seq2Seq ‚Äî Master IA

Ce notebook correspond au **Tutoriel 3 (T3)** : ajout d‚Äôun m√©canisme d‚Äô**attention** √† un mod√®le Seq2Seq.
On utilise un **encodeur BiLSTM** et un **d√©codeur LSTM avec attention additive (Bahdanau)**.

---
## üéØ Objectifs p√©dagogiques
- Expliquer le **goulot d‚Äô√©tranglement** du Seq2Seq (un seul vecteur contexte)
- Impl√©menter l‚Äô**attention additive** (Bahdanau)
- Visualiser les **poids d‚Äôattention** (heatmap)
- Comparer qualitativement et quantitativement avec T2

---
## üß† Intuition
Avec l‚Äôattention, le d√©codeur calcule √† chaque pas un vecteur contexte :
- score d‚Äôalignement entre l‚Äô√©tat du d√©codeur et chaque √©tat encodeur
- softmax ‚Üí poids d‚Äôattention
- somme pond√©r√©e ‚Üí contexte

---
## üß© Probl√®me √©tudi√©
On reprend : **inversion de s√©quence**
`[1, 5, 7, 3] ‚Üí [3, 7, 5, 1]`

Ce probl√®me rend l‚Äôattention facilement interpr√©table : le mod√®le doit ‚Äúpointer‚Äù sur la bonne position source.
---


In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
import matplotlib.pyplot as plt

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


## 1) Param√®tres et vocabulaire

In [None]:

V = 20
MIN_LEN, MAX_LEN = 3, 12

TRAIN_SIZE = 9000
VALID_SIZE = 1000
TEST_SIZE  = 1000

BATCH_SIZE = 64
EMBED_DIM  = 64
HIDDEN_DIM = 128  # hidden du d√©codeur

EPOCHS = 10
LR = 1e-3
TEACHER_FORCING_RATIO = 0.7

PAD = 0
SOS = V + 1
EOS = V + 2
VOCAB_SIZE = V + 3


## 2) Dataset + padding

In [None]:

def generate_pair():
    L = random.randint(MIN_LEN, MAX_LEN)
    src = [random.randint(1, V) for _ in range(L)]
    tgt = [SOS] + list(reversed(src)) + [EOS]
    return src, tgt

class ReverseDataset(Dataset):
    def __init__(self, n):
        self.data = [generate_pair() for _ in range(n)]
    def __len__(self): return len(self.data)
    def __getitem__(self, i): return self.data[i]

def pad(seqs, pad_value=PAD):
    m = max(len(s) for s in seqs)
    return torch.tensor([s+[pad_value]*(m-len(s)) for s in seqs], dtype=torch.long)

def collate(batch):
    src = pad([b[0] for b in batch])
    tgt = pad([b[1] for b in batch])
    tgt_in  = tgt[:, :-1]
    tgt_out = tgt[:, 1:]
    return src, tgt_in, tgt_out

train_loader = DataLoader(ReverseDataset(TRAIN_SIZE), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
valid_loader = DataLoader(ReverseDataset(VALID_SIZE), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)
test_loader  = DataLoader(ReverseDataset(TEST_SIZE),  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)


## 3) Encodeur BiLSTM

Pour l‚Äôattention, on a besoin de **tous** les √©tats de l‚Äôencodeur (pas seulement l‚Äô√©tat final).
L‚Äôencodeur retourne :
- `enc_outputs` : √©tats par pas de temps (dimension 2*HIDDEN_DIM car bidirectionnel)
- √©tat initial du d√©codeur `(h0, c0)` obtenu en concat√©nant forward/backward puis projection.


In [None]:

class BiEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc_h = nn.Linear(2*hidden_dim, hidden_dim)
        self.fc_c = nn.Linear(2*hidden_dim, hidden_dim)

    def forward(self, src):
        emb = self.emb(src)                          # [B,T,E]
        enc_outputs, (h, c) = self.lstm(emb)         # enc_outputs [B,T,2H], h,c [2,B,H]
        h_cat = torch.cat([h[0], h[1]], dim=1)        # [B,2H]
        c_cat = torch.cat([c[0], c[1]], dim=1)        # [B,2H]
        h0 = torch.tanh(self.fc_h(h_cat)).unsqueeze(0) # [1,B,H]
        c0 = torch.tanh(self.fc_c(c_cat)).unsqueeze(0) # [1,B,H]
        return enc_outputs, (h0, c0)


## 4) Attention additive (Bahdanau)

Score d‚Äôalignement : `e_{t,i} = v^T tanh(W_s s_t + W_h h_i)`
Puis :
- `alpha = softmax(e)` (avec mask pour ignorer PAD)
- `context = Œ£ alpha_i * h_i`


In [None]:

class BahdanauAttention(nn.Module):
    def __init__(self, dec_hidden, enc_hidden2):
        super().__init__()
        self.Ws = nn.Linear(dec_hidden, dec_hidden, bias=False)
        self.Wh = nn.Linear(enc_hidden2, dec_hidden, bias=False)
        self.v  = nn.Linear(dec_hidden, 1, bias=False)

    def forward(self, s_t, enc_outputs, src_mask):
        # s_t: [B,Hdec], enc_outputs: [B,Tsrc,2Henc], src_mask: [B,Tsrc]
        s = self.Ws(s_t).unsqueeze(1)               # [B,1,H]
        h = self.Wh(enc_outputs)                    # [B,T,H]
        e = self.v(torch.tanh(s + h)).squeeze(-1)   # [B,T]
        e = e.masked_fill(src_mask == 0, -1e9)
        alpha = torch.softmax(e, dim=1)             # [B,T]
        context = torch.bmm(alpha.unsqueeze(1), enc_outputs).squeeze(1)  # [B,2Henc]
        return context, alpha


## 5) D√©codeur avec attention

√Ä chaque pas :
1. Calcul du contexte via attention
2. Concat(embedding token courant, contexte)
3. LSTM ‚Üí nouvel √©tat
4. Pr√©diction √† partir de Concat(√©tat d√©codeur, contexte)


In [None]:

class AttnDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, dec_hidden, enc_hidden2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD)
        self.attn = BahdanauAttention(dec_hidden, enc_hidden2)
        self.lstm = nn.LSTM(embed_dim + enc_hidden2, dec_hidden, batch_first=True)
        self.fc   = nn.Linear(dec_hidden + enc_hidden2, vocab_size)

    def forward(self, x_tok, state, enc_outputs, src_mask):
        # x_tok: [B,1], state: (h,c) each [1,B,Hdec]
        h, c = state
        s_t = h.squeeze(0)  # [B,Hdec]

        context, alpha = self.attn(s_t, enc_outputs, src_mask)  # context [B,2Henc]
        emb = self.emb(x_tok)                                   # [B,1,E]
        lstm_in = torch.cat([emb.squeeze(1), context], dim=1).unsqueeze(1)  # [B,1,E+2Henc]

        out, (h, c) = self.lstm(lstm_in, (h, c))                 # out [B,1,Hdec]
        s_out = out.squeeze(1)                                   # [B,Hdec]

        logits = self.fc(torch.cat([s_out, context], dim=1)).unsqueeze(1)  # [B,1,V]
        return logits, (h, c), alpha


## 6) Mod√®le Seq2Seq avec attention

In [None]:

class AttnSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt_in, tf_ratio=0.7):
        B, Tsrc = src.shape
        _, Ttgt = tgt_in.shape

        src_mask = (src != PAD).long()

        enc_outputs, state = self.encoder(src)  # enc_outputs [B,Tsrc,2H], state ([1,B,H],[1,B,H])
        logits_all = torch.zeros(B, Ttgt, VOCAB_SIZE, device=src.device)
        attn_all   = torch.zeros(B, Ttgt, Tsrc, device=src.device)

        x = tgt_in[:, 0].unsqueeze(1)  # SOS
        for t in range(Ttgt):
            step_logits, state, alpha = self.decoder(x, state, enc_outputs, src_mask)
            logits_all[:, t:t+1, :] = step_logits
            attn_all[:, t, :] = alpha

            pred = step_logits.argmax(-1)
            if t + 1 < Ttgt:
                x = tgt_in[:, t+1].unsqueeze(1) if random.random() < tf_ratio else pred

        return logits_all, attn_all


## 7) Entra√Ænement / √©valuation

In [None]:

model = AttnSeq2Seq(
    BiEncoder(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM),
    AttnDecoder(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, enc_hidden2=2*HIDDEN_DIM)
).to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=PAD)

def token_acc(logits, targets):
    pred = logits.argmax(-1)
    mask = targets != PAD
    correct = (pred == targets) & mask
    return correct.sum().item() / mask.sum().item()

@torch.no_grad()
def exact_match(logits, targets):
    pred = logits.argmax(-1).detach().cpu().numpy()
    gold = targets.detach().cpu().numpy()
    ok = 0
    for i in range(gold.shape[0]):
        g = [t for t in gold[i].tolist() if t != PAD]
        p = [t for t in pred[i].tolist() if t != PAD]
        ok += int(p == g)
    return ok / gold.shape[0]

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss, total_em = 0.0, 0.0

    for src, tgt_in, tgt_out in loader:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)

        if train:
            optimizer.zero_grad()
            logits, _ = model(src, tgt_in, tf_ratio=TEACHER_FORCING_RATIO)
        else:
            logits, _ = model(src, tgt_in, tf_ratio=0.0)

        B, T, V = logits.shape
        loss = criterion(logits.reshape(B*T, V), tgt_out.reshape(B*T))

        if train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        total_loss += loss.item()
        total_em   += exact_match(logits, tgt_out)

    n = len(loader)
    return total_loss/n, total_em/n

hist = {"tr_loss":[], "va_loss":[], "tr_em":[], "va_em":[]}

for e in range(1, EPOCHS+1):
    tr_loss, tr_em = run_epoch(train_loader, train=True)
    va_loss, va_em = run_epoch(valid_loader, train=False)

    hist["tr_loss"].append(tr_loss); hist["va_loss"].append(va_loss)
    hist["tr_em"].append(tr_em);     hist["va_em"].append(va_em)

    print(f"Epoch {e:02d} | train loss {tr_loss:.4f} EM {tr_em:.3f} | valid loss {va_loss:.4f} EM {va_em:.3f}")


## 8) Courbes

In [None]:

plt.figure()
plt.plot(hist["tr_loss"], label="train loss")
plt.plot(hist["va_loss"], label="valid loss")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.show()

plt.figure()
plt.plot(hist["tr_em"], label="train exact match")
plt.plot(hist["va_em"], label="valid exact match")
plt.xlabel("epoch"); plt.ylabel("exact match"); plt.legend(); plt.show()


## 9) Inference greedy + attention

In [None]:

@torch.no_grad()
def greedy_decode_with_attention(model, src_seq, max_len=30):
    model.eval()
    src = torch.tensor([src_seq], dtype=torch.long, device=device)  # [1,Tsrc]
    src_mask = (src != PAD).long()
    enc_outputs, state = model.encoder(src)

    x = torch.tensor([[SOS]], dtype=torch.long, device=device)
    out_tokens = []
    attn_weights = []

    for _ in range(max_len):
        logits, state, alpha = model.decoder(x, state, enc_outputs, src_mask)
        pred = logits.argmax(-1)
        tok = pred.item()
        out_tokens.append(tok)
        attn_weights.append(alpha.squeeze(0).detach().cpu().numpy())

        x = pred
        if tok == EOS:
            break

    return out_tokens, np.stack(attn_weights, axis=0)

example_src = [1, 5, 7, 3]
pred_tokens, attn = greedy_decode_with_attention(model, example_src, max_len=20)
print("src :", example_src)
print("pred:", pred_tokens)
print("gold:", list(reversed(example_src)) + [EOS])
print("attn shape:", attn.shape)


## 10) Heatmap des poids d‚Äôattention

In [None]:

def show_attention(attn, src_seq, pred_seq):
    src_labels = [str(x) for x in src_seq]
    tgt_labels = [str(x) for x in pred_seq]

    plt.figure(figsize=(max(6, len(src_labels)), max(4, len(tgt_labels)*0.6)))
    plt.imshow(attn, aspect="auto")
    plt.colorbar()
    plt.xticks(range(len(src_labels)), src_labels)
    plt.yticks(range(len(tgt_labels)), tgt_labels)
    plt.xlabel("Source tokens")
    plt.ylabel("Predicted tokens")
    plt.title("Attention weights (Bahdanau)")
    plt.show()

show_attention(attn, example_src, [t for t in pred_tokens if t != PAD])


## 11) Test final

In [None]:

@torch.no_grad()
def evaluate_test(model, loader):
    model.eval()
    total_em = 0.0
    for src, tgt_in, tgt_out in loader:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
        logits, _ = model(src, tgt_in, tf_ratio=0.0)
        total_em += exact_match(logits, tgt_out)
    return total_em / len(loader)

test_em = evaluate_test(model, test_loader)
print(f"TEST exact-match: {test_em:.3f}")


---
## 12) Questions √† rendre

1. Donnez la d√©finition du goulot d‚Äô√©tranglement dans T1/T2.  
2. √âcrivez les √©tapes de calcul de l‚Äôattention (score ‚Üí softmax ‚Üí contexte).  
3. Interpr√©tez une heatmap : que signifie une **ligne** (temps de sortie) ? une **colonne** (position source) ?  
4. Pourquoi l‚Äôattention aide-t-elle sur les longues s√©quences ?  
5. Quelles diff√©rences avec la **self-attention** du Transformer (T4) ?  
---
