In [1]:
import sys
import os
import requests

import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image
import mae

from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from cka import CudaCKA
from collections import defaultdict

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [7]:
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

ds = ImageFolder(
    "/shared/sets/datasets/vision/IN-100/val/",
    transform=T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(imagenet_mean, imagenet_std)
    ])
)
dl = torch.utils.data.DataLoader(ds, batch_size=5, num_workers=4)

In [13]:
n = torch.tensor([0.5, 0.3, 0.7, 0.1, 0.2])


In [16]:
torch.argsort(n), torch.argsort(torch.argsort(n)), torch.argsort(torch.argsort(torch.argsort(n)))

(tensor([3, 4, 1, 0, 2]), tensor([3, 2, 4, 0, 1]), tensor([3, 4, 1, 0, 2]))

In [8]:
cka = CudaCKA(device)

In [9]:
vit_mae = mae.mae_vit_base_patch16()

In [None]:
block_ratio_to_cka = defaultdict(list)

for (x,y) in dl:
    
    x = x.to(device)
    tokens, mask, ids, x_blocks = vit_mae.forward_encoder(x, 0)
    x_blocks_no_cls = x_blocks[:, :, 1:, :]
    
    x_blocks_ordered = x_blocks_no_cls.gather(
        dim=2, 
        index=ids.unsqueeze(0).unsqueeze(-1).repeat(
            x_blocks_no_cls.shape[0], 1, 1, x_blocks_no_cls.shape[3]
        )
    )
    
    # as a sanity check, forward_encoder twice with the same mask ratio, unshuffle tokens and and obtain 100% CKA
    
    
    
    for mask_ratio in [0.0, ]: #0.1, 0.3,]: # 0.5, 0.7, 0.9]:
        _, m_mask, m_ids, m_x_blocks = vit_mae.forward_encoder(x, mask_ratio)
        m_ids_shuffle = torch.argsort(m_ids)
        x_blocks_ordered_for_m = x_blocks_ordered.gather(
            dim=2,
            index=m_ids.unsqueeze(0).unsqueeze(-1).repeat(
                x_blocks_ordered.shape[0], 1, 1, x_blocks_ordered.shape[3]
            )
        )
        x_blocks_ordered_for_m = x_blocks_ordered_for_m[:, :, :m_x_blocks.shape[2], :]
        
        x_blocks_ordered_for_m = torch.cat([x_blocks[:, :, :1, :], x_blocks_ordered_for_m], dim=2) # re-add the cls token
        
        n_blocks, bs, nt, ts = x_blocks_ordered_for_m.shape
        assert m_x_blocks.shape == x_blocks_ordered_for_m.shape
            
        for block_id in range(len(m_x_blocks)):
            orig_tokens = x_blocks_ordered_for_m[block_id].reshape((bs*nt, ts))
            m_tokens = m_x_blocks[block_id].reshape((bs*nt, ts))
            
            block_ratio_to_cka[(block_id, mask_ratio, "linear")].append(
                cka.linear_CKA(orig_tokens, m_tokens) #TODO
                # cka.linear_CKA(orig_tokens, orig_tokens) #TODO

            )
            
        # m_x_blocks_no_cls = m_x_blocks[:, :, 1:, :]
        # m_x_blocks_ordered = m_x_blocks_no_cls.gather(
        #     dim=2,
        #     index=m_ids.unsqueeze(0).unsqueeze(-1).repeat(
        #         m_x_blocks_no_cls.shape[0], 1, 1, m_x_blocks_no_cls.shape[3]
        #     )
        # )
        
    
    break
    
x.shape, x_blocks.shape, x_blocks_no_cls.shape, x_blocks_ordered.shape, x_blocks_ordered_for_m.shape, m_x_blocks.shape, block_ratio_to_cka

In [28]:
196 * 0.7

137.2

In [27]:
m_x_blocks.shape

torch.Size([13, 5, 138, 768])

In [31]:
x_blocks_ordered_for_m.shape

torch.Size([13, 5, 138, 768])

In [21]:
torch.argsort(torch.tensor([[2,4,3,5,1], [9,4,2,10,3]]))

tensor([[4, 0, 2, 1, 3],
        [2, 4, 1, 0, 3]])

In [17]:
m_ids.shape

torch.Size([5, 196])

In [None]:
x_blocks_no_cls.shape

In [None]:
ids.unsqueeze(-1).repeat(1, 1, x_blocks_no_cls.shape[2]).shape

In [None]:
ids.shape

In [None]:
ids_tryout = torch.tensor([[2,0,1], [1,2,0]]).long()

In [None]:
x_tryout = torch.arange(24).reshape((2,3,4))
# batch - 2
# n_tokens - 3
# token_shape = 4
x_tryout

In [None]:
ids_tryout

In [None]:
x_tryout.shape

In [None]:
# ids_tryout.repeat([1,1,1,4])

In [None]:
x_tryout.gather(dim=1, index=ids_tryout.unsqueeze(-1).repeat(1, 1, x_tryout.shape[2]))

In [None]:
x_tryout[[0, 1], [[2,0,1]]] #[[[2, 0, 1],[1, 2, 0]] ]

In [None]:
ids.unsqueeze(0).shape

In [None]:
ids.shape

In [None]:
torch.stack(x_blocks).shape

In [None]:
tokens, mask, ids, x_blocks = vit_mae.forward_encoder(x, 0.25)
x.shape, ids.shape, tokens.shape

In [None]:
pred = vit_mae.forward_decoder(tokens, )

In [None]:
pe = vit_mae.patch_embed(x)



In [None]:
p = vit_mae.patchify(x)
p.shape

In [None]:
up = vit_mae.unpatchify(p)
up.shape

In [None]:
N, L = 16, 196
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

ids_shuffle
ids_restore

In [None]:
# # load an image


# img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
# # img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851
# img = Image.open(requests.get(img_url, stream=True).raw)
# img = img.resize((224, 224))
# img = np.array(img) / 255.

# assert img.shape == (224, 224, 3)

# # normalize by ImageNet mean and std
# img = img - imagenet_mean
# img = img / imagenet_std

# plt.rcParams['figure.figsize'] = [5, 5]
# show_image(torch.tensor(img))
plt.imshow(
    up[0].permute(1,2,0) * imagenet_std + imagenet_mean)

In [None]:
vit_mae = mae.mae_vit_base_patch16()

In [None]:
vi