<a href="https://colab.research.google.com/github/lucasleao03/GSI073---Topicos-Especiais-de-Inteligencia-Artificial-LLMs-Large-Language-Models-/blob/main/GSI073_aula0_luong_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prepara√ß√£o dos dados

Esta tarefa √© inverter sequ√™ncias de caracteres. Exemplo: **aabcd** em **dcbaa**.


In [None]:
import torch
import torch.nn as nn
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt

chars = list("abcd ")
vocab = {ch: i for i, ch in enumerate(chars)} # Cada letra, ganha um n√∫mero
inv_vocab = {i: ch for ch, i in vocab.items()}# Tabela de decodifica√ß√£o
vocab_size = len(vocab)

def encode(s): # Codifica letras em n√∫meros
    return torch.tensor([vocab[c] for c in s], dtype=torch.long)

def decode(t): # Decodifica n√∫meros em letras
    return ''.join(inv_vocab[int(x)] for x in t)

def random_seq(n=5): # Cria novas sequ√™ncias
    return ''.join(random.choice(chars[:-1]) for _ in range(n))

# Gerar dados
pairs = [(encode(s), encode(s[::-1])) for s in [random_seq() for _ in range(50000)]]

max_len = max(len(x) for x, _ in pairs) # pega maior sequ√™ncia

def pad(x):  # Preenche conjunto de dados em pad no √∫ltimo √≠ndice
    return torch.cat([x, torch.tensor([vocab[' ']] * (max_len - len(x)))], dim=0)

inputs = torch.stack([pad(x) for x, _ in pairs])
targets = torch.stack([pad(y) for _, y in pairs])

train_ds = torch.utils.data.TensorDataset(inputs, targets)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Veja um par

In [None]:
print(pairs[1])

# Defini√ß√£o do modelo Seq2Seq com GRU

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.gru = nn.GRU(emb_size, hidden_size, batch_first=True)

    def forward(self, x):
        x = self.embed(x)
        outputs, h = self.gru(x)
        return outputs, h   # <--- ESSENCIAL

