In [8]:
from transformers import AutoTokenizer
import torch.nn as nn
import torch


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, head_dim, d_model, use_mask=False):
        # head_dim = d_k
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.W_Q = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.W_K = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.W_V = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.W_O = nn.Linear(n_heads * head_dim, d_model, bias=False)
        self.use_mask = use_mask
    def forward(self, Q, K, V):
        # Q size: from [batch_size, seq_len, d_model]
        Q = self.W_Q(Q) # => [batch_size, seq_len, n_heads * n_dim]
        K = self.W_K(K) # => [batch_size, seq_len, n_heads * n_dim]
        V = self.W_V(V) # => [batch_size, seq_len, n_heads * n_dim]
        sim_mat = torch.matmul(Q, K.transpose(-2, -1)) # sim_mat => [batch_size, seq_len, seq_len]
        sim_mat = sim_mat / torch.sqrt(torch.Tensor([self.head_dim]))
        if self.use_mask:
            seq_len = Q.shape[-2]
            mask_mat = torch.ones( (seq_len, seq_len) )
            # row elements above the 1 offset of the main diagonal set to 1. Others 0.
            mask_mat = torch.triu(mask_mat, diagonal=1)
            mask_mat = mask_mat.masked_fill(mask_mat.bool(), -torch.inf)
            sim_mat = sim_mat + mask_mat
        # context_mat = [batch_size, seq_len, seq_len]
        context_mat = nn.functional.softmax(sim_mat, dim=-1)
        result_mat = torch.matmul(context_mat, V)
        result_mat = self.W_O(result_mat)
        return result_mat
        

In [None]:
tokenizer = AutoTokenizer.from_pretrained("ikit-claw-nlp/toy-llm")
d_model = 512
n_seq_len = 256
n_batch_size = 10
EmbeddingLayer = nn.Embedding(
    num_embeddings=tokenizer.vocab_size,
    embedding_dim = d_model,
    padding_idx=tokenizer.convert_tokens_to_ids("<pad>")
)

In [None]:
embedding = EmbeddingLayer(torch.randint(low=0, high=tokenizer.vocab_size, size=(n_batch_size, n_seq_len)))

In [None]:
d_k = 64
n_heads = 12
print("Embedding Size", embedding.shape)
W_Q = nn.Linear(d_model, n_heads * d_k, bias=False)
W_K = nn.Linear(d_model, n_heads * d_k, bias=False)
W_V = nn.Linear(d_model, n_heads * d_k, bias=False)
W_O = nn.Linear(n_heads * d_k, d_model, bias=False)
Q = W_Q(embedding)
K = W_K(embedding).transpose(-1, -2)
V = W_V(embedding)
print("Q, K, V shapes")
print(Q.shape, K.shape, V.shape)
similar_mat = torch.matmul(Q, K)
# Masking
mask_matrix = torch.triu(torch.ones(n_seq_len, n_seq_len), diagonal=1)
mask_matrix = mask_matrix.masked_fill(mask_matrix.bool(), -torch.inf)
print("Masking Matrix")
print(mask_matrix)
print(mask_matrix.shape)
similar_mat = similar_mat + mask_matrix
softmax_mat = torch.softmax(similar_mat / torch.sqrt(torch.Tensor([d_model])), dim=-1)
print("Softmax Matrix Shape")
print(softmax_mat.shape)
print("Contextual Matrix")
result_mat = torch.matmul(softmax_mat, V)
print(result_mat.shape)
print("Final Output")
result_mat = W_O(result_mat)
print(result_mat.shape)

In [40]:
import torch
test_tensor = torch.rand(size=(10, 256, 100))

In [61]:
mha = MultiHeadAttention(
    n_heads = 8,
    head_dim = 64,
    d_model = 100,
    use_mask= True 
)
test_mat = mha(test_tensor,test_tensor,test_tensor)


In [59]:
test_mat.shape

torch.Size([10, 256, 256])