New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accessing last layer hidden states or embeddings for models like CrossViT, RegionViT (Extractor doesn't seem to work) #221
Comments
@PrithivirajDamodaran Hi Prithivida! Let me know if 4e62e5f works now |
regionvit can also work, if you pass in a reference to the layer whose output you would like to extract import torch
from vit_pytorch.regionvit import RegionViT
model = RegionViT(
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
window_size = 7, # window size, which should be either 7 or 14
num_classes = 1000, # number of output classes
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
)
# wrap the CrossViT
from vit_pytorch.extractor import Extractor
v = Extractor(model, layer = model.layers[-1][-1])
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 224, 224)
logits, embeddings = v(img)
# there is one extra token due to the CLS token
embeddings # ((1, 512, 7, 7), (1, 512, 1, 1)) |
Thank you, will check and close. Big fan of your work. |
Works fine! so just to be sure, the below tuple for a single image is ((1, 512, 7, 7) - last_layer emb That's a right understanding? |
@PrithivirajDamodaran so what you are seeing is the outputs of those two separate paths, one is for the normal network output, the other is the "regional" tokens |
@PrithivirajDamodaran if you are doing anything downstream i would concat those two together for a 1024 dimensional embedding from einops import reduce
embedding = torch.cat((reduce(fine_embed, 'b c h w -> b c', 'mean'), reduce(region_embed, 'b c h w -> b c', 'mean')), dim = -1) |
excuse me what if i need to remove the last layer of the layer for the classification to get the features before classifying it ? |
is there any help please ? |
How can I access the last layer hidden states aka embeddings of an image from models like CrossViT and RegionViT? The extractor option works only on vanilla ViT.
Please advice
The text was updated successfully, but these errors were encountered: