In [92]:
from models.retiro_model import *
import numpy as np

# Linear Attention Vision Transformer

In [115]:
class RetiroAttention(nn.Module):
    def __init__(self, chan, key_dim, value_dim, heads, kernel_size,padding,stride):
        super(RetiroAttention, self).__init__()
        self.query  = nn.Conv2d(chan, key_dim * heads, kernel_size, padding=padding, stride=stride)
        self.key    = nn.Conv2d(chan, key_dim * heads, kernel_size, padding=padding, stride=stride)
        self.value  = nn.Conv2d(chan, value_dim * heads, kernel_size, padding=padding, stride=stride)

        self.attn_drop = nn.Dropout(0.0)
        self.proj = nn.Conv2d(value_dim * heads, chan, kernel_size, padding=padding)
        self.n_head = heads
    def forward(self,x):
        
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.query(x.permute(0,3,1,2)).flatten(start_dim=-2).transpose(1, 2)
        k = self.key(x.permute(0,3,1,2)).flatten(start_dim=-2).transpose(1, 2)
        v = self.value(x.permute(0,3,1,2)).flatten(start_dim=-2).transpose(1, 2)

        print(q.shape)
        B, C, T = q.size()
        print( B, C, T)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        
        print(f'Linear Attention: q {q.shape} k {k.shape} v {v.shape}')

        k = k.softmax(dim=-1)   #
        k_cumsum = k.sum(dim=[-2], keepdim=True)
        D_inv = 1. / (q * k_cumsum).sum(dim=-1, keepdim=True)

        
        context = k.transpose(-2, -1) @ v
        y = self.attn_drop((q @ context) * D_inv + q)
        y1 = rearrange(y, 'b h n d -> b n (h d)')
        y2 =torch.einsum('bhnd->bnhd',y)
        print(f'y: {y.shape}, y1: {y1.shape}, y2: {y2.shape}')
        out = self.proj(y.reshape(B, -1, T, T))

        return out

In [97]:
class LinearAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, n_embd, n_head, attn_pdrop=0.0):
        super(LinearAttention, self).__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)

        self.n_head = n_head

        self.attn_type = 'l1'
    
    def forward(self,x):
        B, T1, C = x.size()
        
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.query(x)#.view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        k = self.key(x)#.view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = self.value(x)#.view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        print(f'Linear Attention: q {q.shape} k {k.shape} v {v.shape}')

        print(f'Linear Attention: q {q.view(B, T1, self.n_head, C // self.n_head).shape}')
        print(f'Linear Attention: q {q.view(B, T1, self.n_head, C // self.n_head).transpose(1, 2).shape}')

        q = q.view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)

        k = k.softmax(dim=-1)   #
        print(k.shape)
        k_cumsum = k.sum(dim=-2, keepdim=True)
        print(k_cumsum.shape)
        D_inv = 1. / (q * k_cumsum).sum(dim=-1, keepdim=True)   
        print(D_inv.shape)

        print(f'k {k.shape}')
        print(f'k {k.transpose(-2, -1).shape}')
        context = k.transpose(-2, -1) @ v
        print(context.shape)
        y = self.attn_drop((q @ context) * D_inv + q)
        print(y.shape)

       

In [114]:
from einops import repeat, rearrange

In [94]:
input = torch.rand([4,128,64])
layer1 = LinearAttention(n_embd=64, n_head=1)
output = layer1(input)

Linear Attention: q torch.Size([4, 128, 64]) k torch.Size([4, 128, 64]) v torch.Size([4, 128, 64])
Linear Attention: q torch.Size([4, 128, 1, 64])
Linear Attention: q torch.Size([4, 1, 128, 64])
torch.Size([4, 1, 128, 64])
torch.Size([4, 1, 1, 64])
torch.Size([4, 1, 128, 1])
k torch.Size([4, 1, 128, 64])
k torch.Size([4, 1, 64, 128])
torch.Size([4, 1, 64, 64])
torch.Size([4, 1, 128, 64])


In [116]:
input = torch.rand([4,128,128,32])
layer2 = RetiroAttention(chan=32, key_dim=32, value_dim=32, heads=2, kernel_size=2,padding=0,stride=2)
output = layer2(input)
output.shape

torch.Size([4, 4096, 64])
4 4096 64
Linear Attention: q torch.Size([4, 2, 64, 2048]) k torch.Size([4, 2, 64, 2048]) v torch.Size([4, 2, 64, 2048])
y: torch.Size([4, 2, 64, 2048]), y1: torch.Size([4, 64, 4096]), y2: torch.Size([4, 64, 2, 2048])


torch.Size([4, 32, 63, 63])

In [54]:
output.shape

AttributeError: 'NoneType' object has no attribute 'shape'