In [1]:
import math
import torch
import torch.nn as nn

In [2]:
class AttentionLayer(nn.Module):
    
    def __init__(self, hidden_size: int, n_head: int) -> None:
        super().__init__()
        
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_v = nn.Linear(hidden_size, hidden_size)
        
        self.n_head = n_head
        self.hidden_size = hidden_size
        assert self.hidden_size % self.n_head == 0
        
    def forward(self, xs: torch.Tensor):
        # xs (batch_size, squence_len, feat_len)
        n_head = self.n_head
        hidden_size = self.hidden_size
        n_dim = hidden_size // n_head
        batch_size, seq_len, feat_len = xs.size()
        assert feat_len == hidden_size, f"hidden dim not equal! hidden_size: {hidden_size}, feat_len: {feat_len}"

        q = self.linear_q(xs)
        k = self.linear_k(xs)
        v = self.linear_v(xs)
        
        q: torch.Tensor = q.view(batch_size, -1, n_head, n_dim).transpose(1, 2)
        k: torch.Tensor = k.view(batch_size, -1, n_head, n_dim).transpose(1, 2)
        v: torch.Tensor = v.view(batch_size, -1, n_head, n_dim).transpose(1, 2)

        atten_score = torch.softmax(torch.matmul(q, v.transpose(-1, -2)), dim=-1) / math.sqrt(n_dim)
        v = torch.matmul(atten_score, v)
        v = v.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        
        return v
    

In [3]:
attention_model = AttentionLayer(512, 8)
xs = torch.randn(32, 100, 512)
ys = attention_model(xs)

In [4]:
param_num = sum([p.numel() for p in attention_model.parameters()])
print(param_num)

787968


In [5]:
for name, p in attention_model.named_parameters():
    print(name, p.numel())

linear_q.weight 262144
linear_q.bias 512
linear_k.weight 262144
linear_k.bias 512
linear_v.weight 262144
linear_v.bias 512
