In [1]:
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
import omegaconf

In [52]:
x = torch.rand(8, 3, 30, 480, 480) # b, c, t, w, h

## PATCH TOKENIZATION

In [53]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=480, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = [img_size, img_size]
        patch_size = [patch_size, patch_size]
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, T, H, W = x.shape
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        print(f'x shape 1: {x.shape}')
        x = self.proj(x)
        print(f'x shape 2: {x.shape}')
        W = x.size(-1)
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x, T, W

In [54]:
patching = PatchEmbed()

In [55]:
x, T, W = patching(x)

x shape 1: torch.Size([240, 3, 480, 480])
x shape 2: torch.Size([240, 768, 30, 30])


In [56]:
print(f'x shape: {x.shape}')
print(f'T: {T}')
print(f'W: {W}')

x shape: torch.Size([240, 900, 768])
T: 30
W: 30


## MULTIHEAD ATTENTION

In [7]:
import sys, os
sys.path.append(r'C:\Users\34609\VisualStudio\TFG\attention_zoo')  
from base_attention import BaseAttention

In [8]:
cfg = omegaconf.OmegaConf.create({'name' : 'vanilla_attention'})

In [57]:
class MultiHeadAttention(nn.Module):
    def __init__(self, cfg, dim, num_heads=8, proj_drop=0., attn_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.attention = BaseAttention.init_att_module(cfg, in_feat=dim, out_feat=dim, n=dim, h=dim)
        self.qkv = nn.Linear(dim, dim * 3)  # (B, N, C) -> (B, N, C * 3)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        print(f'qkv: {self.qkv(x).shape}')
        qkv = rearrange(qkv, 'b n (c h w) -> b n c h w', h=self.num_heads, w=C//self.num_heads)
        print(f'qkv reshaped: {qkv.shape}')
        qkv = rearrange(qkv, 'b n c h w -> c b h n w')
        print(f'qkv reshaped and permuted: {qkv.shape}')
        q, k, v = qkv[0], qkv[1], qkv[2]
        # output = self.attention.apply_attention(Q=q, K=k, V=v)
        # return output

In [58]:
mha = MultiHeadAttention(cfg=cfg, dim=768)

In [59]:
mha.forward(x)

qkv: torch.Size([240, 900, 2304])
qkv reshaped: torch.Size([240, 900, 3, 8, 96])
qkv reshaped and permuted: torch.Size([3, 240, 8, 900, 96])
