In [6]:
import torch 
import torch.nn.functional as F
import torch.nn as nn
import math

import numpy as np 
import random 

seed = 73 
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [7]:
# se instancia el tensor(aleatorio) de orden 2 que simula el embeding de cada token único dentro del léxico,
# en este caso, el tamaño del léxico es 3 y la dimensión de embedding es 768 (la misma dimensión que BERT)
embed_dim = 768
n_tokens = 3

embedding_words = torch.rand((n_tokens, embed_dim))*0.1        # 3 x 768
embedding_words.shape

torch.Size([3, 768])

In [8]:
linear_query = nn.Linear(in_features = embed_dim, out_features = embed_dim)
linear_key = nn.Linear(in_features = embed_dim, out_features = embed_dim)
linear_value = nn.Linear(in_features = embed_dim, out_features = embed_dim)

# se aplican las proyecciones al embedding 
query = linear_query(embedding_words)          # 3 x 768
key = linear_key(embedding_words)              # 3 x 768
value = linear_value(embedding_words)          # 3 x 768

In [9]:
def attention(query, key, value):
    """Capa para obtener 'Scaled Dot Product Attention'

    Args:
        query (torch.Tensor): Tensor de pesos de consulta
        key (torch.Tensor): Tensor de pesos del valor clave en el embedding
        value (torch.Tensor): Tensor de  valores de los tokens
    """
    # dimensión del embedding
    d_k = query.shape[-1]

    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 
    attn = F.softmax(scores, dim = -1)
    z = torch.matmul(attn, value)

    return z, attn

In [10]:
class MultiHeadAttention(nn.Module):
    """_summary_

    Args:
        nn (torch): neural network module
    """
    def __init__(self, n_heads, embed_dim):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.embed_dim = embed_dim
        self.linear_w0 = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, query, key, value):
        query = query.view(-1, self.n_heads, int(self.embed_dim / self.n_heads)).transpose(1,2)
        key = key.view(-1, self.n_heads, int(self.embed_dim / self.n_heads)).transpose(1,2)
        value = value.view(-1, self.n_heads, int(self.embed_dim / self.n_heads)).transpose(1,2)

        z, _ = attention(query, key, value)
        z = z.transpose(1, 2).contiguous().view(n_tokens, -1)
        z = linear_w0(z)
        
        return z


In [11]:
n_heads = 12
mha = MultiHeadAttention(n_heads = n_heads, embed_dim = embed_dim)
mha

MultiHeadAttention(
  (linear_w0): Linear(in_features=768, out_features=768, bias=True)
)

In [14]:
z_module = mha(query, key, value)
z_module.shape

torch.Size([3, 768])