In [2]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

## Self Attention (Intuition)

In [62]:
torch.manual_seed(42)

a = torch.tril(torch.ones((3,3))).float()
a_norm = a / a.sum(dim=1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
print(a)
print(b)
print(a_norm)
print('Soma dos valores')
c = a @ b
print(c)
#
print('Média dos valores')
c = a @ b
print(c)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
Soma dos valores
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
Média dos valores
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [35]:
B = 1
T = 5
C = 2
triangular = torch.ones((T,T), dtype=torch.long)
print(triangular)
triangular = torch.tril(triangular)
print(triangular)

x = torch.randint(0,4, (B, T, C), dtype=torch.long) # batches, time steps, channels(encoding every token)
x = x.transpose(1, 2)
print(x.shape, triangular.shape)
print(x @ triangular) # (B, T, C) x (T,T) = (B, T, T)


tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])
tensor([[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]])
torch.Size([1, 2, 5]) torch.Size([5, 5])
tensor([[[5, 4, 3, 3, 1],
         [9, 7, 6, 6, 3]]])


tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [7]:
x[0]

tensor([[5, 5],
        [8, 6],
        [7, 7],
        [2, 9],
        [2, 5],
        [7, 1],
        [6, 3],
        [6, 3]])

### Self-Attention Most Simple Code

In [39]:
torch.manual_seed(1337)

# 4 batches of tokens, sequences of 8 tokens where each token is represented by 32 values(encoded).
B, T, C = 4, 8, 32
input = torch.randn(B,T,C)
print(f'Input shape: {input.shape}')

head_size = 16

key_layer = nn.Linear(C, head_size, bias=False)
query_layer = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
print(f'query and key shape: {list(query_layer.parameters())[0].T.shape}')
# para cada timestamp, aplica a camada e produz uma query, com base na codificacao do proprio token, independentemente
q = query_layer(input)  # (B, T, head_size)
k = key_layer(input) # (B, T, head_size)
print(f'q and k shape: {q.shape}')
# afinidade entre query and key
weights_affinity = q @ k.transpose(-1, -2) # (B, T, head_size) # (B, head_size, T)
# as operacoes a seguir garantem que nao consideremos os tokens futuros a um instante t
# e tambem normalizam os pesos de cada um dos tokens anteriores,
#### --> caso não houvesse a necessidade de somente considerar tokens passados,
#### --> ou seja, podemos considerar os tokens do futuro, anteriores e o atual
#### --> ex: classificacao de sentimento. poderiamos remover a linha de masking
#### --> nesse caso teriamos um encoder
#### --> encoder: não possui masking.
#### --> decoder: possui masking
#### -->    decoder: chamado assim porque temos um formato auto-regressivo, considerando tokens do passado somente e fazendo inferencia para futuro.


# for every token(timestap) we have afinity of it to every other token (T, T)
print(f'weights_affinity shape: {weights_affinity.shape}')
# remove future information (masking)
tril = torch.tril(torch.ones(T, T)) # matriz triangular
# substitui 0 por -inf
weights_affinity = weights_affinity.masked_fill(tril == 0, float('-inf'))
# normaliza com softmax
weights_affinity = F.softmax(weights_affinity, dim=-1)
weights_affinity[0]



# aplica a afinidade no input
# output = weights_affinity @ input
 # aplica uma camada de pesos no input, antes de aplicar as afinidades
output = weights_affinity @ value(input)
# as matrizes query e key sao usadas para adiquirir o contexto dos tokens anteriores
# e a relacao com o token atual. A matriz value nao depende dos tokens anteriores
# ela guarda informacao do token atual, algum tipo de 'significado' especifico do token
print(f'output shape: {output.shape}')
# attention serve como um mecanismo de comunicação entre os tokens anteriores e o token atual. 


Input shape: torch.Size([4, 8, 32])
query and key shape: torch.Size([32, 16])
q and k shape: torch.Size([4, 8, 16])
weights_affinity shape: torch.Size([4, 8, 8])
output shape: torch.Size([4, 8, 16])


## SelfAttention Module (Reusable in our Nets)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [51]:
%%writefile self_attention.py
import torch
from torch import nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embedding_dim, head_size, block_size, masked=False):
        super(SelfAttention, self).__init__()
        #
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)
        self.masked = masked
        # masked attention (decoder)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    
    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x) # (B, T, C) @ (C, H) ->  (B, T, H)
        k = self.query(x) # (B, T, C) @ (C, H) ->  (B, T, H)
        a_weights = q @ k.transpose(-1, -2) # (B, T, H) @ (B, T, H) -> (B, T, T) (attention scores)
        # attention_weights shape is CxC -> representing weights from each C token to all others C
        a_weights = a_weights / (C**(0.5))# scale factor (scaled dot product attention - sqrt(C))
        if self.masked:
            # uses the masking for dont communicate with the future tokens(replace by -inf to apply softmax)
            a_weights = a_weights.masked_fill(self.tril == 0, float('-inf'))
        #
        a_weights = torch.softmax(a_weights, dim=-1) # normalize in the Channels(embedding dim) dimension
        y = a_weights @ self.value(x) # (B, T, T) @ (B, T, C) -> (B, T, C)
        return y


Writing self_attention.py


In [46]:
batch_size = 4
seq_len = 8
embedding_dim = 2
head_size = 7

attention_layer = SelfAttention(embedding_dim, head_size, seq_len, masked=True)

x = torch.randn(batch_size, seq_len, embedding_dim)
print('x:', x.shape)
y = attention_layer(x)
print('y:', y.shape)

x: torch.Size([4, 8, 2])
torch.Size([4, 8, 7]) = torch.Size([4, 8, 8]) @ torch.Size([4, 8, 7])
y: torch.Size([4, 8, 7])
