Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accessing last layer hidden states or embeddings for models like CrossViT, RegionViT (Extractor doesn't seem to work) #221

Closed
PrithivirajDamodaran opened this issue Jun 17, 2022 · 8 comments

Comments

@PrithivirajDamodaran
Copy link

PrithivirajDamodaran commented Jun 17, 2022

How can I access the last layer hidden states aka embeddings of an image from models like CrossViT and RegionViT? The extractor option works only on vanilla ViT.

Please advice

@lucidrains
Copy link
Owner

@PrithivirajDamodaran Hi Prithivida! Let me know if 4e62e5f works now

@lucidrains
Copy link
Owner

regionvit can also work, if you pass in a reference to the layer whose output you would like to extract

import torch
from vit_pytorch.regionvit import RegionViT

model = RegionViT(
    dim = (64, 128, 256, 512),      # tuple of size 4, indicating dimension at each stage
    depth = (2, 2, 8, 2),           # depth of the region to local transformer at each stage
    window_size = 7,                # window size, which should be either 7 or 14
    num_classes = 1000,             # number of output classes
    tokenize_local_3_conv = False,  # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
    use_peg = False,                # whether to use positional generating module. they used this for object detection for a boost in performance
)

# wrap the CrossViT

from vit_pytorch.extractor import Extractor
v = Extractor(model, layer = model.layers[-1][-1])

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 224, 224)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # ((1, 512, 7, 7), (1, 512, 1, 1))

@PrithivirajDamodaran
Copy link
Author

Thank you, will check and close. Big fan of your work.

@PrithivirajDamodaran
Copy link
Author

Works fine! so just to be sure, the below tuple for a single image is

((1, 512, 7, 7) - last_layer emb
(1, 512, 1, 1)) - CLS emb

That's a right understanding?

@lucidrains
Copy link
Owner

@PrithivirajDamodaran so RegionViT is a bit different than the conventional neural net in that it keeps two separate information paths and have them cross attend to each other iirc

what you are seeing is the outputs of those two separate paths, one is for the normal network output, the other is the "regional" tokens

@lucidrains
Copy link
Owner

lucidrains commented Jun 25, 2022

@PrithivirajDamodaran if you are doing anything downstream i would concat those two together for a 1024 dimensional embedding

from einops import reduce
embedding = torch.cat((reduce(fine_embed, 'b c h w -> b c', 'mean'), reduce(region_embed, 'b c h w -> b c', 'mean')), dim = -1)

@mathshangw
Copy link

excuse me what if i need to remove the last layer of the layer for the classification to get the features before classifying it ?

@mathshangw
Copy link

is there any help please ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants