In [1]:
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
import omegaconf

In [8]:
x = torch.rand(1, 3, 30, 240, 240) # b, c, t, w, h

## PATCH TOKENIZATION

In [3]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=240, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = [img_size, img_size]
        patch_size = [patch_size, patch_size]
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, T, H, W = x.shape
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        print(f'x shape 1: {x.shape}')
        x = self.proj(x)
        print(f'x shape 2: {x.shape}')
        W = x.size(-1)
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x, T, W

In [4]:
patching = PatchEmbed(img_size=480)

In [5]:
x, T, W = patching(x)

x shape 1: torch.Size([30, 3, 240, 240])
x shape 2: torch.Size([30, 768, 15, 15])


In [6]:
print(f'x shape: {x.size()}') # ( frames x batches ), nº patches, patch_embed = (3 x 16 x 16)
print(f'T: {T}')
print(f'W: {W}')

x shape: torch.Size([30, 225, 768])
T: 30
W: 15


## MULTIHEAD ATTENTION

In [14]:
import sys, os
sys.path.append(r'C:\Users\34609\VisualStudio\TFG\attention_zoo')  
from base_attention import BaseAttention

In [15]:
cfg = omegaconf.OmegaConf.create({'name' : 'vanilla_attention'})

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, cfg, dim, num_heads=8, proj_drop=0., attn_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.attention = BaseAttention.init_att_module(cfg, in_feat=dim, out_feat=dim, n=dim, h=dim)
        self.qkv = nn.Linear(dim, dim * 3)  # (B, N, C) -> (B, N, C * 3)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        print(f'qkv: {self.qkv(x).shape}')
        qkv = rearrange(qkv, 'b n (c h w) -> b n c h w', h=self.num_heads, w=C//self.num_heads)
        print(f'qkv reshaped: {qkv.shape}')
        qkv = rearrange(qkv, 'b n c h w -> c b h n w')
        print(f'qkv reshaped and permuted: {qkv.shape}')
        q, k, v = qkv[0], qkv[1], qkv[2]
        output = self.attention.apply_attention(Q=q, K=k, V=v)
        return output

In [17]:
mha = MultiHeadAttention(cfg=cfg, dim=768)

In [18]:
out = mha.forward(x)

qkv: torch.Size([30, 225, 2304])
qkv reshaped: torch.Size([30, 225, 3, 8, 96])
qkv reshaped and permuted: torch.Size([3, 30, 8, 225, 96])


In [19]:
print(f'Output shape: {out[0].shape}')
print(f'Scores: {out[1]}')

Output shape: torch.Size([30, 225, 768])
Scores: None


## MLP

In [20]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        print(f'in: {in_features} / hidden: {hidden_features} / out: {out_features}')
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)

        return x 

In [21]:
mlp = MLP(in_features=768, hidden_features=4*768)

in: 768 / hidden: 3072 / out: 768


In [22]:
mlp_out = mlp.forward(out[0])

In [23]:
print(mlp_out.shape)

torch.Size([30, 225, 768])


## ATTENTION BLOCK

In [24]:
class Block(nn.Module):
    def __init__(self, cfg, dim, num_heads, mlp_ratio=4., proj_drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = MultiHeadAttention(cfg, dim, num_heads, proj_drop, attn_drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(dim, mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [25]:
block = Block(cfg, dim=768, num_heads=4)

in: 768 / hidden: 3072 / out: 768


In [26]:
block_out = block.forward(x)

qkv: torch.Size([30, 225, 2304])
qkv reshaped: torch.Size([30, 225, 3, 4, 192])
qkv reshaped and permuted: torch.Size([3, 30, 4, 225, 192])


In [27]:
print(block_out.shape)

torch.Size([30, 225, 768])


## MODEL

In [82]:
class Model(nn.Module):
    """
    Model class with PatchTokenization + (MuliHeadAttention + MLP) x L + MLP
    """
    def __init__(self, cfg, img_size=240, patch_size=16, in_chans=3, embed_dim=768, num_classes=97, depth=2, num_heads=4, mlp_ratio=4.,
                 proj_drop=0., attn_drop=0., norm_layer=nn.LayerNorm, num_frames=30, dropout=0.):
        super().__init__()
        self.depth = depth
        self.dropout = nn.Dropout(dropout)
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        self.num_frames = num_frames
        self.patch_embed= PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # Positional Embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(num_frames, num_patches+1, embed_dim))
        # self.time_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
                                       
        # Attention Blocks
        self.blocks = nn.ModuleList([
            Block(cfg, embed_dim, num_heads, mlp_ratio, proj_drop, attn_drop, act_layer=nn.GELU, norm_layer=norm_layer)
            for i in range(self.depth)])                            
        self.norm = norm_layer(embed_dim)
        
        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        x, T, W = self.patch_embed(x)
        
        # add class token
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1) # (1, 1, embed) -> (30, 1, embed)
        x = torch.cat((cls_tokens, x), dim=1) # (batch x frames, patches, embed) -> (batch x frames, patches + 1, embed)
    
        # add positional/temporal embedding
        x = x + self.pos_embed
        
        for block in self.blocks:
            x = block.forward(x)
        x = rearrange(x, '(b f) p e -> b f p e', f=self.num_frames) # (batch x frames, patches, embed) -> (batch, frames, patch, embed)
        x = torch.mean(x, [1,2])
        x = self.head(x)
        return x               

In [83]:
model = Model(cfg)

in: 768 / hidden: 3072 / out: 768
in: 768 / hidden: 3072 / out: 768


In [85]:
model_out = model(x)

x shape 1: torch.Size([30, 3, 240, 240])
x shape 2: torch.Size([30, 768, 15, 15])
qkv: torch.Size([30, 226, 2304])
qkv reshaped: torch.Size([30, 226, 3, 4, 192])
qkv reshaped and permuted: torch.Size([3, 30, 4, 226, 192])
qkv: torch.Size([30, 226, 2304])
qkv reshaped: torch.Size([30, 226, 3, 4, 192])
qkv reshaped and permuted: torch.Size([3, 30, 4, 226, 192])


In [86]:
print(model_out.size())

torch.Size([1, 97])