In [None]:
class LuongAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: (B, 1, H)
        encoder_outputs: (B, S, H)

        Retorna:
          context: (B, 1, H)
          attn_weights: (B, 1, S)
        """

        # score = h_t ¬∑ h_s^T
        # (B, 1, H) x (B, H, S) -> (B, 1, S)
        attn_scores = torch.bmm(decoder_hidden, encoder_outputs.transpose(1, 2))

        attn_weights = F.softmax(attn_scores, dim=-1)  # normaliza nos steps da source

        # context = soma ponderada
        # (B, 1, S) x (B, S, H) -> (B, 1, H)
        context = torch.bmm(attn_weights, encoder_outputs)

        return context, attn_weights

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.gru = nn.GRU(emb_size, hidden_size, batch_first=True)
        self.attn = LuongAttention()

        # Luong concat: concatena hidden + context
        self.fc = nn.Linear(hidden_size * 2, vocab_size)

    def forward(self, x, h, encoder_outputs):
        """
        x: tokens anteriores corretos  (B, T)
        h: estado inicial do decoder   (1, B, H)
        encoder_outputs: todos os h_s  (B, S, H)
        """
        x = self.embed(x)  # (B, T, E)

        outputs = []
        seq_len = x.size(1)
        hidden = h

        for t in range(seq_len):
            inp = x[:, t:t+1]  # (B, 1, E)

            out_t, hidden = self.gru(inp, hidden)   # out_t: (B,1,H)

            # Aten√ß√£o
            context, attn_w = self.attn(out_t, encoder_outputs)

            # concatena√ß√£o [out_t ; context]
            combined = torch.cat([out_t, context], dim=-1)

            logits = self.fc(combined)  # (B,1,V)
            outputs.append(logits)

        outputs = torch.cat(outputs, dim=1)  # (B, T, V)
        return outputs, hidden


In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt):
        encoder_outputs, h = self.encoder(src)
        logits, _ = self.decoder(tgt[:, :-1], h, encoder_outputs)
        return logits

# C√≥digo para usar o modelo treinado: infer√™ncia

In [None]:
def decode_step(decoder, token, h, encoder_outputs):
    """
    Executa um passo de decodifica√ß√£o:
    - token: tensor (B,1)
    - h: estado oculto do decoder (1,B,H)
    - encoder_outputs: (B,S,H)
    """
    logits, h = decoder(token, h, encoder_outputs)  # (B,1,V)
    next_token = logits[:, -1, :].argmax(-1, keepdim=True)  # (B,1)
    return next_token, h


def predict(model, seq, max_len=10):
    model.eval()
    with torch.no_grad():
        # codifica entrada
        src = pad(encode(seq)).unsqueeze(0).to(device, dtype=torch.long)

        # encoder agora retorna (encoder_outputs, h)
        encoder_outputs, h = model.encoder(src)

        # token inicial (ex: espa√ßo ou <sos>)
        token = torch.tensor([[vocab[' ']]], dtype=torch.long, device=device)

        seq_invertida = []
        for _ in range(max_len):
            token, h = decode_step(model.decoder, token, h, encoder_outputs)
            seq_invertida.append(token.item())

        return decode(seq_invertida)


In [None]:
def compara(seq_ori, seq_inv):
  res = 0;
  for i in range(len(seq_ori)):
    if(seq_ori[i] == seq_inv[(len(seq_ori)-1)-i]): res += 1
  return res/len(seq_ori)

# Prepara√ß√£o para treino

In [None]:
emb_size = 32
hidden_size = 64
encoder = Encoder(vocab_size, emb_size, hidden_size)
decoder = Decoder(vocab_size, emb_size, hidden_size)
model = Seq2Seq(encoder, decoder).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=vocab[' ']) # ignora o pad: " "
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

# Execu√ß√£o do treino

In [None]:
for epoch in range(10):
    model.train()
    total_loss = 0
    for xb, yb in train_dl:
        xb, yb = xb.to(device, dtype=torch.long), yb.to(device, dtype=torch.long)
        opt.zero_grad()
        logits = model(xb, yb)
        loss = loss_fn(logits.reshape(-1, vocab_size), yb[:, 1:].reshape(-1))
        loss.backward()
        opt.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: loss={total_loss/len(train_dl):.4f}")

# Vamos testar

In [None]:
acertos = []

for _ in range(100):
    s = random_seq()
    pred = predict(model, s, max_len=len(s))
    acc = compara(s, pred)
    acertos.append(acc)
    print(f"{s} -> {pred} --- Acertos = {acc*100:.2f}%")

# m√©dia geral
media_geral = sum(acertos) / len(acertos)
print(f"\nüéØ M√©dia geral de acertos: {media_geral*100:.2f}%")


In [None]:
def get_attention_matrix(model, seq, verbose=False):
    """
    Retorna:
      attn_scores  -> matriz bruta  (1, S)
      attn_weights -> ap√≥s softmax (1, S)
    """
    model.eval()
    with torch.no_grad():
        # Entrada
        src = pad(encode(seq)).unsqueeze(0).to(device)   # (1, S)

        # ====== Encoder ======
        encoder_outputs, h = model.encoder(src)          # encoder_outputs: (1, S, H)

        # ====== Token inicial (igual ao predict) ======
        token = torch.tensor([[vocab[' ']]],
                             dtype=torch.long,
                             device=device)               # (1,1)

        # ====== Um passo do decoder ======
        embedded = model.decoder.embed(token)            # (1,1,E)
        out_t, h_next = model.decoder.gru(embedded, h)   # out_t: (1,1,H)

        # ====== Aten√ß√£o (Q¬∑K·µÄ) ======
        # out_t:           (1, 1, H)
        # encoder_outputs: (1, S, H)
        # -> bmm =>        (1, 1, S)
        attn_scores = torch.bmm(out_t, encoder_outputs.transpose(1, 2))

        # Softmax para normalizar
        attn_weights = F.softmax(attn_scores, dim=-1)

        # Remover a dimens√£o do meio (fica 1√óS)
        attn_scores = attn_scores.squeeze(0)     # (1, S)
        attn_weights = attn_weights.squeeze(0)   # (1, S)

        # Debug opcional
        if verbose:
            print("Scores antes do softmax:")
            print(attn_scores.cpu())

            print("\nPesos de aten√ß√£o (softmax):")
            print(attn_weights.cpu())

        return attn_scores, attn_weights


In [None]:
seq = "abcd"
scores, weights = get_attention_matrix(model, seq)

plt.imshow(weights.cpu().numpy(), aspect='auto')
plt.colorbar()
plt.title("Mapa de Aten√ß√£o")
plt.show()


# Exerc√≠cio
Compare o resultado do uso do encoder-decoder com aten√ß√£o com o encoder-decoder sem aten√ß√£o.

In [None]:
"""Comparando os modelos com aten√ß√£o e sem aten√ß√£o com os seguintes par√¢metros:
n√∫mero de dados de treino = 50000,
emb_size = 32,
hidden_size = 64,
n√∫mero de √©pocas = 20,
n√∫mero de dados de testes = 100

