In [1]:
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
import omegaconf
import math
from datetime import datetime

In [2]:
x = torch.rand(4, 3, 200, 112, 112) # b, c, t, w, h

## PATCH TOKENIZATION

In [3]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=112, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = [img_size, img_size]
        patch_size = [patch_size, patch_size]
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, T, H, W = x.shape
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        print(f'x shape 1: {x.shape}')
        x = self.proj(x)
        print(f'x shape 2: {x.shape}')
        W = x.size(-1)
        # x = rearrange(x, 'b c h w -> b (h w) c')
        x = rearrange(x, '(b t) c w h -> b (w h t) c', t=200) 
        return x, T, W

In [4]:
patching = PatchEmbed(img_size=112)

In [5]:
num_patches = x.shape[1]
num_patches

3

In [6]:
x, T, W = patching.forward(x)

x shape 1: torch.Size([800, 3, 112, 112])
x shape 2: torch.Size([800, 768, 7, 7])


In [7]:
print(f'x shape: {x.size()}') # ( frames x batches ), nº patches, patch_embed = (3 x 16 x 16)
print(f'T: {T}')
print(f'W: {W}')

x shape: torch.Size([4, 9800, 768])
T: 200
W: 7


## MULTIHEAD ATTENTION

In [8]:
import sys, os
sys.path.append(r'C:\Users\34609\VisualStudio\TFG\attention_zoo')  
from base_attention import BaseAttention

In [9]:
# MODEL_V1.YAML
# cfg = omegaconf.OmegaConf.create({
#     'model': {
#         'ATTENTION' : 'rela_attention'
#     }
# })

# MODEL_V2.YAML
# cfg = omegaconf.OmegaConf.create({
#     'model': {
#         'ATTENTION' : 'skyformer',
#         'accumulation': 1,
#         'num_feats': 128
#     }
# })

# MODEL_V3.YAML
# cfg = omegaconf.OmegaConf.create({
#     'model': {
#         # 'ATTENTION': 'nystromformer',
#         'ATTENTION': 'cosformer',
#         'eps': 1e-8,
#         'num_landmarks': 64,
#         'pinv_iterations': 64
#     }
# })

# MODEL_V4.YAML
cfg = omegaconf.OmegaConf.create({
    'model': {
        'model': {
            'ATTENTION': 'cosformer',
            'eps': 1e-8,
            'num_landmarks': 64,
            'pinv_iterations': 64,
            'NUM_CLASSES': 96,
            'PATCH_SIZE': 16,
            'DEPTH': 2,
            'HEADS': 4
        },
        'ATTENTION': 'cosformer'
    }
})

# MODEL_V5.YAML
# cfg = omegaconf.OmegaConf.create({
#     'model': {
#         'ATTENTION' : 'linformer',
#         'proj_feats': 64 
#     }
# })

# MODEL_V6.YAML
# cfg = omegaconf.OmegaConf.create({
#     'model': {
#         'ATTENTION' : 'performer',
#         'kernel_type': 'relu'
#     }
# })

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self, cfg, dim, num_heads=4, num_patches=num_patches, proj_drop=0., attn_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.attention = BaseAttention.init_att_module(cfg, in_feat=dim, out_feat=dim, n=num_patches, h=num_heads)
        self.qkv = nn.Linear(dim, dim * 3)  # (B, N, C) -> (B, N, C * 3)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        
    def forward(self, x):
        B, N, C = x.shape
        print(f'x shape; {x.shape}')
        qkv = self.qkv(x)
        print(f'qkv: {self.qkv(x).shape}')
        qkv = rearrange(qkv, 'b n (c h1 c1) -> b n c h1 c1', h1=self.num_heads, c1=C//self.num_heads)
        print(f'qkv reshaped: {qkv.shape}')
        qkv = rearrange(qkv, 'b n c h1 c1 -> c b h1 n c1')
        print(f'qkv reshaped and permuted: {qkv.shape}')
        q, k, v = qkv[0], qkv[1], qkv[2]
        print(f'q: {q.shape}, k: {k.shape}, v: {v.shape}')
        output = self.attention.apply_attention(Q=q, K=k, V=v)
        return output

In [11]:
mha = MultiHeadAttention(cfg=cfg, dim=768)

In [12]:
out = mha.forward(x)

x shape; torch.Size([4, 9800, 768])
qkv: torch.Size([4, 9800, 2304])
qkv reshaped: torch.Size([4, 9800, 3, 4, 192])
qkv reshaped and permuted: torch.Size([3, 4, 4, 9800, 192])
q: torch.Size([4, 4, 9800, 192]), k: torch.Size([4, 4, 9800, 192]), v: torch.Size([4, 4, 9800, 192])
Q is nan 1: False
Q is nan 2: False
m: 9800
weight index is nan: False
sin is nan: False
Q*sin is nan: False
cos is nan: False
Q*cos is nan: False
q is nan: False
k is nan: False
kv_ is nan: False
z_ is nan: False
attn output 1 is nan: False
attn output 2 is nan: False


In [13]:
print(f'Output shape: {out[0].shape}')
# print(f'Scores: {out}')

Output shape: torch.Size([4, 9800, 768])


## MLP

In [20]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        print(f'in: {in_features} / hidden: {hidden_features} / out: {out_features}')
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)

        return x 

