In [51]:
import os
import argparse
import math
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor, Compose, Normalize
from tqdm import tqdm
from einops import repeat, rearrange
from model import *

import torch
import numpy as np
import random

In [52]:
import torch
import timm
import numpy as np

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)

        self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding
        patches, forward_indexes, backward_indexes = self.shuffle(patches)
        print(patches.shape, self.cls_token.expand(-1, patches.shape[1], -1).shape)
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        print(patches.shape)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
        self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:] # remove global feature

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T-1:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

class ViT_Classifier(torch.nn.Module):
    def __init__(self, encoder : MAE_Encoder, num_classes=10) -> None:
        super().__init__()
        self.cls_token = encoder.cls_token
        self.pos_embedding = encoder.pos_embedding
        self.patchify = encoder.patchify
        self.transformer = encoder.transformer
        self.layer_norm = encoder.layer_norm
        self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')
        logits = self.head(features[0])
        return logits


if __name__ == '__main__':
    shuffle = PatchShuffle(0.75)
    a = torch.rand(16, 2, 10)
    b, forward_indexes, backward_indexes = shuffle(a)
    print(b.shape)

    img = torch.rand(10, 3, 32, 32)
    model = MAE_ViT()
    features, backward_indexes = model(img)
    # encoder = MAE_Encoder()
    # decoder = MAE_Decoder()
    # features, backward_indexes = encoder(img)
    # print(forward_indexes.shape)
    # predicted_img, mask = decoder(features, backward_indexes)
    # print(predicted_img.shape)
    # loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75)
    # print(loss)

torch.Size([4, 2, 10])
torch.Size([64, 10, 192]) torch.Size([1, 10, 192])
torch.Size([65, 10, 192])


In [53]:
def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [54]:
class Settings:
    def __init__(self):
        self.seed = 42
        self.batch_size = 128
        self.max_device_batch_size=256
        self.base_learning_rate=1e-3
        self.weight_decay=0.05
        self.total_epoch=1
        self.warmup_epoch=5
        self.pretrained_model_path =None

args = Settings()

batch_size = args.batch_size
load_batch_size = min(args.max_device_batch_size, batch_size)

assert batch_size % load_batch_size == 0
steps_per_update = batch_size // load_batch_size

train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, load_batch_size, shuffle=False, num_workers=4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if args.pretrained_model_path is not None:
    model = torch.load(args.pretrained_model_path, map_location='cpu')
    writer = SummaryWriter(os.path.join('logs', 'cifar10', 'pretrain-cls'))
else:
    model = MAE_ViT()
    writer = SummaryWriter(os.path.join('logs', 'cifar10', 'scratch-cls'))
    
model = ViT_Classifier(model.encoder, num_classes=10).to(device)


Files already downloaded and verified
Files already downloaded and verified


In [55]:
x, y = next(iter(train_dataloader))
x = x.to(device)
logits = model(x)
print(x.shape)

torch.Size([128, 3, 32, 32])


In [56]:
image_size=32
patch_size=2
emb_dim=192
num_layer=12
num_head=3
mask_ratio=0.75

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes
    
pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
pos_embedding.shape

torch.Size([257, 1, 192])

In [57]:
torch.manual_seed(0)

first_conv = torch.nn.Conv2d(in_channels = 3, out_channels = 192, kernel_size = 2, stride = 2)
x_input = torch.rand(10, 3, 32, 32)


In [58]:
x = first_conv(x_input)
print(x.shape)
patches = rearrange(x, 'b c h w -> (h w) b c')
print(patches.shape)

torch.Size([10, 192, 16, 16])
torch.Size([256, 10, 192])


In [59]:
image_size=32
patch_size=2
emb_dim=192
num_layer=12
num_head=3
mask_ratio=0.75

shuffle = PatchShuffle(mask_ratio)
pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
pos_embedding.shape

torch.Size([256, 1, 192])

In [60]:
patches = patches + pos_embedding

In [61]:
patches.shape

torch.Size([256, 10, 192])

In [62]:
output, forward_indexes, backward_indexes = shuffle(patches)

In [63]:
output.shape

torch.Size([64, 10, 192])

In [64]:
patches.shape

torch.Size([256, 10, 192])

In [65]:
X = 10
Y = 4
Z = 6

x = torch.zeros(X,Y,Z)
pos = torch.arange(0,X*Z).view(X,1,Z)

# x =  torch.arange(0,X*Z*Y).view(X,Y,Z)
# pos = torch.zeros(X,1,Z)

x.shape, pos.shape
result = x + pos
X,Y,Z =  result.shape

result.shape

torch.Size([10, 4, 6])

In [66]:
B = 2
C = 3
H = 3
W = 3
x = torch.arange(0,B*C*H*W).view(B,C,H,W)
x

tensor([[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]],

         [[18, 19, 20],
          [21, 22, 23],
          [24, 25, 26]]],


        [[[27, 28, 29],
          [30, 31, 32],
          [33, 34, 35]],

         [[36, 37, 38],
          [39, 40, 41],
          [42, 43, 44]],

         [[45, 46, 47],
          [48, 49, 50],
          [51, 52, 53]]]])

In [67]:
test = rearrange(x, 'b c h w -> b (h w) c')

In [68]:
B = 2
C = 10
H = 4
W = 4

pos = torch.arange(0,B*C*H*W).view(B,C,H,W)
pos

