In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F


## Attention

Given sequence 'x' e.g. ["hi", "my", "name", "is"]

we want to know the interaction (similarity) b/w each word with others.

this is done by projecting 'x' via a linear layer to get q, k, v 
q -> query, what is needed
k -> key, information about the current
v -> value, final weight avg given similarity(calculated)

In [7]:
b, s, d = 10, 4, 128

x = torch.randn(b,s,d)
print(x.var())
similarity = x @ x.transpose(-2, -1)
print(similarity.var())
similarity = similarity / (d **0.5)
print(similarity.var())
soft = similarity.softmax(-1)
soft[0].sum(-1)

tensor(0.9952)
tensor(3274.3716)
tensor(25.5810)


tensor([1.0000, 1.0000, 1.0000, 1.0000])

## Single headed Attention

In [42]:
class Attention(nn.Module):
    def __init__(self, emd_dim):
        super().__init__()
        self.emb_dim = emd_dim
        self.q = nn.Linear(self.emb_dim, self.emb_dim)
        self.k = nn.Linear(self.emb_dim, self.emb_dim)
        self.v = nn.Linear(self.emb_dim, self.emb_dim)

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.q(x)

        similarity = (q @ k.transpose(-2, -1)) / self.emb_dim ** 0.5
        attention = similarity.softmax(-1)
        output = attention @ v

        return output
    
attn = Attention(128)
x = torch.randn(2, 64, 128)
attn(x)


tensor([[[ 0.0215, -0.1242,  0.0626,  ...,  0.1387,  0.1708,  0.0487],
         [-0.0139, -0.1821,  0.0945,  ...,  0.1462,  0.1688,  0.0929],
         [-0.0014, -0.1084,  0.0424,  ...,  0.0898,  0.1312,  0.0886],
         ...,
         [-0.0152, -0.1461,  0.0783,  ...,  0.1647,  0.1842,  0.0663],
         [-0.0184, -0.1054,  0.0434,  ...,  0.1180,  0.1456,  0.1013],
         [ 0.0248, -0.0857,  0.0691,  ...,  0.1254,  0.1747,  0.0700]],

        [[-0.2378,  0.0009,  0.1013,  ...,  0.0725,  0.1120,  0.1676],
         [-0.1747, -0.0497,  0.0817,  ...,  0.0491,  0.0399,  0.1744],
         [-0.2405, -0.0036,  0.1232,  ...,  0.0151,  0.1309,  0.1887],
         ...,
         [-0.2745, -0.0709,  0.1324,  ...,  0.0718,  0.0901,  0.1753],
         [-0.2727, -0.0808,  0.1612,  ...,  0.0930,  0.1021,  0.2083],
         [-0.2279, -0.0353,  0.1329,  ...,  0.0364,  0.1422,  0.1589]]],
       grad_fn=<UnsafeViewBackward0>)

## MHA encoded

In [None]:
class MultiHeadAttentionEncoded(nn.Module):
    def __init__(self, emb_dim, n_heads, attn_drop=0.1, proj_drop=0.0, bias=False):
        super().__init__()
        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.head_dim = self.emb_dim // n_heads

        # attn
        self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=bias)
        self.k_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=bias)
        self.v_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_drop)
        ## post attn
        self.out_proj = nn.Linear(self.emb_dim, self.emb_dim)
        self.out_proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        print(x.shape)

        # B -> Batch dim of input
        # T -> Time step of sequence
        # C -> Channel i.e embedding dim
        B, T, C = x.shape

        ## 1. project 'x' to get q, k, v
        ## 2. we have (B, T, C), but we need to do multi headed on the 'C'
        ##    so, split C into (n_heads * head_dim) -> (B, T, H, C')
        ##    But for actual calculation, we want (B H T C') @ (B H C' T) -> transpose/ swap dim (1,2)
        q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(1,2)
        k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(1,2)
        v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(1,2)

        # calc attn
        attn = ( (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)).softmax(-1)
        attn = self.attn_drop(attn) ## dropout some attn scores
        # Get values
        x = attn @ v

        # swap back (B H T C) to (B T H C) and combine (H C') to C
        x = x.transpose(1,2).reshape(B, T, C)

        # Project multi-headed result to have interaction b/w heads via linear layer again 
        x = self.out_proj(x)
        x = self.out_proj_drop(x)
        
        print(x.shape)

mha = MultiHeadAttentionEncoded(128, 4)
x = torch.randn(2, 64, 128)
mha(x)

torch.Size([2, 64, 128])
torch.Size([2, 64, 128])


## Causal & attention masks

In [None]:
class CausalMHAEncoded(nn.Module):
    def __init__(self, emb_dim, n_heads, attn_drop=0.1, proj_drop=0.0, bias=False):
        super().__init__()
        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.head_dim = self.emb_dim // n_heads

        # attn
        self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=bias)
        self.k_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=bias)
        self.v_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_drop)
        ## post attn
        self.out_proj = nn.Linear(self.emb_dim, self.emb_dim)
        self.out_proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, attn_mask):
        print(x.shape)

        # B -> Batch dim of input
        # T -> Time step of sequence
        # C -> Channel i.e embedding dim
        B, T, C = x.shape

        q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(1,2)
        k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(1,2)
        v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(1,2)

        # calc attn
        attn = ( (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5))

        ### before doing softmax, apply the masks  
        ### Causal mask
        ones = torch.ones((T, T), device=attn.device)
        causal_mask = torch.tril(ones)
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).bool() # add fake B, H dim
        ### attn mask
        if attn_mask is not None:
            causal_mask = causal_mask.repeat(B, 1, 1, 1)
            # attn_mask -> (B, T) -> add fake H, C dim -> (B 1 1 T) then repeat in C dim -> (B H T T)
            attn_mask = attn_mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, T, 1)

            # add attn mask to causal mask
            causal_mask = causal_mask.masked_fill(~attn_mask, False)

        ## mask out the attn
        attn = attn.masked_fill(~causal_mask, float("-inf"))

        attn = attn.softmax(-1)
        attn = self.attn_drop(attn) ## dropout some attn scores
        # Get values
        x = attn @ v

        # swap back (B H T C) to (B T H C) and combine (H C') to C
        x = x.transpose(1,2).reshape(B, T, C)

        # Project multi-headed result to have interaction b/w heads via linear layer again 
        x = self.out_proj(x)
        x = self.out_proj_drop(x)
        
        print(x.shape)

attn_mask = torch.randint(0, 2, (2, 64)).bool()
print(attn_mask)
mha = CausalMHAEncoded(128, 4)
x = torch.randn(2, 64, 128)
mha(x, attn_mask)

tensor([[False, False, False,  True, False,  True,  True,  True,  True, False,
          True, False, False, False,  True, False, False,  True, False, False,
         False, False,  True,  True, False, False, False,  True,  True,  True,
         False, False,  True, False, False, False,  True,  True,  True, False,
         False,  True, False, False, False, False, False,  True, False, False,
          True,  True, False,  True, False,  True,  True, False, False, False,
         False, False,  True, False],
        [ True,  True, False,  True, False,  True, False,  True,  True, False,
         False, False, False,  True, False,  True, False,  True, False,  True,
          True, False,  True,  True,  True,  True,  True,  True, False,  True,
          True,  True, False,  True,  True, False, False,  True, False, False,
         False, False, False, False,  True, False, False,  True,  True, False,
          True,  True, False, False,  True,  True, False,  True, False, False,
         False