In [21]:
mlp = MLP(in_features=768, hidden_features=4*768)

in: 768 / hidden: 3072 / out: 768


In [22]:
mlp_out = mlp.forward(out[0])

In [23]:
print(mlp_out.shape)

torch.Size([4, 9800, 768])


## ATTENTION BLOCK

In [24]:
class Block(nn.Module):
    def __init__(self, cfg, dim, num_heads, mlp_ratio=4., proj_drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = MultiHeadAttention(cfg, dim, num_heads, proj_drop, attn_drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(dim, mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [25]:
block = Block(cfg, dim=768, num_heads=4)

in: 768 / hidden: 3072 / out: 768


In [26]:
block_out = block.forward(x)

x shape; torch.Size([4, 9800, 768])
qkv: torch.Size([4, 9800, 2304])
qkv reshaped: torch.Size([4, 9800, 3, 4, 192])
qkv reshaped and permuted: torch.Size([3, 4, 4, 9800, 192])
q: torch.Size([4, 4, 9800, 192]), k: torch.Size([4, 4, 9800, 192]), v: torch.Size([4, 4, 9800, 192])


In [27]:
print(block_out.shape)

torch.Size([4, 9800, 768])


In [29]:
print(torch.isnan(block_out))

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [

## MODEL

In [30]:
class Model(nn.Module):
    """
    Model class with PatchTokenization + (MuliHeadAttention + MLP) x L + MLP
    """
    def __init__(self, cfg, img_size=112, patch_size=16, in_chans=3, embed_dim=768, num_classes=97, depth=2, num_heads=4, mlp_ratio=4.,
                 proj_drop=0., attn_drop=0., norm_layer=nn.LayerNorm, num_frames=200, dropout=0., batch_size=1):
        super().__init__()
        self.depth = depth
        self.dropout = nn.Dropout(dropout)
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        self.num_frames = num_frames
        self.patch_embed= PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches * self.num_frames
        
        # Positional Embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(batch_size, num_patches+1, embed_dim))
        # self.time_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
                                       
        # Attention Blocks
        self.blocks = nn.ModuleList([
            Block(cfg, embed_dim, num_heads, mlp_ratio, proj_drop, attn_drop, act_layer=nn.GELU, norm_layer=norm_layer)
            for i in range(self.depth)])                            
        self.norm = norm_layer(embed_dim)
        
        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        x, T, W = self.patch_embed(x)
        
        # add class token
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1) # shape: (1, 1, embed) -> (batches, 1, embed)
        print(f'cls_tokens shape: {cls_tokens.shape}')
        x = torch.cat((cls_tokens, x), dim=1) # (batch, frames * patches, embed) -> (batch, frames * patches + 1, embed)
        print(f'torch cat: {x.shape}')
    
        # add positional/temporal embedding
        x = x + self.pos_embed
        print(f'x + pos_embed: {x.shape}')
    
        for block in self.blocks:
            x = block.forward(x)
        # x = rearrange(x, 'b (p f) e -> b f p e', f=self.num_frames) # (batch x frames, patches, embed) -> (batch, frames, patch, embed)
        # x = torch.mean(x, [1,2])
        x = x[:, -1]
        print(f'x shape: {x.shape}')
        x = self.head(x)
        return x               

In [31]:
model = Model(cfg)

in: 768 / hidden: 3072 / out: 768
in: 768 / hidden: 3072 / out: 768


In [32]:
params = model.named_parameters()
count = 0
for param in params:
    # print(param[0])
    count += 1
print(count)

44


In [33]:
x = torch.rand(3, 3, 200, 112, 112) # b, c, t, w, h
model_out = model(x)

x shape 1: torch.Size([600, 3, 112, 112])
x shape 2: torch.Size([600, 768, 7, 7])
cls_tokens shape: torch.Size([3, 1, 768])
torch cat: torch.Size([3, 9801, 768])
x + pos_embed: torch.Size([3, 9801, 768])
x shape; torch.Size([3, 9801, 768])
qkv: torch.Size([3, 9801, 2304])
qkv reshaped: torch.Size([3, 9801, 3, 4, 192])
qkv reshaped and permuted: torch.Size([3, 3, 4, 9801, 192])
q: torch.Size([3, 4, 9801, 192]), k: torch.Size([3, 4, 9801, 192]), v: torch.Size([3, 4, 9801, 192])
x shape; torch.Size([3, 9801, 768])
qkv: torch.Size([3, 9801, 2304])
qkv reshaped: torch.Size([3, 9801, 3, 4, 192])
qkv reshaped and permuted: torch.Size([3, 3, 4, 9801, 192])
q: torch.Size([3, 4, 9801, 192]), k: torch.Size([3, 4, 9801, 192]), v: torch.Size([3, 4, 9801, 192])
x shape: torch.Size([3, 768])


In [32]:
print(model_out.size())

torch.Size([3, 97])


In [35]:
print(torch.isnan(model_out))

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, 

## Positional encoding

In [5]:
from math import sin, cos, pow

def pos_embed(
    batch_size: int,
    num_patches: int,
    embed_dim: int
) -> torch.tensor:
    pos_embed = torch.zeros(num_patches, embed_dim)
    
    for i in range(num_patches):
        for j in range(embed_dim):
            if j % 2 == 0:
                p = sin(i / pow(10000, ((2 * i) / embed_dim)))
            else:
                p = cos(i / pow(10000, ((2 * i) / embed_dim)))
            pos_embed[i][j] = p
    pos_embed = pos_embed.unsqueeze(0)
    
    return pos_embed        

In [2]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 9801):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_length, embedding_dim]`` orignally [seq, batch]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [47]:
div_term = torch.exp(torch.arange(1, 768, 2) * (-math.log(10000.0) / 768))
pe = torch.zeros(1, 9801, 768)
position = torch.arange(9801).unsqueeze(1)

In [48]:
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
pe.size()

torch.Size([1, 9801, 768])

In [49]:
x = torch.zeros(2, 9801, 768)
a = x + pe
a.size()

torch.Size([2, 9801, 768])

In [23]:
x.size()

torch.Size([1, 9800, 768])

In [25]:
cls_token = nn.Parameter(torch.zeros(1, 1, 768))

In [26]:
cls_tokens = cls_token.expand(x.size(0), -1, -1)
cls_tokens.size()

torch.Size([1, 1, 768])

In [30]:
x = torch.cat((cls_tokens, x), dim=1)

In [31]:
pos_enc = PositionalEncoding(d_model=768, max_len=9801)
x = pos_enc.forward(x)
x.size()

torch.Size([1, 9801, 768])

## RANDOM TESTS

In [36]:
print(omegaconf.OmegaConf.to_yaml(cfg))

model:
  ATTENTION: fastformer
  use_rotary_emb: false



In [43]:
from PIL import Image
import numpy as np
a = Image.open('sample_image.jpg')

In [45]:
print(np.asarray(a).shape)

(200, 150, 3)


In [49]:
arr = np.asarray(a)

In [46]:
import torchvision.transforms as T

In [59]:
b = T.ToTensor()

In [60]:
c = b(a)

In [62]:
print(c.shape)

torch.Size([3, 200, 150])


In [7]:
start = datetime.now()
print(start)

2023-04-19 17:15:47.901014
