# Self-Attention

- multi-head self-attention
- equation

attn = softmax(q @ k.t)
result = attn @ v 

mixed_q = softmax( (4,10) @ (10,4) ) @ v 

- engineering 

head 마다 w 를 따로 만들어주었음  
nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])

In [11]:
# code ref : https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c

import torch
import torch.nn as nn

class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                print("q shape", q.shape)
                print("k shape", k.shape)
                print("v shape", v.shape)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                
                print("attention shape", attention.shape)

                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [12]:
a_ = torch.randn((1,4,10))

In [13]:
a_.shape

torch.Size([1, 4, 10])

In [14]:
MSA = MyMSA(d=10,n_heads=1)

In [15]:
MSA.forward(a_)

q shape torch.Size([4, 10])
k shape torch.Size([4, 10])
v shape torch.Size([4, 10])
attention shape torch.Size([4, 4])


tensor([[[ 0.2914, -0.3380, -0.1037, -0.3091,  0.2320,  0.3522,  0.0466,
          -0.5097,  0.3001, -0.1115],
         [ 0.3803, -0.2434, -0.2613, -0.2943,  0.1425,  0.3373,  0.0014,
          -0.6945,  0.4082, -0.0787],
         [-0.0247, -0.5951,  0.1316, -0.2992,  0.3529,  0.3853,  0.1303,
          -0.1537,  0.1337, -0.1963],
         [-0.0714, -0.6603,  0.4050, -0.3395,  0.4705,  0.4381,  0.1973,
           0.1237, -0.0513, -0.2108]]], grad_fn=<CatBackward0>)