Modelo com aten√ß√£o:
bdacb -> bcadb --- Acertos = 100.00%
bbdad -> dadbd --- Acertos = 80.00%
ccaab -> baaca --- Acertos = 80.00%
bbdba -> abdbb --- Acertos = 100.00%
bbdbd -> dbdbd --- Acertos = 80.00%
cdcda -> adcdc --- Acertos = 100.00%
acaaa -> aaaac --- Acertos = 60.00%
acaab -> baaaa --- Acertos = 80.00%
ddaad -> dadad --- Acertos = 60.00%
adbab -> babda --- Acertos = 100.00%
dcaab -> baaad --- Acertos = 80.00%
bdbcb -> bcbdb --- Acertos = 100.00%
dddcd -> ddcdd --- Acertos = 60.00%
adcab -> badac --- Acertos = 40.00%
bcbdb -> bdbcb --- Acertos = 100.00%
caddb -> bddac --- Acertos = 100.00%
cddba -> abddc --- Acertos = 100.00%
bdaad -> dadab --- Acertos = 60.00%
abcbc -> cbcba --- Acertos = 100.00%
bcbad -> dabcb --- Acertos = 100.00%
cbcca -> accbc --- Acertos = 100.00%
ddbcd -> dcdbd --- Acertos = 60.00%
cadad -> daddc --- Acertos = 80.00%
cbacd -> dcacc --- Acertos = 80.00%
dbbda -> adbdd --- Acertos = 80.00%
dbcda -> adcdd --- Acertos = 80.00%
bacdc -> cdcbd --- Acertos = 60.00%
bdabc -> cbadb --- Acertos = 100.00%
bbddc -> cddbd --- Acertos = 80.00%
cbadc -> cdabc --- Acertos = 100.00%
abdaa -> aadba --- Acertos = 100.00%
dcbdd -> ddbdc --- Acertos = 60.00%
bbcbb -> bbcbb --- Acertos = 100.00%
bcaba -> ababc --- Acertos = 60.00%
bbcdb -> bdcbb --- Acertos = 100.00%
bdcab -> bacdb --- Acertos = 100.00%
dbbab -> babbd --- Acertos = 100.00%
adcaa -> aacda --- Acertos = 100.00%
cacdc -> cdcca --- Acertos = 60.00%
dbbcb -> bcbbd --- Acertos = 100.00%
cbccd -> dccdc --- Acertos = 80.00%
badba -> abdab --- Acertos = 100.00%
ddcac -> cadcd --- Acertos = 60.00%
abdac -> cadab --- Acertos = 60.00%
addcd -> ddcdd --- Acertos = 40.00%
daddc -> cdddd --- Acertos = 80.00%
ccbab -> babcc --- Acertos = 100.00%
aadba -> abdaa --- Acertos = 100.00%
adabd -> dbadd --- Acertos = 80.00%
bdbab -> babdb --- Acertos = 100.00%
aadbc -> cbdaa --- Acertos = 100.00%
bdbbc -> cbbdb --- Acertos = 100.00%
ddccc -> ccdcd --- Acertos = 60.00%
dbacd -> dcdab --- Acertos = 40.00%
dacba -> abcad --- Acertos = 100.00%
acabb -> babca --- Acertos = 60.00%

üéØ M√©dia geral de acertos: 79.40%


