In [1]:
from transformers import AutoTokenizer
import torch.nn as nn
import torch


In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, head_dim, d_model, seq_len, dropout =0.0, use_mask=False):
        # head_dim = d_k
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.dropout = nn.Dropout(dropout)
        self.W_Q = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.W_K = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.W_V = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.W_O = nn.Linear(n_heads * head_dim, d_model, bias=False)
        self.use_mask = use_mask
        if self.use_mask:
            self.register_buffer(
                "mask",
                torch.triu(
                    torch.ones(self.seq_len, self.seq_len),
                    diagonal=1
                )
            )
    def forward(self, Q, K, V):
        # Q size: from [batch_size, seq_len, d_model]
        Q = self.W_Q(Q) # => [batch_size, seq_len, n_heads * n_dim]
        K = self.W_K(K) # => [batch_size, seq_len, n_heads * n_dim]
        V = self.W_V(V) # => [batch_size, seq_len, n_heads * n_dim]
        attention_weights = torch.matmul(Q, K.transpose(-2, -1)) # sim_mat => [batch_size, seq_len, seq_len]
        attention_weights = attention_weights / torch.sqrt(torch.Tensor([self.head_dim]))
        if self.use_mask:
            seq_len = Q.shape[-2]
            mask = self.mask.bool()[:seq_len, :seq_len]
            attention_weights.masked_fill_(mask, -torch.inf)
        # attention_weights = [batch_size, seq_len, seq_len]
        attention_weights = nn.functional.softmax(attention_weights, dim=-1)
        attention_score = attention_weights
        result_mat = torch.matmul(attention_weights, V)
        result_mat = self.W_O(result_mat)
        return result_mat, attention_score
        

In [3]:
tokenizer = AutoTokenizer.from_pretrained("ikit-claw-nlp/toy-llm")
d_model = 512
n_seq_len = 256
n_batch_size = 10
TokenEmbeddingLayer = nn.Embedding(
    num_embeddings=tokenizer.vocab_size,
    embedding_dim = d_model,
    padding_idx=tokenizer.convert_tokens_to_ids("<pad>")
)
PosEmbeddingLayer = nn.Embedding(
    num_embeddings = n_seq_len,
    embedding_dim = d_model
)

In [4]:
embedding = TokenEmbeddingLayer(torch.randint(low=0, high=tokenizer.vocab_size, size=(n_batch_size, n_seq_len)))
pos_embedding = PosEmbeddingLayer(torch.arange(n_seq_len))
embedding = embedding + pos_embedding

In [5]:
import torch
test_tensor = torch.rand(size=(n_batch_size, n_seq_len, d_model))

In [6]:
mha = MultiHeadAttention(
    n_heads = 8,
    head_dim = 64,
    d_model = d_model,
    seq_len = n_seq_len,
    dropout=0.5,
    use_mask= True 
)
test_mat, _ = mha(test_tensor,test_tensor,test_tensor)
