In [None]:
import torch
import torch.nn as nn


class SingleHeadedAttention(nn.Module):

    def __init__(self, wQ, wK, wV) -> None:
        super().__init__()
        self.wQ = nn.Parameter(wQ, requires_grad=True)
        self.wK = nn.Parameter(wK, requires_grad=True)
        self.wV = nn.Parameter(wV, requires_grad=True)

    def forward(self, query, key, value) -> None:
        """lets assume dimension of Q,K,V is same i.e == N, in reality query and key should always be of same dimension
            N : number of examples 
            Q : (N, Dq)
            K : (N, Dk)
            V : (N, Dv)
            wQ :(Dq, D)
            wK :(Dk, D)  
            wV :(Dv, D) dimensions of wV need not be same as wQ, wK. Useful for applying cross attention
        """

        # IMPLEMENT OUTPUT = SOFTMAX(WQ * WK.T / d**0.5) * Values

        # (N, Dq) @ (Dq, D)  -> N * D
        key_transformed = torch.matmul(key, self.wK)
        print("keys", key_transformed)
        # (N, Dv) @ (Dv, D)  -> N * D
        value_transformed = torch.matmul(value, self.wV)
        print("values", value_transformed)
        # (N, Dq) @ (Dq, D)  -> N * D
        query_transformed = torch.matmul(query, self.wQ)
        print("queries", query_transformed)

        # (N * D) @ (D * N) -> N * N
        scores = torch.matmul(query_transformed, key_transformed.mT)  # /len(query)**0.5
        print(scores)

        softmax_scores = torch.softmax(scores, dim=-1)
 
       
        # (N * N) @ (N * D) -> N * D
        attended_values = torch.matmul(softmax_scores, value_transformed)

        return attended_values, softmax_scores


In [None]:
# Testcase
# N = 3, Dq,Dv,Dk = 4, D = 3

x = [
    [1, 0, 1, 0],  # Input 1
    [0, 2, 0, 2],  # Input 2
    [1, 1, 1, 1],  # Input 3
]

wK = [
    [0, 0, 1],
    [1, 1, 0],
    [0, 1, 0],
    [1, 1, 0],
]
wQ = [
    [1, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 1],
]
wV = [
    [0, 2, 0],
    [0, 3, 0],
    [1, 0, 3],
    [1, 1, 0],
]

wQ, wK, wV = torch.Tensor(wQ), torch.Tensor(wK), torch.Tensor(wV)

model = SingleHeadedAttention(wQ, wK, wV)


In [18]:
q, k , v = torch.Tensor(x), torch.Tensor(x), torch.Tensor(x)

model(q, k , v)

keys tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]], grad_fn=<MmBackward0>)
values tensor([[1., 2., 3.],
        [2., 8., 0.],
        [2., 6., 3.]], grad_fn=<MmBackward0>)
queries tensor([[1., 0., 2.],
        [2., 2., 2.],
        [2., 1., 3.]], grad_fn=<MmBackward0>)
tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]], grad_fn=<MmBackward0>)


(tensor([[1.9366, 6.6831, 1.5951],
         [2.0000, 7.9640, 0.0540],
         [1.9997, 7.7599, 0.3584]], grad_fn=<MmBackward0>),
 tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
         [6.0337e-06, 9.8201e-01, 1.7986e-02],
         [2.9539e-04, 8.8054e-01, 1.1917e-01]], grad_fn=<SoftmaxBackward0>))

In [None]:
#Lets use torch api with same input 

module = torch.nn.MultiheadAttention(embed_dim=4 , kdim=4, vdim=4 , num_heads=1, batch_first=False)

In [61]:
module(q, k , v)

(tensor([[-0.2326, -0.1976,  0.3269, -0.2430],
         [-0.2349, -0.1973,  0.3276, -0.2419],
         [-0.2378, -0.1959,  0.3273, -0.2386]], grad_fn=<SqueezeBackward1>),
 tensor([[0.3400, 0.3260, 0.3340],
         [0.3381, 0.3230, 0.3389],
         [0.3424, 0.3208, 0.3368]], grad_fn=<SqueezeBackward1>))

In [None]:
q.shape

In [56]:
module.q_proj_weight, module.v_proj_weight, module.k_proj_weight

(None, None, None)

In [44]:
list(module.parameters())

[Parameter containing:
 tensor([[ 0.4270,  0.7082,  0.7191],
         [-0.1559, -0.8517,  0.1958],
         [ 0.4888,  0.9493,  0.0645]], requires_grad=True),
 Parameter containing:
 tensor([[-0.2847, -0.6150, -0.1386, -0.1009],
         [-0.3014, -0.8046,  0.4053, -0.9241],
         [ 0.9085, -0.8638,  0.2429, -0.4422]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.8045, -0.1455, -0.4198,  0.6208],
         [ 0.0383,  0.2769, -0.9163,  0.8202],
         [ 0.0949,  0.9156,  0.8680, -0.5306]], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True),
 Parameter containing:
 tensor([[ 0.0146, -0.0544, -0.4000],
         [ 0.2542,  0.3828,  0.3437],
         [ 0.4194, -0.4564,  0.1957]], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0.], requires_grad=True)]