In [15]:
import torch
from torch import nn
from torchinfo import summary
import numpy as np
import math

In [25]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, embedding_length, qkv_vec_length):
        '''
        embedding_length : embedding 하나의 길이 -> W_(qkv)의 row 값이 된다.
        qkv_vec_length :  W_(qkv) 행렬의 col 값
        '''
        super().__init__()

        # Query Matrix (Weight)
        self.W_q = nn.Parameter(torch.randn(embedding_length, qkv_vec_length, requires_grad=True))
        '''
        self.W_q = torch.Tensor([[1, 0, 1],
                                 [1, 0, 0],
                                 [0, 0, 1],
                                 [0, 1, 1]])
        self.W_q.requires_grad = True
        '''
        
        # Key Matrix (Weight)
        self.W_k = nn.Parameter(torch.randn(embedding_length, qkv_vec_length, requires_grad=True))
        '''
        self.W_k = torch.Tensor([[0, 0, 1],
                                 [1, 1, 0],
                                 [0, 1, 0],
                                 [1, 1, 0]])
        self.W_k.requires_grad = True
        '''
        
        # Value Matrix (Weight)
        self.W_v = nn.Parameter(torch.randn(embedding_length, qkv_vec_length, requires_grad=True))
        '''
        self.W_v = torch.Tensor([[0, 2, 0],
                                 [0, 3, 0],
                                 [1, 0, 3],
                                 [1, 1, 0]])
        self.W_k.requires_grad = True
        '''
        
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        Q = x @ self.W_q
        K = x @ self.W_k
        V = x @ self.W_v

        # print(f'x shape: {x.shape}')
        # print(f'W_q shape: {self.W_q.shape}\nW_k shape: {self.W_k.shape}\nW_v shape: {self.W_q.shape}')
        # print(f'Q shape: {Q.shape}\nK shape: {K.shape}\nV shape: {V.shape}')
        # print(f'<Q>\n{Q}\n<K>\n{K}\n<V>\n{V}\n')

        attn_scores = Q @ torch.transpose(K, -2, -1)
        # print(f'<attention scores>\n{attn_scores}\n')

        attn_scores_softmax = self.softmax(attn_scores / math.sqrt(K.shape[-1]))
        # print(f'attention_score: {attn_scores_softmax.shape}')
        # print(f'<attention scores softmax>\n{attn_scores_softmax}\n')

        weighted_values = attn_scores_softmax @ V
        # print(f'weighted values shape: {weighted_values.shape}')
        # print(f'<Scaled Dot Product Attention Output>\n{weighted_values}\n')
        
        return weighted_values

In [17]:
x_test = torch.Tensor([[[1, 0, 1, 0],
                       [0, 2, 0, 2],
                       [1, 1, 1, 1]]])
x_test

tensor([[[1., 0., 1., 0.],
         [0., 2., 0., 2.],
         [1., 1., 1., 1.]]])

In [26]:
att = ScaledDotProductAttention(30, 20)

In [27]:
summary(att, (1, 197, 30))

Layer (type:depth-idx)                   Output Shape              Param #
ScaledDotProductAttention                [1, 197, 20]              1,800
├─Softmax: 1-1                           [1, 197, 197]             --
Total params: 1,800
Trainable params: 1,800
Non-trainable params: 0
Total mult-adds (M): 0
Input size (MB): 0.02
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.02