# Implementing Self-Attention Layer

### Ref: [The AiEdge Newsletter](https://drive.google.com/file/d/1Je2SAFBlsWcgwzK_gl1_f-LtPK3SOzg3/view)

<img src="../../assets/attention_layer.png" width="700" height="350">

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SelfAttention(nn.Module):
    """
    single-head self-attention

    Args:
        d_in (int): input (embedding) dimension (in standard transformer d_in = d_model (or hidden_size))
        d_out (int): output dimension (in standard transformer d_out = d_model (or hidden_size))
    """    
    def __init__(self, d_in: int, d_out: int) -> None:   
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out

        # linear projections for key, queries and values that maps [d_in] -> [d_out]
        self.K = nn.Linear(d_in, d_out)
        self.Q = nn.Linear(d_in, d_out)
        self.V = nn.Linear(d_in, d_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        compute attension

        Args:
            x (torch.Tensor): input tensor of shape [batch_size, seq_len, d_in]

        Returns:
            hidden_states (torch.Tensor): output of shape [batch_size, seq_len, d_out] 
        """        
        keys = self.K(x)    # transforming input to output with size [batch_size, seq_len, d_out]
        queries = self.Q(x) # [batch_size, seq_len, d_out]
        values = self.V(x)  # [batch_size, seq_len, d_out]

        # scores = torch.bmm(keys, queries.transpose(1,2))
        # scale scores to prevent large values when d_out is big
        scores = queries @ keys.transpose(-2, -1) / self.d_out ** 0.5   # [batch_size, seq_len, seq_len]
        self_attn = F.softmax(scores, dim=-1)

        # hidden_states = torch.bmm(self_attn, values)
        hidden_states = self_attn @ values  # [batch_size, seq_len, d_out] (context vectors in the figure above)
        return hidden_states

#### Toy Example 1

In [9]:
batch_size = 2
seq_len = 10    # 10 tokens per each sequence
d_in = 6        # input embedding dim
d_out = 23      # output hidden state dim (after attention)


x = torch.randn(batch_size, seq_len, d_in)
attn = SelfAttention(d_in, d_out)
hidden_state = attn(x)

print(f'input x: {x.size()}, and hidden state: {hidden_state.size()}')

input x: torch.Size([2, 10, 6]), and hidden state: torch.Size([2, 10, 23])


#### Toy Example 2 (with Embedding)

In [10]:
batch_size = 2
vocab_size = 15
d_model = 10    # or model dim or hidden size
seq_len = 5

# generate random Token Ids (note that Embedding layer only receives integer values)
x = torch.randint(0, vocab_size, (batch_size, seq_len))

embedding = nn.Embedding(vocab_size, d_model)
embedded = embedding(x)     # [batch_size, seq_len, d_model]

attn = SelfAttention(d_model, d_model)
hidden_state = attn(embedded)   # [batch_size, seq_len, d_model]

print(f'input x: {x.size()}, and embedded: {embedded.size()}, and hidden state: {hidden_state.size()}')

input x: torch.Size([2, 5]), and embedded: torch.Size([2, 5, 10]), and hidden state: torch.Size([2, 5, 10])
