In [2]:
import torch
from torch import nn
from fvcore.nn import FlopCountAnalysis
from gpytorch.kernels.kernel import Distance
from timm.models.layers import PatchEmbed
import math
import torch.nn.functional as F

In [115]:
class GMMAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,prune_amount = 0.7, seq_len = 197, type_ = 'gmm', prune_token_amount = None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.prune_token_amount = prune_token_amount
        self.seq_len = seq_len
        self.amount = seq_len - math.floor(prune_amount*seq_len)
        self.prune_amount = prune_amount
        self.type = type_
        
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
       
        self.dist = Distance()
        if self.prune_token_amount is None:
            self.register_buffer('pi_mask', torch.ones(1, self.num_heads, seq_len, seq_len))
        else:
            self.register_buffer('pi_mask', torch.ones(1, self.num_heads, seq_len, int(seq_len*(1 - prune_token_amount))))

    def forward(self, x, x_k = None):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0,2,1,3)
        if x_k is not None:
            k = self.k(x_k).reshape(B, x_k.shape[1], self.num_heads, -1).permute(0,2,1,3)
            v = self.v(x_k).reshape(B, x_k.shape[1], self.num_heads, -1).permute(0,2,1,3)
        else:
            k = self.k(x).reshape(B, N, self.num_heads, -1).permute(0,2,1,3)
            v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0,2,1,3)

        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]   
 
        ### to calculate flops
        if self.type == 'gmm':
            if self.prune_token_amount is not None:
                self.amount = self.seq_len - math.floor((self.prune_amount - self.prune_token_amount)*self.seq_len) 
            attn = (-self.scale/2.0)*self.dist._sq_dist(q[:,:, :self.amount, :], k, postprocess = False)
            # attn = torch.einsum('bhle,bhme->bhlm',q[:,:, :self.amount, :]/self.scale,k)
            attn =F.pad(attn, (0,0,self.seq_len- self.amount,0), 'constant', 0)
            print(attn.shape, self.pi_mask.shape)
            attn = self.pi_mask*torch.exp(attn)
        else:
            attn = torch.einsum('bhle,bhme->bhlm',q/self.scale,k)
            attn = torch.exp(attn)
        attn = attn / (attn.sum(dim=-1, keepdim = True) + 1e-6)

        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [116]:
patch_embed = PatchEmbed(img_size=224, patch_size=4, in_chans=3, embed_dim=192)
num_patches = patch_embed.num_patches

In [117]:
seq_len = [197, 785, 3137]

In [125]:
seq_len = 197
num_head = 4
dim = 192
x = torch.randn(1, seq_len, dim)
prune_token_amount = 0.15

model = GMMAttention(dim=dim, num_heads=4, qkv_bias= True, seq_len = seq_len, type_= 'softmax')
flops = FlopCountAnalysis(model, x)
flop_count_softmax = flops.total()/1e9

model = GMMAttention(dim=dim, num_heads=4, qkv_bias= True, seq_len = seq_len, type_= 'gmm')
flops = FlopCountAnalysis(model, x)
flop_count_gmm = flops.total()/1e9

model = GMMAttention(dim=dim, num_heads=4, qkv_bias= True, seq_len = seq_len, type_= 'gmm', prune_token_amount=0.15)
x_k = x[:,:int(seq_len*(1 - prune_token_amount)),:]
flops = FlopCountAnalysis(model, (x, x_k))
flop_count_gmm_key = flops.total()/1e9

print(flop_count_softmax/flop_count_gmm_key)

Unsupported operator aten::div encountered 2 time(s)
Unsupported operator aten::exp encountered 1 time(s)
Unsupported operator aten::sum encountered 1 time(s)
Unsupported operator aten::add encountered 1 time(s)
Unsupported operator aten::mean encountered 1 time(s)
Unsupported operator aten::sub encountered 2 time(s)
Unsupported operator aten::pow encountered 2 time(s)
Unsupported operator aten::sum encountered 3 time(s)
Unsupported operator aten::ones_like encountered 2 time(s)
Unsupported operator aten::mul encountered 3 time(s)
Unsupported operator aten::clamp_min_ encountered 1 time(s)
Unsupported operator aten::exp encountered 1 time(s)
Unsupported operator aten::add encountered 1 time(s)
Unsupported operator aten::div encountered 1 time(s)
Unsupported operator aten::mean encountered 1 time(s)
Unsupported operator aten::sub encountered 2 time(s)
Unsupported operator aten::pow encountered 2 time(s)
Unsupported operator aten::sum encountered 3 time(s)
Unsupported operator aten::ones

torch.Size([1, 4, 197, 197]) torch.Size([1, 4, 197, 197])
torch.Size([1, 4, 197, 167]) torch.Size([1, 4, 197, 167])
1.2165730134915935


In [126]:
flop_count_softmax/flop_count_gmm

1.1308660730091684

In [127]:
flop_count_softmax/flop_count_gmm_key

1.2165730134915935

In [123]:
flop_count_gmm

0.27112016