tensor([[[[  0,   1,   2,   3],
          [  4,   5,   6,   7],
          [  8,   9,  10,  11],
          [ 12,  13,  14,  15]],

         [[ 16,  17,  18,  19],
          [ 20,  21,  22,  23],
          [ 24,  25,  26,  27],
          [ 28,  29,  30,  31]],

         [[ 32,  33,  34,  35],
          [ 36,  37,  38,  39],
          [ 40,  41,  42,  43],
          [ 44,  45,  46,  47]],

         [[ 48,  49,  50,  51],
          [ 52,  53,  54,  55],
          [ 56,  57,  58,  59],
          [ 60,  61,  62,  63]],

         [[ 64,  65,  66,  67],
          [ 68,  69,  70,  71],
          [ 72,  73,  74,  75],
          [ 76,  77,  78,  79]],

         [[ 80,  81,  82,  83],
          [ 84,  85,  86,  87],
          [ 88,  89,  90,  91],
          [ 92,  93,  94,  95]],

         [[ 96,  97,  98,  99],
          [100, 101, 102, 103],
          [104, 105, 106, 107],
          [108, 109, 110, 111]],

         [[112, 113, 114, 115],
          [116, 117, 118, 119],
          [120, 121, 122, 

In [69]:
patches = rearrange(pos, 'b c h w -> (h w) b c')
output, forward_indexes, backward_indexes = shuffle(patches)
output

tensor([[[ 14,  30,  46,  62,  78,  94, 110, 126, 142, 158],
         [166, 182, 198, 214, 230, 246, 262, 278, 294, 310]],

        [[  3,  19,  35,  51,  67,  83,  99, 115, 131, 147],
         [170, 186, 202, 218, 234, 250, 266, 282, 298, 314]],

        [[ 10,  26,  42,  58,  74,  90, 106, 122, 138, 154],
         [162, 178, 194, 210, 226, 242, 258, 274, 290, 306]],

        [[  9,  25,  41,  57,  73,  89, 105, 121, 137, 153],
         [167, 183, 199, 215, 231, 247, 263, 279, 295, 311]]])

In [70]:
cls_token = torch.nn.Parameter(torch.ones(1, 2, 10))

In [71]:
test_array = torch.cat([cls_token, output], dim=0)

patches = rearrange(test_array, 't b c -> b t c')


In [72]:
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5 
        torch.manual_seed(0)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        
        # print(B, N, 3, self.num_heads, C // self.num_heads)
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        print(3, B,  self.num_heads, N, C // self.num_heads)
        # qkv = self.qkv(x).view(3, B,  self.num_heads, N, C // self.num_heads)

        qkv = self.qkv(x)
        return qkv
        # q, k, v = qkv[0], qkv[1], qkv[2]   

        # attn = (q @ k.transpose(-2, -1)) * self.scale
        # attn = attn.softmax(dim=-1)
        # attn = self.attn_drop(attn)

        # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # x = self.proj(x)
        # x = self.proj_drop(x)
        return x


attn = Attention(10,2)
attn(patches).shape

3 2 2 5 5


torch.Size([2, 5, 30])

In [73]:
# DIM 2, 5, 30


# OPTION 1
# 2, 5, 30 -> DIM 2 5 3 2 5 -> 3, 2, 2, 5, 5 ( 2, 0, 3, 1, 4 )

# OPTION 2
#  2, 5, 30 ->  DIM 3 2 2 5 5

In [74]:
A = 2
B = 5
C = 30

x = torch.arange(0,A*B*C)

x.view(2,5,3,2,5).permute(2,0,3,1,4)

tensor([[[[[  0,   1,   2,   3,   4],
           [ 30,  31,  32,  33,  34],
           [ 60,  61,  62,  63,  64],
           [ 90,  91,  92,  93,  94],
           [120, 121, 122, 123, 124]],

          [[  5,   6,   7,   8,   9],
           [ 35,  36,  37,  38,  39],
           [ 65,  66,  67,  68,  69],
           [ 95,  96,  97,  98,  99],
           [125, 126, 127, 128, 129]]],


         [[[150, 151, 152, 153, 154],
           [180, 181, 182, 183, 184],
           [210, 211, 212, 213, 214],
           [240, 241, 242, 243, 244],
           [270, 271, 272, 273, 274]],

          [[155, 156, 157, 158, 159],
           [185, 186, 187, 188, 189],
           [215, 216, 217, 218, 219],
           [245, 246, 247, 248, 249],
           [275, 276, 277, 278, 279]]]],



        [[[[ 10,  11,  12,  13,  14],
           [ 40,  41,  42,  43,  44],
           [ 70,  71,  72,  73,  74],
           [100, 101, 102, 103, 104],
           [130, 131, 132, 133, 134]],

          [[ 15,  16,  17,  18,  1

In [75]:
x.view(3,2,2,5,5)

tensor([[[[[  0,   1,   2,   3,   4],
           [  5,   6,   7,   8,   9],
           [ 10,  11,  12,  13,  14],
           [ 15,  16,  17,  18,  19],
           [ 20,  21,  22,  23,  24]],

          [[ 25,  26,  27,  28,  29],
           [ 30,  31,  32,  33,  34],
           [ 35,  36,  37,  38,  39],
           [ 40,  41,  42,  43,  44],
           [ 45,  46,  47,  48,  49]]],


         [[[ 50,  51,  52,  53,  54],
           [ 55,  56,  57,  58,  59],
           [ 60,  61,  62,  63,  64],
           [ 65,  66,  67,  68,  69],
           [ 70,  71,  72,  73,  74]],

          [[ 75,  76,  77,  78,  79],
           [ 80,  81,  82,  83,  84],
           [ 85,  86,  87,  88,  89],
           [ 90,  91,  92,  93,  94],
           [ 95,  96,  97,  98,  99]]]],



        [[[[100, 101, 102, 103, 104],
           [105, 106, 107, 108, 109],
           [110, 111, 112, 113, 114],
           [115, 116, 117, 118, 119],
           [120, 121, 122, 123, 124]],

          [[125, 126, 127, 128, 12