In [3]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
from einops import rearrange

### 1. Regular Attention

In [37]:

"""Multi-Head Layers for Transformer Encoder"""
class MultiHeadAttention(nn.Module): 
    def __init__(self, d_hidden, num_heads, attention_dropout):
        super(MultiHeadAttention, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"
        
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.d_k = d_hidden // num_heads # dimension of each head
        self.dropout = nn.Dropout(attention_dropout)
        
        self.W_q = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_k = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_v = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_o = nn.Linear(d_hidden, d_hidden, bias=False)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        print(f"[Attention Scores]: {attn_scores.shape} \n {attn_scores} \n")
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        attn_probs = self.dropout(torch.softmax(attn_scores, dim=-1))
        print(f"[Attention Probs]: {attn_probs.shape} \n {attn_probs} \n")

        output = torch.matmul(attn_probs, V)
        print(f"[Attention Output]: {output.shape} \n {output} \n")

        return output, attn_probs
    
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
        
    def combine_heads(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def forward(self, x, mask=None):
        q = self.split_head(self.W_q(x)) # (B, num_heads, seq_length, d_k)
        k = self.split_head(self.W_k(x))
        v = self.split_head(self.W_v(x))
        
        attn_output, _ = self.scaled_dot_product_attention(q, k, v, mask) # (B, num_heads, seq_length, d_k)
        output = self.W_o(self.combine_heads(attn_output)) # (B, seq_length, d_hidden)
        print(f"[Final Output]: {output.shape} \n {output} \n")
        return output


In [39]:
d_hidden = 8 
num_heads = 1
seq_length = 6
batch_size = 1
dropout = 0.0 

# x = torch.Tensor(
#     [
#         [
#             [ 1,  2,  3,  4,  5,  6,  7,  8],
#             [ 9, 10, 11, 12, 13, 14, 15, 16],
#             [17, 18, 19, 20, 21, 22, 23, 24],
#             [25, 26, 27, 28, 29, 30, 31, 32],
#             [33, 34, 35, 36, 37, 38, 39, 40],
#             [41, 42, 43, 44, 45, 46, 47, 48],
#         ]
#     ]
# ).to(torch.float32)

# x = torch.randint(0, 10, (batch_size, seq_length, d_hidden)).to(torch.float32)
x = torch.randn(batch_size, seq_length, d_hidden).to(torch.float32)

print(f"Input: {x.shape} \n {x} \n")

attention = MultiHeadAttention(d_hidden=d_hidden, num_heads=num_heads, attention_dropout=dropout)
attention.W_k.weight.data.fill_(1.0)
attention.W_q.weight.data.fill_(1.0)
attention.W_v.weight.data.fill_(1.0)
attention.W_o.weight.data.fill_(1.0)
out = attention(x)
print()


Input: torch.Size([1, 6, 8]) 
 tensor([[[-0.8775, -2.1730,  0.7890,  0.5044, -0.9170,  0.6572, -0.0923,
           0.3994],
         [-1.8758, -0.8411,  1.3538, -2.2346, -0.8308,  2.4109,  1.1988,
          -0.5029],
         [ 0.6618,  1.5232,  1.5176, -0.3011, -0.5179, -1.2367,  1.1713,
           0.5475],
         [ 1.4358, -1.6689, -0.1532,  0.3160,  0.5740, -0.6418,  0.3470,
           1.1666],
         [-1.0268, -0.3807,  0.1988,  1.6391,  0.0255,  0.5719, -0.3235,
          -0.5822],
         [-0.5843, -0.5342, -0.5527,  0.3054, -0.0530, -0.1742, -1.4902,
          -0.5062]]]) 

[Attention Scores]: torch.Size([1, 1, 6, 6]) 
 tensor([[[[  8.2686,   6.3918, -16.2770,  -6.6518,  -0.5907,  17.3586],
          [  6.3918,   4.9409, -12.5823,  -5.1419,  -0.4566,  13.4184],
          [-16.2770, -12.5823,  32.0414,  13.0941,   1.1628, -34.1707],
          [ -6.6518,  -5.1419,  13.0941,   5.3511,   0.4752, -13.9643],
          [ -0.5907,  -0.4566,   1.1628,   0.4752,   0.0422,  -1.2401],


## IMPORTANT ConvNN-Attention == KvT

### 2. Convolutional Nearest Neighbor Attention

In [46]:

class MultiHeadConvNNAttention(nn.Module):
    def __init__(self, 
                 d_hidden, 
                 num_heads, 
                 attention_dropout,
                 K, 
                 sampling_type, 
                 num_samples, 
                 sample_padding, 
                 magnitude_type, 
                 seq_length=197, 
                 coordinate_encoding=False
                 ):
        
        super(MultiHeadConvNNAttention, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"

        # Core Parameters
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads

        # ConvNN Parameters
        self.K = K
        self.seq_length = seq_length

        # 3 types of sampling: all, random, spatial
        self.sampling_type = sampling_type
        self.num_samples = int(num_samples) 
        self.sample_padding = int(sample_padding) if sampling_type == 'spatial' else 0    

        # Similarity Metric 
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'cosine' else False

        # Coordinate Encoding (optional) 
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {}
        
        # Linear projections for query, key, value
        self.W_q = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_k = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_v = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_o = nn.Linear(d_hidden, d_hidden, bias=False)
        self.dropout = nn.Dropout(attention_dropout)

        self.in_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else (d_hidden // num_heads)
        self.out_channels = (d_hidden // num_heads) 
        
        self.conv = nn.Conv1d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.K,
            stride=self.K,
            padding=0,
            bias = False, 
            groups=self.in_channels
        )

        # Utility Variables 
        self.INF = 1.1
        self.NEG_INF = -0.1 

    """K, Q, V projection functions"""
    def split_head(self, x): ## K, Q, V
        batch_size, seq_length, d_hidden = x.size()
        self.batch_size = batch_size
        # self.seq_length = seq_length
        return x.contiguous().view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
    def batch_combine(self, x):  ## K, Q, V
        batch_size, _, seq_length, d_k = x.size()
        x = x.permute(0, 1, 3, 2).contiguous() 
        return x.view(-1, self.d_k, seq_length)

    """Output projection function"""
    
    # def batch_split(self, x): ## O
    #     x = x.reshape(self.batch_size, -1, self.d_k, self.seq_length)
    #     return x.permute(0, 1, 3, 2).contiguous()
        
    # def combine_heads(self, x): ## O
    #     batch_size, _, seq_length, d_k = x.size()
    #     return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 

    def batch_split(self, x):
        if self.num_heads == 1:
            return x.unsqueeze(1)  # Just add head dimension [B, 1, seq_len, dim]
        else:
            x = x.reshape(self.batch_size, -1, self.d_k, self.seq_length)
            return x.permute(0, 1, 3, 2).contiguous()

    def combine_heads(self, x):
        if self.num_heads == 1:
            return x.squeeze(1)  # Just remove head dimension
        else:
            batch_size, _, seq_length, d_k = x.size()
            return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden)
        
    def forward(self, x):
        k = self.batch_combine(self.split_head(self.W_k(x)))
        v = self.batch_combine(self.split_head(self.W_v(x)))
        q = self.batch_combine(self.split_head(self.W_q(x)))
        print(f"[q shape]: {q.shape} \n {q} \n")
        print(f"[q transpose shape]: {q.transpose(1, 2).shape} \n {q.transpose(1, 2)} \n")
        print(f"[k shape]: {k.shape} \n {k} \n")
        print(f"[v shape]: {v.shape} \n {v} \n")

        

        similarity_matrix = self._calculate_attention_matrix(k, q)
        print(f"[Attention Score]: {similarity_matrix.shape} \n {similarity_matrix} \n")

        # similarity_matrix = torch.softmax(similarity_matrix, dim=-1)

        prime = self._prime(v, similarity_matrix, self.K, self.maximum)

        # 4. Conv1d Layer
        x = self.conv(prime)  

        print(f"[After Conv1d]: {x.shape} \n {x} \n")

        # 5. Dropout + Reshape (B, seq_length, d_hidden)
        x = self.dropout(x)
        x = x.permute(0, 2, 1) 
        print(f"[After Dropout + Permute]: {x.shape} \n {x} \n")

        # # 6. Final Linear Projection
        x = self.W_o(self.combine_heads(self.batch_split(x)))
        # x = self.W_o(x)
        print(f"[After projection]: {x.shape} \n {x} \n")

        print(f"[After projection]: {x.shape} \n {x} \n")
        
        return x       
    def _calculate_attention_matrix(self, K, Q):
        attn_score = torch.matmul(K.transpose(1, 2), Q) / self.d_k**0.5
        return attn_score

    def _prime(self, v, qk, K, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=True)
        print(f"[Top-{K} Indices]: {topk_indices.shape} \n {topk_indices} \n")
        print(f"[Top-{K} Values]: {topk_values.shape} \n {topk_values} \n")

        topk_values = torch.softmax(topk_values, dim=-1)
        print(f"[After Softmax Top-K Values]: {topk_values.shape} \n {topk_values} \n")
        
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)

        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        
        prime = topk_values_exp * prime 
        
        prime = prime.view(b, c, -1)
        print(f"[Prime]: {prime.shape} \n {prime} \n")

        return prime


In [48]:
d_hidden = 8 
num_heads = 1
seq_length = 6
batch_size = 1
dropout = 0.0 

# x = torch.Tensor(
#     [
#         [
#             [ 1,  2,  3,  4,  5,  6,  7,  8],
#             [ 9, 10, 11, 12, 13, 14, 15, 16],
#             [17, 18, 19, 20, 21, 22, 23, 24],
#             [25, 26, 27, 28, 29, 30, 31, 32],
#             [33, 34, 35, 36, 37, 38, 39, 40],
#             [41, 42, 43, 44, 45, 46, 47, 48],
#         ]
#     ]
# ).to(torch.float32)

print(f"[Input x]: {x.shape} \n {x} \n")

convnn = MultiHeadConvNNAttention(d_hidden=d_hidden, 
                                  num_heads=num_heads, 
                                  attention_dropout=dropout, 
                                  K=3,
                                  sampling_type='all',
                                  num_samples=-1,
                                  sample_padding=0,
                                  magnitude_type='cosine',
                                  coordinate_encoding=False,
                                  seq_length=6
                                  )

convnn.W_k.weight.data.fill_(1.0)
convnn.W_q.weight.data.fill_(1.0)
convnn.W_v.weight.data.fill_(1.0)
convnn.W_o.weight.data.fill_(1.0)
convnn.conv.weight.data.fill_(1.0)
out = convnn(x)




[Input x]: torch.Size([1, 6, 8]) 
 tensor([[[-0.8775, -2.1730,  0.7890,  0.5044, -0.9170,  0.6572, -0.0923,
           0.3994],
         [-1.8758, -0.8411,  1.3538, -2.2346, -0.8308,  2.4109,  1.1988,
          -0.5029],
         [ 0.6618,  1.5232,  1.5176, -0.3011, -0.5179, -1.2367,  1.1713,
           0.5475],
         [ 1.4358, -1.6689, -0.1532,  0.3160,  0.5740, -0.6418,  0.3470,
           1.1666],
         [-1.0268, -0.3807,  0.1988,  1.6391,  0.0255,  0.5719, -0.3235,
          -0.5822],
         [-0.5843, -0.5342, -0.5527,  0.3054, -0.0530, -0.1742, -1.4902,
          -0.5062]]]) 

[q shape]: torch.Size([1, 8, 6]) 
 tensor([[[-1.7098, -1.3217,  3.3658,  1.3755,  0.1221, -3.5894],
         [-1.7098, -1.3217,  3.3658,  1.3755,  0.1221, -3.5894],
         [-1.7098, -1.3217,  3.3658,  1.3755,  0.1221, -3.5894],
         [-1.7098, -1.3217,  3.3658,  1.3755,  0.1221, -3.5894],
         [-1.7098, -1.3217,  3.3658,  1.3755,  0.1221, -3.5894],
         [-1.7098, -1.3217,  3.3658,  1.375

### 3. KvT Attention

In [49]:

class MultiHeadKvtAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,topk=100):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
    

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim, bias=False)
        self.proj_drop = nn.Dropout(proj_drop)
        self.topk = topk

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
        print(f"[q shape]: {q.shape} \n {q} \n")
        print(f"[k shape]: {k.shape} \n {k} \n")
        print(f"[v shape]: {v.shape} \n {v} \n")
                
        attn = (q @ k.transpose(-2, -1)) * self.scale
        print(f"[Attention Score]: {attn.shape} \n {attn} \n")
        
        # the core code block
        mask=torch.zeros(B,self.num_heads,N,N,device=x.device,requires_grad=False)
        print(f"[Attention mask initialization]: {mask.shape} \n {mask} \n")

        index=torch.topk(attn,k=self.topk,dim=-1,largest=True)[1]
        print(f"[Top-k indices]: {index.shape} \n {index} \n")

        mask.scatter_(-1,index,1.)
        print(f"[Attention mask after scatter]: {mask.shape} \n {mask} \n")
        
        attn=torch.where(mask>0,attn,torch.full_like(attn,float('-inf')))
        print(f"[Attention Score after masking]: {attn.shape} \n {attn} \n")

        # end of the core code block
        attn = attn.softmax(dim=-1)
        print(f"[Attention Probability]: {attn.shape} \n {attn} \n")
        attn = self.attn_drop(attn)

        x = (attn @ v)
        print(f"[x after attn @ v]: {x.shape} \n {x} \n")

        x = x.transpose(1, 2).reshape(B, N, C)
        print(f"[x after transpose and reshape]: {x.shape} \n {x} \n")

        x = self.proj(x)
        print(f"[After projection]: {x.shape} \n {x} \n")
        x = self.proj_drop(x)
        return x

In [50]:
d_hidden = 8 
num_heads = 1
seq_length = 6
batch_size = 1
dropout = 0.0 

# x = torch.Tensor(
#     [
#         [
#             [ 1,  2,  3,  4,  5,  6,  7,  8],
#             [ 9, 10, 11, 12, 13, 14, 15, 16],
#             [17, 18, 19, 20, 21, 22, 23, 24],
#             [25, 26, 27, 28, 29, 30, 31, 32],
#             [33, 34, 35, 36, 37, 38, 39, 40],
#             [41, 42, 43, 44, 45, 46, 47, 48],
#         ]
#     ]
# ).to(torch.float32)

print(f"[Input x]: {x.shape} \n {x} \n")

kvt = MultiHeadKvtAttention(d_hidden,
                               num_heads, 
                               qkv_bias=False,
                               qk_scale=None,
                               attn_drop=dropout,
                               proj_drop=dropout,
                               topk=3)

kvt.qkv.weight.data.fill_(1.0)
kvt.proj.weight.data.fill_(1.0)
out = kvt(x)




[Input x]: torch.Size([1, 6, 8]) 
 tensor([[[-0.8775, -2.1730,  0.7890,  0.5044, -0.9170,  0.6572, -0.0923,
           0.3994],
         [-1.8758, -0.8411,  1.3538, -2.2346, -0.8308,  2.4109,  1.1988,
          -0.5029],
         [ 0.6618,  1.5232,  1.5176, -0.3011, -0.5179, -1.2367,  1.1713,
           0.5475],
         [ 1.4358, -1.6689, -0.1532,  0.3160,  0.5740, -0.6418,  0.3470,
           1.1666],
         [-1.0268, -0.3807,  0.1988,  1.6391,  0.0255,  0.5719, -0.3235,
          -0.5822],
         [-0.5843, -0.5342, -0.5527,  0.3054, -0.0530, -0.1742, -1.4902,
          -0.5062]]]) 

[q shape]: torch.Size([1, 1, 6, 8]) 
 tensor([[[[-1.7098, -1.7098, -1.7098, -1.7098, -1.7098, -1.7098, -1.7098,
           -1.7098],
          [-1.3217, -1.3217, -1.3217, -1.3217, -1.3217, -1.3217, -1.3217,
           -1.3217],
          [ 3.3658,  3.3658,  3.3658,  3.3658,  3.3658,  3.3658,  3.3658,
            3.3658],
          [ 1.3755,  1.3755,  1.3755,  1.3755,  1.3755,  1.3755,  1.3755,
      