# Inspected code from SViTE/vision_transformer.py around line 291
# Note that vit.gumbel() is a Linear layer that projects the patch embeddings into a scalar quantity (it is learned)

In [5]:
import torch
import torch.nn
import torch.nn.functional as F
import torchvision
from SViTE import vision_transformer
from SViTE.vision_transformer import gumbel_softmax

In [6]:
### Copied + Modified from SViTE/vision_transformer.forward_features(...)
def vit_debug_forward_features(vit, x, tau=-1, number=197):
    l1_list = []
    B = x.shape[0]
    x = vit.patch_embed(x)

    cls_tokens = vit.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    x = torch.cat((cls_tokens, x), dim=1)
    x = x + vit.pos_embed
    x = vit.pos_drop(x) # [Batch, token, dim]

    if tau > 0:
        emb_dim = x.shape[2]
        token_number = x.shape[1]
        token_scores = vit.gumbel(x)
        token_scores = token_scores.reshape(B, -1)
        token_mask = gumbel_softmax(F.log_softmax(token_scores, dim=-1), k=number, tau=tau, hard=True)
        print("Post Gumbel-Softmax token_mask.shape: {}".format(token_mask.shape))
        token_mask[:,0] = 1.
        token_mask = token_mask.expand(emb_dim,-1,-1)
        print("Post torch.expand(emb_dim,-1,-1) token_mask.shape: {}".format(token_mask.shape))
        token_mask = token_mask.permute(1,2,0)
        print("Post torch.permute(1,2,0) token_mask.shape: {}".format(token_mask.shape))

        ### Is this true masking??? I'm not quite sure it is
        ### This looks like elementwise multiplication
        x = x * token_mask 
        print("Post elementwise mult x.shape: {}".format(x.shape))

    for blk in vit.blocks:
        x, l1 = blk(x)
        l1_list.append(l1)

    x = vit.norm(x)
    print("Post norm x.shape: {}".format(x.shape))
    return x[:, 0], l1_list

In [7]:
### Emulate the forward_features(...) layer from SViTE/vision_transformer with a dummy network and dummy variables
device = torch.device('cuda:0')
model = vision_transformer.VisionTransformer()
model.to(device)
### (batch, channels, width, height)
input_image = torch.rand(1,3,224,224).to(device)

In [8]:
#x, l1_list = vit_debug_forward_features(model, input_image, tau=1.0, number=100)
x, l1_list = model.forward_features(input_image, tau=1.0, number=100)
print("Post forward_features.shape: {}".format(x.shape))
x = model.head(x)
print("Output shape: {}".format(x.shape))

x post patch_embed shape: torch.Size([1, 196, 768])
x post pos_drop, cls token, and pos_embed shape: torch.Size([1, 197, 768])
embedding dim size: 768
patch count: 197
patch score shape: torch.Size([1, 197])
gumbel_softmax params (tau, hard, k) : (1.0, True, 100)
y_soft.shape: torch.Size([1, 197])
index.shape: torch.Size([1, 100])
y_hard.shape: torch.Size([1, 197])
gumbel_softmax(...) return value shape: torch.Size([1, 197])
patch mask shape: torch.Size([1, 197])
patch mask shape post-expand: torch.Size([1, 197, 768])
elementwise patch mask and patch embedding multiply shape: torch.Size([1, 197, 768])
x output shape: torch.Size([1, 768])
Post forward_features.shape: torch.Size([1, 768])
Output shape: torch.Size([1, 1000])
