Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can i exttract features from vision? #72

Closed
mathshangw opened this issue Jun 13, 2021 · 16 comments
Closed

How can i exttract features from vision? #72

mathshangw opened this issue Jun 13, 2021 · 16 comments

Comments

@mathshangw
Copy link

i need to extract features from vision transformer . How can i start ?

@woctezuma
Copy link

woctezuma commented Jun 13, 2021

Start with one of these:

  • https://github.com/facebookresearch/dino/blob/main/eval_linear.py

    dino/eval_linear.py

    Lines 44 to 49 in ba9edd1

    val_transform = pth_transforms.Compose([
    pth_transforms.Resize(256, interpolation=3),
    pth_transforms.CenterCrop(224),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    dino/eval_linear.py

    Lines 201 to 205 in ba9edd1

    intermediate_output = model.get_intermediate_layers(inp, n)
    output = [x[:, 0] for x in intermediate_output]
    if avgpool:
    output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
    output = torch.cat(output, dim=-1)
  • https://github.com/facebookresearch/dino/blob/main/eval_knn.py

    dino/eval_knn.py

    Lines 32 to 37 in ba9edd1

    transform = pth_transforms.Compose([
    pth_transforms.Resize(256, interpolation=3),
    pth_transforms.CenterCrop(224),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    dino/eval_knn.py

    Lines 101 to 104 in ba9edd1

    if multiscale:
    feats = utils.multi_scale(samples, model)
    else:
    feats = model(samples).clone()

    where

    dino/utils.py

    Lines 795 to 809 in ba9edd1

    def multi_scale(samples, model):
    v = None
    for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales
    if s == 1:
    inp = samples.clone()
    else:
    inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False)
    feats = model(inp).clone()
    if v is None:
    v = feats
    else:
    v += feats
    v /= 3
    v /= v.norm()
    return v

    NB: multi-scale is set to False for KNN. The code is generic so that multi-scale can be set to True for image-retrieval.

@mathildecaron31
Copy link
Contributor

Hi @mathshangw

You can take a look at

dino/eval_knn.py

Lines 94 to 138 in ba9edd1

@torch.no_grad()
def extract_features(model, data_loader, use_cuda=True, multiscale=False):
metric_logger = utils.MetricLogger(delimiter=" ")
features = None
for samples, index in metric_logger.log_every(data_loader, 10):
samples = samples.cuda(non_blocking=True)
index = index.cuda(non_blocking=True)
if multiscale:
feats = utils.multi_scale(samples, model)
else:
feats = model(samples).clone()
# init storage feature matrix
if dist.get_rank() == 0 and features is None:
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
if use_cuda:
features = features.cuda(non_blocking=True)
print(f"Storing features into tensor of shape {features.shape}")
# get indexes from all processes
y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
y_l = list(y_all.unbind(0))
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
y_all_reduce.wait()
index_all = torch.cat(y_l)
# share features between processes
feats_all = torch.empty(
dist.get_world_size(),
feats.size(0),
feats.size(1),
dtype=feats.dtype,
device=feats.device,
)
output_l = list(feats_all.unbind(0))
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
output_all_reduce.wait()
# update storage feature matrix
if dist.get_rank() == 0:
if use_cuda:
features.index_copy_(0, index_all, torch.cat(output_l))
else:
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
return features

Let me know if you have any questions

@woctezuma
Copy link

woctezuma commented Jul 16, 2021

Also:

  • https://github.com/facebookresearch/dino/blob/main/eval_copy_detection.py

    dino/eval_copy_detection.py

    Lines 154 to 158 in ba9edd1

    transform = pth_transforms.Compose([
    pth_transforms.Resize((args.imsize, args.imsize), interpolation=3),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    dino/eval_copy_detection.py

    Lines 166 to 175 in ba9edd1

    feats = model.get_intermediate_layers(samples, n=1)[0].clone()
    cls_output_token = feats[:, 0, :] # [CLS] token
    # GeM with exponent 4 for output patch tokens
    b, h, w, d = len(samples), int(samples.shape[-2] / model.patch_embed.patch_size), int(samples.shape[-1] / model.patch_embed.patch_size), feats.shape[-1]
    feats = feats[:, 1:, :].reshape(b, h, w, d)
    feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2)
    feats = nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1. / 4).reshape(b, -1)
    # concatenate [CLS] token and GeM pooled patch tokens
    feats = torch.cat((cls_output_token, feats), dim=1)
  • https://github.com/facebookresearch/dino/blob/main/eval_image_retrieval.py
    transform = pth_transforms.Compose([
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    from eval_knn import extract_features

    NB: multi-scale can be set to True for image-retrieval, but is set to False by default.

@mathildecaron31
Copy link
Contributor

@mathshangw have you been able to extract features as you wanted ? If yes, I let you close the issue. Otherwise, feel free to ask more questions and I will try to help.

@mathshangw
Copy link
Author

sorry for the late reply . I tried the code but are these features for images before classifiying it .. I mean if I need to remove the output or classification layer to get the features . does it will be the same

@woctezuma
Copy link

woctezuma commented Oct 3, 2021

sorry for the late reply . I tried the code but are these features for images before classifiying it .. I mean if I need to remove the output or classification layer to get the features . does it will be the same

In the case of Facebook's DINO (🤗's documentation), contrary to Microsoft's BEiT (🤗's documentation), yes.

As I understood it, Mathilde froze the features and added a linear classification layer on top.
This is visible in the way the weights for the linear classifier are shared independently from the rest of the network.
This is also hinted by the good results obtained with k-NN classification.

However, if you look at other works, e.g. BEiT:

  1. the architecture is slightly different (for layer normalization) for Classification compared to Masked Image Modelling,
  2. the whole network is fine-tuned, so that the whole weights have to be shared again!

I think it is better to see the performance of the network with frozen features, because fine-tuning hides the effect of the pre-training. I hope Mathilde keeps this approach in the future, or at least offers both perspectives, and I wish others would as well.


In a nutshell, if you want to use DINO, you can use the official implementation without worrying about a classification layer.

Or you can use 🤗's implementation, and just extract the [CLS] token as your feature. I have checked in a Github Gist, and you should get similar results with both methods. The only thing which you have to be wary of is the pre-processing, which 🤗 modified, without giving a reason.

@mathshangw
Copy link
Author

mathshangw commented Oct 3, 2021

Thanks a lot for replying .. so excuse me
does that right ?

PATH = 'dino_resnet50_pretrain.pth'
model = dino_resnet50(pretrained=True)
model.load_state_dict(torch.load(PATH),strict=True)

then I can use the model for the images I have to extract the features ?

@woctezuma
Copy link

It is possible that your code is equivalent, but I would follow the README and use this:

import torch
resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

@mathshangw
Copy link
Author

mathshangw commented Oct 3, 2021

i tried your code but got
ImportError: cannot import name 'ViTFeatureExtractor'
does this link for the ViTFeatureExtractor
https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit/feature_extraction_vit.py

@woctezuma
Copy link

That is because you would need to install transformers first.

@mathshangw
Copy link
Author

mathshangw commented Oct 3, 2021

That is because you would need to install transformers first.

i installed it using pip3 install transformers but didn't solve it

@woctezuma
Copy link

woctezuma commented Oct 3, 2021

If you want a minimal example without HuggingFace:

from PIL import Image
import requests

def get_image(url):
  return Image.open(requests.get(url, stream=True).raw)
from torchvision import transforms as pth_transforms

preprocess = pth_transforms.Compose([
        pth_transforms.Resize(256, interpolation=3),
        pth_transforms.CenterCrop(224),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

def get_features(model, image):
  return model(preprocess(image).unsqueeze(0))
import torch

resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
features = get_features(resnet50, get_image(url))

@mathshangw
Copy link
Author

thanks a lot but excuse me how can i get the details about preprocess(image) method

@woctezuma
Copy link

woctezuma commented Oct 3, 2021

The details are in my minimal code snippet above: resize, center-crop, normalization. This is a very simple pre-processing.

@mathshangw
Copy link
Author

The details are in my code snippet: resize, center-crop, normalization. This is a very simple pre-processing.

oh i didn't take my attention .. thanks a lot .. appreciating your help

@fbliman
Copy link

fbliman commented Mar 6, 2024

Hi, I reopen this topic because I am a bit lost in which are the best features to extract for comparing images (looking for similar images, independet of view point)

def load_dino_vit_model(weights_path):
    # Load a pre-trained DINO ViT model
    # Specify the appropriate model name and path as needed
    model_name = 'vit_small_patch16_224'  # Example model name, adjust based on actual use
    model = timm.create_model(model_name, pretrained=False, num_classes=0)  # num_classes=0 for feature extraction
    checkpoint = torch.load(weights_path, map_location='cpu')
    
    # Extract the 'teacher' state dictionary and remove the 'backbone.' prefix from each key
    state_dict = checkpoint['teacher']
    adapted_state_dict = {key.replace('backbone.', ''): value for key, value in state_dict.items()}
    
    model.load_state_dict(adapted_state_dict, strict=False)

    #model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
    model.eval()  # Set the model to evaluation mode
    if torch.cuda.is_available():
        model.cuda()
    return model

I am loading this model and then just do

output = model(image)

this returns a 384 (or 768) dimensial feature. Is this feature the class tokens activations? or it comes from other place?

I think if this is the case it would not be ideal as in contains positional informations, which is not the best for comparing images from different viewpoints.

Also I see that from the teacher model I an not using the mlp head that is used for training and it outputs 60k+ dim for trainingand comparing to the student branch.

So, If I would like to have an image feature (with pseudo-semantinc info, nos positional) in the order of 2..3k dimansional, which would be the best place to get it from

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants