In [6]:
from math import ceil
import torch
import torch.nn.functional as F
from einops import rearrange, reduce
from torch import Tensor
from torch import einsum
from typing import Tuple
from omegaconf import DictConfig
import sys

In [7]:
sys.path.append(r'C:\Users\34609\VisualStudio\TFG\attention_zoo')  
from attentions.abstract_attention import AbstractAttention
sys.path.append(r'C:\Users\34609\VisualStudio\TFG')  
from utils import iterative_inv

In [35]:
print(iterative_inv)

<function iterative_inv at 0x000002705443B5B0>


In [17]:
cfg = omegaconf.OmegaConf.create({
    'model': {
        'model': {
            'ATTENTION': 'nystromformer',
            'eps': 1e-8,
            'num_landmarks': 64,
            'pinv_iterations': 64,
            'NUM_CLASSES': 96,
            'PATCH_SIZE': 16,
            'DEPTH': 2,
            'HEADS': 4
        },
        'ATTENTION': 'nystromformer',
        'eps': 1e-8,
        'num_landmarks': 64,
        'pinv_iterations': 64,
        'NUM_CLASSES': 96,
        'PATCH_SIZE': 16,
        'DEPTH': 2,
        'HEADS': 4
    }
})

In [30]:
class NystromformerAttention(AbstractAttention):
    def __init__(self, hpars: DictConfig, n: int, h: int, in_feat: int, out_feat: int) -> None:
        super().__init__(n=n, h=h, in_feat=in_feat, out_feat=out_feat)
        self.model_params = hpars.model

        self.n_orig = None
        self.eps = self.model_params.eps
        self.num_landmarks = self.model_params.num_landmarks
        self.pinv_iterations = self.model_params.pinv_iterations

    # This is an optional function that if overwritten, it adds the necessary padding
    def pad_input(self, x: Tensor) -> Tensor:
        b, n, d, f, m = *x.shape, self.num_landmarks
        remainder = d % m

        self.original_dim = d
        if remainder > 0:
            padding = m - (d % m)
            x = F.pad(x, (0, 0, padding, 0), value=0)

        self.n_orig = d

        return x

    def apply_attention(self, Q: Tensor, K: Tensor, V: Tensor, debug: bool = False, mask=None) -> Tuple[Tensor, Tensor]:
        b, h, n, d_head, m, iters, eps = *Q.shape, self.num_landmarks, self.pinv_iterations, self.eps

        # If necessary, add padding to the embeddings to be divisible
        Q,K,V = map(lambda t: self.pad_input(t), (Q, K, V))
        
        for mat in [Q, K, V]:
            isnan = torch.isnan(mat).any()
            print(f'Nans 1: {isnan}')

        # set masked positions to 0 in queries, keys, values
        if mask is not None:
            mask = rearrange(mask, 'b n -> b () n')
            Q, K, V = map(lambda t: t * mask[..., None], (Q, K, V))
        Q *= (d_head ** -0.5)

        # generate landmarks by sum reduction, and then calculate mean using the mask
        l = ceil(n / m)
        print(f'l: {l}')
        landmark_einops_eq = '... (n l) d -> ... n d'
        q_landmarks = reduce(Q, landmark_einops_eq, 'sum', l=l)
        k_landmarks = reduce(K, landmark_einops_eq, 'sum', l=l)

        # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean
        divisor = l
        if mask is not None:
            mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l=l)
            divisor = mask_landmarks_sum[..., None] + eps
            mask_landmarks = mask_landmarks_sum > 0

        # masked mean (if mask exists)
        q_landmarks /= divisor
        k_landmarks /= divisor

        # similarities
        einops_eq = '... i d, ... j d -> ... i j'
        sim1 = einsum(einops_eq, Q, k_landmarks)
        sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
        sim3 = einsum(einops_eq, q_landmarks, K)

        # masking
        if mask is not None:
            mask_value = -torch.finfo(Q.dtype).max
            sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
            sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
            sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)

        # eq (15) in the paper and aggregate values
        attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3))
        
        for mat in [attn1, attn2, attn3]:
            isnan = torch.isnan(mat).any()
            print(f'Nans 2: {isnan}')
        
        attn2_inv = iterative_inv(attn2, iters)
        isnan = torch.isnan(attn2_inv).any()
        print(f'Nans 3: {isnan}')
        
        out = (attn1 @ attn2_inv) @ (attn3 @ V)

        # Merge the multiple heads into one
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = out[:, -n:, :]  # Select only last n to get rid of the padded ones

        return out, None if not debug else (attn1 @ attn2_inv @ attn3)


In [31]:
x = torch.rand(4, 9800, 768)

In [32]:
class MultiHeadAttention(nn.Module):
    def __init__(self, cfg, dim, num_heads=4, num_patches=9800, proj_drop=0., attn_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.attention = NystromformerAttention(cfg, in_feat=dim, out_feat=dim, n=num_patches, h=num_heads)
        self.qkv = nn.Linear(dim, dim * 3)  # (B, N, C) -> (B, N, C * 3)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        
    def forward(self, x):
        B, N, C = x.shape
        print(f'x shape; {x.shape}')
        qkv = self.qkv(x)
        print(f'qkv: {self.qkv(x).shape}')
        qkv = rearrange(qkv, 'b n (c h1 c1) -> b n c h1 c1', h1=self.num_heads, c1=C//self.num_heads)
        print(f'qkv reshaped: {qkv.shape}')
        qkv = rearrange(qkv, 'b n c h1 c1 -> c b h1 n c1')
        print(f'qkv reshaped and permuted: {qkv.shape}')
        q, k, v = qkv[0], qkv[1], qkv[2]
        print(f'q: {q.shape}, k: {k.shape}, v: {v.shape}')
        output = self.attention.apply_attention(Q=q, K=k, V=v)
        return output

In [33]:
att = MultiHeadAttention(cfg, 768)

In [34]:
out = att.forward(x)

x shape; torch.Size([4, 9800, 768])
qkv: torch.Size([4, 9800, 2304])
qkv reshaped: torch.Size([4, 9800, 3, 4, 192])
qkv reshaped and permuted: torch.Size([3, 4, 4, 9800, 192])
q: torch.Size([4, 4, 9800, 192]), k: torch.Size([4, 4, 9800, 192]), v: torch.Size([4, 4, 9800, 192])
Nans 1: False
Nans 1: False
Nans 1: False
l: 154
Nans 2: False
Nans 2: False
Nans 2: False
Nans 3: False