Modelo sem aten√ß√£o:
bcddb -> dbbcd --- Acertos = 20.00%
bccbc -> ccbdc --- Acertos = 20.00%
cccaa -> acacd --- Acertos = 40.00%
aadbc -> cdbcd --- Acertos = 20.00%
bbccc -> cccbd --- Acertos = 80.00%
cdacc -> ccdcd --- Acertos = 40.00%
baccd -> dcccd --- Acertos = 60.00%
ccacd -> dcccd --- Acertos = 60.00%
cadbc -> cdbcd --- Acertos = 20.00%
cdaac -> caddc --- Acertos = 80.00%
bbdba -> abddc --- Acertos = 60.00%
dbbad -> ddabc --- Acertos = 40.00%
dbccd -> dccdb --- Acertos = 60.00%
acbab -> babcd --- Acertos = 80.00%
bddbc -> cdbdc --- Acertos = 40.00%
bbbdd -> ddbdc --- Acertos = 60.00%
dbdbb -> bbddc --- Acertos = 60.00%
adcdc -> cddcc --- Acertos = 40.00%
cbdbd -> ddbcb --- Acertos = 20.00%
dbaca -> accdd --- Acertos = 60.00%
caaac -> cadac --- Acertos = 80.00%
cacbc -> ccbdc --- Acertos = 40.00%
cdbca -> cadbb --- Acertos = 0.00%
dbbdb -> bdbdc --- Acertos = 60.00%
cacab -> bacdc --- Acertos = 80.00%
bbccc -> cccbd --- Acertos = 80.00%
cbdac -> cdadc --- Acertos = 40.00%
bddcd -> dcddc --- Acertos = 80.00%
bacbc -> ccbdb --- Acertos = 40.00%
acaaa -> aacad --- Acertos = 40.00%
dcdbb -> bdbcb --- Acertos = 40.00%
babbd -> dbbdc --- Acertos = 60.00%
ddaac -> caddd --- Acertos = 80.00%
cabac -> caddc --- Acertos = 60.00%
dddbd -> ddbcd --- Acertos = 40.00%
aabaa -> aaabc --- Acertos = 40.00%
cbbad -> ddacb --- Acertos = 20.00%
ccabb -> bbcdb --- Acertos = 40.00%
cacbc -> ccbdc --- Acertos = 40.00%
aaaad -> dadca --- Acertos = 60.00%
bdcdc -> cdcdd --- Acertos = 80.00%
cadca -> cadda --- Acertos = 20.00%
bddab -> bdaad --- Acertos = 20.00%
abdad -> ddadc --- Acertos = 20.00%
ddacd -> dcdcd --- Acertos = 60.00%
adabd -> ddbba --- Acertos = 40.00%
bbbcb -> bcbcd --- Acertos = 60.00%
addaa -> adaac --- Acertos = 20.00%
dcbbb -> bbdcb --- Acertos = 60.00%
cccac -> ccdaa --- Acertos = 20.00%
dbbab -> babdb --- Acertos = 60.00%
dadbc -> cdbdc --- Acertos = 20.00%
cadcd -> dcdcd --- Acertos = 60.00%
dccbc -> ccbdc --- Acertos = 20.00%
cccca -> caccd --- Acertos = 40.00%
bdacc -> ccddc --- Acertos = 60.00%
cbacd -> dccdc --- Acertos = 60.00%
abcaa -> acaab --- Acertos = 20.00%
bbbad -> ddabb --- Acertos = 60.00%
abdba -> abddc --- Acertos = 60.00%
aabcc -> ccdcb --- Acertos = 40.00%
baddd -> ddddc --- Acertos = 60.00%
bdadc -> dcdca --- Acertos = 0.00%
adcca -> cadcc --- Acertos = 0.00%
bacbb -> bbcdc --- Acertos = 60.00%
bdcdc -> cdcdd --- Acertos = 80.00%
ddbdc -> cddbd --- Acertos = 60.00%
bbddb -> dbbdb --- Acertos = 20.00%
ddbad -> ddadc --- Acertos = 40.00%
acdaa -> aadcc --- Acertos = 80.00%
bddcd -> dcddc --- Acertos = 80.00%
cbdba -> abdcd --- Acertos = 60.00%
cbbbc -> cbdcb --- Acertos = 40.00%
ccdab -> bdacc --- Acertos = 60.00%
dddac -> cdadd --- Acertos = 60.00%
bcdab -> bdaca --- Acertos = 40.00%
dbdba -> abddc --- Acertos = 60.00%
babab -> babbd --- Acertos = 60.00%
cdcbb -> bbcdd --- Acertos = 80.00%
cbaab -> baadc --- Acertos = 80.00%
bdadb -> dbbad --- Acertos = 0.00%
babab -> babbd --- Acertos = 60.00%
ccdcb -> cbdcb --- Acertos = 40.00%
daccc -> cccdd --- Acertos = 80.00%
abdbc -> cdbbc --- Acertos = 40.00%
cbccb -> cbbcd --- Acertos = 0.00%
ccaab -> abbcc --- Acertos = 40.00%
bcbad -> ddcab --- Acertos = 40.00%
dddac -> cdadd --- Acertos = 60.00%
abddc -> dccdd --- Acertos = 0.00%
bdccb -> cbdcb --- Acertos = 20.00%
dabda -> daaba --- Acertos = 0.00%
ddddc -> dcdcd --- Acertos = 40.00%
dbcdd -> ddcdc --- Acertos = 60.00%
bcaba -> abcac --- Acertos = 40.00%
bdbdd -> dddbc --- Acertos = 40.00%
cdbab -> badbc --- Acertos = 60.00%
abaaa -> aaada --- Acertos = 80.00%
bcccb -> cbcdb --- Acertos = 40.00%
bcabd -> dbcdd --- Acertos = 40.00%

üéØ M√©dia geral de acertos: 46.80%
"""





In [None]:
"""Comparando os modelos com aten√ß√£o e sem aten√ß√£o com os seguintes par√¢metros:
n√∫mero de dados de treino = 50000,
emb_size = 32,
hidden_size = 64,
n√∫mero de √©pocas = 20,
n√∫mero de dados de testes = 10000

Modelo com aten√ß√£o:
üéØ M√©dia geral de acertos: 77.32%

Modelo sem aten√ß√£o:
üéØ M√©dia geral de acertos: 57.28%
"""

