# **Self-Attention**


O Self-Attention, ou autoatenção, é um mecanismo fundamental em modelos de aprendizado profundo, especialmente em tarefas de processamento de linguagem natural (PLN) e visão computacional. Ele permite que o modelo aprenda a se concentrar nas partes mais relevantes de uma entrada, como uma sequência de palavras ou uma imagem, para realizar uma tarefa específica.

<img src="https://miro.medium.com/v2/resize:fit:856/1*ZCFSvkKtppgew3cc7BIaug.png" width="400">

[Attention Is All You Need](https://arxiv.org/abs/1706.03762)

## Incorporando uma frase de entrada

Considere a seguinte frase: "A vida é curta, coma a sobremesa primeiro" em ingles: 'Life is short, eat dessert first'

In [None]:
# criando uma incorporação da frase

In [None]:
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s
      in enumerate(sorted(sentence.replace(',', '').split()))}

print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


In [None]:
import torch
sentence_int = torch.tensor(
    [dc[s] for s in sentence.replace(',', '').split()]
)
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


Usando camada embedding para codificar as entradas em uma incorporação de vetor real

In [None]:
vocab_size = 50_000

torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)


tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])
torch.Size([6, 3])


6 Linhas, cada uma representando uma palavra

## Definindo as matrizes de pesos

<img src="https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fecef3e00-4c0e-4c7a-9a9f-42ac4e4ada69_366x786.png" width="300">

Cada entrada retornará 3 matrizes q, k e v

que são calculadas:

q = x(i) * Wq

k = x(i) * Wk

v = x(i) * Wv


In [None]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 2, 2, 4

W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))

## Calculo dos pesos de atenção não normalizados

Usaremos o 2° elemento de entrada como consulta

<img src="https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Ff9774001-ea9d-48bf-9857-3c911b0a279d_588x962.png" width="300">



In [None]:
x_2 = embedded_sentence[1] # 2 entrada

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([2])
torch.Size([2])
torch.Size([4])


Generalizando para calcular a key e value restante para todas as entradas

In [None]:
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 4])


Calculando os pesos de atenção não normalizados

<img src="https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbaf9e308-223b-429e-8527-a7b868003e8c_814x912.png" width="400">

In [None]:
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(1.2903, grad_fn=<DotBackward0>)


In [None]:
omega_2 = query_2 @ keys.T
print(omega_2)

tensor([-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374],
       grad_fn=<SqueezeBackward4>)


## Normalizando os pesos de atenção

<img src="https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F42da287a-18e8-45c7-860c-46e8a3a534fc_1400x798.png" width="700">

In [None]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],
       grad_fn=<SoftmaxBackward0>)


## Calculando o Vetor de Contexto

<img src="https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6e1dcdeb-e096-4ff9-bdf9-3338e4efa4b4_1916x1048.png" width="700">

In [None]:
context_vector_2 = attention_weights_2 @ values
print(context_vector_2.shape)
print(context_vector_2)

torch.Size([4])
tensor([0.5313, 1.3607, 0.7891, 1.3110], grad_fn=<SqueezeBackward4>)


Resumindo a self-attention em uma classe compacta

In [None]:
import torch.nn as nn

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T  # unnormalized attention weights
        attn_weights = torch.softmax(
            attn_scores / self.d_out_kq**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec

In [None]:
torch.manual_seed(123)

# reduza d_out_v de 4 para 1, porque temos 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

tensor([[-0.1564,  0.1028, -0.0763, -0.0764],
        [ 0.5313,  1.3607,  0.7891,  1.3110],
        [-0.3542, -0.1234, -0.2626, -0.3706],
        [ 0.0071,  0.3345,  0.0969,  0.1998],
        [ 0.1008,  0.4780,  0.2021,  0.3674],
        [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)
