In [24]:
#Followed blog here : https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a

import torch
import torch.nn as nn


class SingleHeadedAttention(nn.Module):

    def __init__(self, wQ, wK, wV) -> None:
        super().__init__()
        self.wQ = nn.Parameter(wQ, requires_grad=True)
        self.wK = nn.Parameter(wK, requires_grad=True)
        self.wV = nn.Parameter(wV, requires_grad=True)

    def forward(self, query, key, value) -> None:
        """lets assume dimension of Q,K,V is same i.e == N, in reality query and key should always be of same dimension
            N : length of sequence
            Q : (N, Dq)
            K : (N, Dk)
            V : (N, Dv)
            wQ :(Dq, D)
            wK :(Dk, D)  
            wV :(Dv, D) dimensions of wV need not be same as wQ, wK. Useful for applying cross attention
        """

        # IMPLEMENT OUTPUT = SOFTMAX(WQ * WK.T / d**0.5) * Values

        # (N, Dq) @ (Dq, D)  -> N * D
        key_transformed = torch.matmul(key, self.wK)
        #print("keys", key_transformed)
        # (N, Dv) @ (Dv, D)  -> N * D
        value_transformed = torch.matmul(value, self.wV)
       # print("values", value_transformed)
        # (N, Dq) @ (Dq, D)  -> N * D
        query_transformed = torch.matmul(query, self.wQ)
        #print("queries", query_transformed)

        # (N * D) @ (D * N) -> N * N
        scores = torch.matmul(query_transformed, key_transformed.mT)  # /len(query)**0.5
        #print(scores)

        softmax_scores = torch.softmax(scores, dim=-1)
 
       
        # (N * N) @ (N * D) -> N * D
        attended_values = torch.matmul(softmax_scores, value_transformed)

        return attended_values, softmax_scores


In [25]:
# Testcase
# N = 3, Dq,Dv,Dk = 4, D = 3

x = [
    [1, 0, 1, 0],  # Input 1
    [0, 2, 0, 2],  # Input 2
    [1, 1, 1, 1],  # Input 3
]

wK = [
    [0, 0, 1],
    [1, 1, 0],
    [0, 1, 0],
    [1, 1, 0],
]
wQ = [
    [1, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 1],
]
wV = [
    [0, 2, 0],
    [0, 3, 0],
    [1, 0, 3],
    [1, 1, 0],
]

wQ, wK, wV = torch.Tensor(wQ), torch.Tensor(wK), torch.Tensor(wV)

model = SingleHeadedAttention(wQ, wK, wV)


In [26]:
q, k , v = torch.Tensor(x), torch.Tensor(x), torch.Tensor(x)

model(q, k , v)

(tensor([[1.9366, 6.6831, 1.5951],
         [2.0000, 7.9640, 0.0540],
         [1.9997, 7.7599, 0.3584]], grad_fn=<MmBackward0>),
 tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
         [6.0337e-06, 9.8201e-01, 1.7986e-02],
         [2.9539e-04, 8.8054e-01, 1.1917e-01]], grad_fn=<SoftmaxBackward0>))