Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is this the right way to do inference? #2

Closed
Suhail opened this issue Apr 17, 2023 · 22 comments
Closed

Is this the right way to do inference? #2

Suhail opened this issue Apr 17, 2023 · 22 comments
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@Suhail
Copy link

Suhail commented Apr 17, 2023

I presume I don't need Normalize?

CleanShot 2023-04-17 at 12 39 54

@Esbenthorius
Copy link

Esbenthorius commented Apr 17, 2023

Not sure if its correct, but hope it helps

import torch
from PIL import Image
import torchvision.transforms as T
import hubconf

dinov2_vits14 = hubconf.dinov2_vits14()

img = Image.open('meta_dog.png')

transform = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.5], std=[0.5]),
])

img = transform(img)[:3].unsqueeze(0)

with torch.no_grad():
features = dinov2_vits14(img, return_patches=True)[0]

print(features.shape)
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA

pca = PCA(n_components=3)
pca.fit(features)

pca_features = pca.transform(features)
pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min())
pca_features = pca_features * 255

plt.imshow(pca_features.reshape(16, 16, 3).astype(np.uint8))
plt.savefig('meta_dog_features.png')

In dinov2/models/vision_transformer.py line 290 add

def forward(self, *args, is_training=False, return_patches=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
elif return_patches:
return ret["x_norm_patchtokens"]
else:
return self.head(ret["x_norm_clstoken"])

input:
meta_dog

visualized features:

meta_dog_features

@patricklabatut
Copy link
Contributor

@Suhail To generate features from the pretrained backbones, just use a transform similar to the standard one used for evaluating on image classification with the typical ImageNet normalization mean and std (see what's used in the code). Also, as noted in the model card, the model can also use image sizes that are multiple of the patch size.

@Suhail
Copy link
Author

Suhail commented Apr 17, 2023

@Suhail To generate features from the pretrained backbones, just use a transform similar to the standard one used for evaluating on image classification with the typical ImageNet normalization mean and std (see what's used in the code). Also, as noted in the model card, the model can also use image sizes that are multiple of the patch size.

Thanks! This is what I used:

image_transforms = T.Compose([
    T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

Let me know if that's wrong though.

@jjennings955
Copy link

Not sure if its correct, but hope it helps

import torch from PIL import Image import torchvision.transforms as T import hubconf

dinov2_vits14 = hubconf.dinov2_vits14()

img = Image.open('meta_dog.png')

transform = T.Compose([ T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.5], std=[0.5]), ])

img = transform(img)[:3].unsqueeze(0)

with torch.no_grad(): features = dinov2_vits14(img, return_patches=True)[0]

print(features.shape) import matplotlib.pyplot as plt import numpy as np from sklearn.decomposition import PCA

pca = PCA(n_components=3) pca.fit(features)

pca_features = pca.transform(features) pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min()) pca_features = pca_features * 255

plt.imshow(pca_features.reshape(16, 16, 3).astype(np.uint8)) plt.savefig('meta_dog_features.png')

In dinov2/models/vision_transformer.py line 290 add

def forward(self, *args, is_training=False, return_patches=False, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret elif return_patches: return ret["x_norm_patchtokens"] else: return self.head(ret["x_norm_clstoken"])

input: meta_dog

visualized features:

meta_dog_features

I found this helpful, but I would say instead of needing to modify the forward function, you can just do dino.forward_features(x)["x_norm_patchtokens"] yourself directly.

@TimDarcet
Copy link
Contributor

TimDarcet commented Apr 18, 2023

@Suhail To generate features from the pretrained backbones, just use a transform similar to the standard one used for evaluating on image classification with the typical ImageNet normalization mean and std (see what's used in the code). Also, as noted in the model card, the model can also use image sizes that are multiple of the patch size.

Thanks! This is what I used:

image_transforms = T.Compose([
    T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

Let me know if that's wrong though.

What you are doing is correct. What you get with the forward method is the CLS token. If you'd like the patch tokens, you can use forward_features, as noted by @jjennings955

@Suhail
Copy link
Author

Suhail commented Apr 18, 2023

@Suhail To generate features from the pretrained backbones, just use a transform similar to the standard one used for evaluating on image classification with the typical ImageNet normalization mean and std (see what's used in the code). Also, as noted in the model card, the model can also use image sizes that are multiple of the patch size.

Thanks! This is what I used:

image_transforms = T.Compose([

T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

])

Let me know if that's wrong though.

What you are doing is correct. What you get with the forward method is the CLS token. If you'd like the patch tokens, you can use forward_features, as noted by @jjennings955

I think what I want is an embedding like CLIP that contains the features/understanding of the image. Is that what I'd get from forward_features?

@woctezuma
Copy link

woctezuma commented Apr 18, 2023

If this is like DINO, any of the two features could be used as an image embedding.


Edit: You can see here how it is done in knn.py and log_regression.py, by simply calling model(samples).float():

features_rank = model(samples).float()

See:

dinov2/dinov2/eval/knn.py

Lines 260 to 264 in fc49f49

logger.info("Extracting features for train set...")
train_features, train_labels = extract_features(
model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
)
logger.info(f"Train features created, shape {train_features.shape}.")

train_features, train_labels = extract_features(
model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
)

dinov2/dinov2/eval/utils.py

Lines 114 to 122 in fc49f49

def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False):
gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
metric_logger = MetricLogger(delimiter=" ")
features, all_labels = None, None
for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
samples = samples.cuda(non_blocking=True)
labels_rank = labels_rank.cuda(non_blocking=True)
index = index.cuda(non_blocking=True)
features_rank = model(samples).float()

@woctezuma
Copy link

woctezuma commented Apr 18, 2023

Please note that linear.py adopts a different approach.

features = self.feature_model.get_intermediate_layers(
images, self.n_last_blocks, return_class_token=True
)

See:

n_last_blocks_list = [1, 4]
n_last_blocks = max(n_last_blocks_list)
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())

def forward(self, images):
with torch.inference_mode():
with self.autocast_ctx():
features = self.feature_model.get_intermediate_layers(
images, self.n_last_blocks, return_class_token=True
)
return features


It was also the case with DINO:

You could also do fancier stuff, e.g. "concatenate [CLS] token and GeM pooled patch tokens", as with DINO's copy detection.

@patricklabatut patricklabatut added the documentation Improvements or additions to documentation label Apr 18, 2023
@patricklabatut patricklabatut self-assigned this Apr 18, 2023
@Elsword016
Copy link

Elsword016 commented Apr 19, 2023

How about this??

img = Image.open('')
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
input_tensor = transform(img)
input_batch = input_tensor.unsqueeze(0).cuda()
with torch.no_grad():
output =dinov2_vits14.get_intermediate_layers(input_batch)

the output is a tuple of intermediate feature maps. Then you can select which features you want from the tuple, and then you can try K-means etc etc

@woctezuma
Copy link

woctezuma commented Apr 19, 2023

Yes, get_intermediate_layers() allows different approaches. This is similar to what is done in linear.py as mentioned above.

You could also use GeM pooled patch tokens with this output, as in eval_copy_detection.py for DINO (v1).

@Suhail
Copy link
Author

Suhail commented Apr 21, 2023

Sounds like this is all I need to do to get a features embedding: dino_emb = dinov2_vitg14(t_img.unsqueeze(0))

@patricklabatut
Copy link
Contributor

Closing as this seems resolved (and using #53 to keep track of documentation needs on feature extraction).

@aaiguy
Copy link

aaiguy commented May 3, 2023

hello, How to train nearest neighbors model on extracted embeddings of images from different classes of folders using dinov2 model and retrieve nearest similar image for query image ?
I tried below approach using sklearn nearest neighbors

import torch
from sklearn.neighbors import NearestNeighbors 
import pickle
from PIL import Image
import torchvision.transforms as T
import os 
# import hubconf
import tqdm
from tqdm import tqdm_notebook
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print('device:',device)
# dinov2_vits14 = hubconf.dinov2_vits14()
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
dinov2_vits14.to(device)
def extract_features(filename):
    img = Image.open(filename)

    transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5]),
    ])

    img = transform(img)[:3].unsqueeze(0)

    with torch.no_grad():
        features = dinov2_vits14(img.to('cuda'))[0]

    # print(features.shape)
    return features.numpy()

extensions = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def get_file_list(root_dir):
    file_list = []
    for root, directories, filenames in os.walk(root_dir):
        for filename in filenames:
            if any(ext in filename for ext in extensions):
                filepath = os.path.join(root, filename)
                if os.path.exists(filepath):
                  file_list.append(filepath)
                else:
                  print(filepath)
    return file_list
def extract_features(filename):
    img = Image.open(filename)

    transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5]),
    ])

    img = transform(img)[:3].unsqueeze(0)

    with torch.no_grad():
        features = dinov2_vits14(img.to('cuda'))[0]

    # print(features.shape)
    return features.cpu().numpy()

extensions = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def get_file_list(root_dir):
    file_list = []
    for root, directories, filenames in os.walk(root_dir):
        for filename in filenames:
            if any(ext in filename for ext in extensions):
                filepath = os.path.join(root, filename)
                if os.path.exists(filepath):
                  file_list.append(filepath)
                else:
                  print(filepath)
    return file_list

# # path to the your datasets
root_dir = 'image_folder' 
filenames = sorted(get_file_list(root_dir))
print('Total files :', len(filenames))
feature_list = []
for i in tqdm.tqdm(range(len(filenames))):
    feature_list.append(extract_features(filenames[i]))
pickle.dump(feature_list,open('dino-all-feature-list.pickle','wb'))
pickle.dump(filenames,open('dino-all-filenames.pickle','wb'))
neighbors = NearestNeighbors(n_neighbors=5, algorithm='brute',metric='euclidean').fit(feature_list)
# Save the model to a file
with open('dino-all-neighbors2.pkl', 'wb') as f:
    pickle.dump(neighbors, f)

with above dinov2 based trained model i get around 70% accuracy on testing data for retrieving similar class images, is there a way to improve my approach in better manner to improvise the accuracy ??

@woctezuma
Copy link

woctezuma commented May 3, 2023

First, for k-NN classification, have a look at knn.py.

Second, after a quick look at your code, I would suggest to try a different metric, e.g. cosine instead of euclidean.

Third, I believe you should use a different image pre-processing (cf. transform in your code). Copy the one used for DINOv2.

For further question, I would suggest to create a separate Github issue for this purpose.

@aaiguy
Copy link

aaiguy commented May 3, 2023

hey thanks, i will look into it.
where can I find the cv.transform used for DINOv2 one?

@woctezuma
Copy link

woctezuma commented May 3, 2023

hey thanks, i will look into it. where can I find the cv.transform used for DINOv2 one?

It is mentioned above: #2 (comment)

transforms_list = [
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),

It is similar to what you did but some values may differ, e.g.:

  • resizing to 256 resolution before center-cropping at 224 resolution,

resize_size: int = 256,
interpolation=transforms.InterpolationMode.BICUBIC,
crop_size: int = 224,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,

  • normalizing with different mean and std.

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

@aaiguy
Copy link

aaiguy commented May 5, 2023

Not sure if its correct, but hope it helps

import torch from PIL import Image import torchvision.transforms as T import hubconf

dinov2_vits14 = hubconf.dinov2_vits14()

img = Image.open('meta_dog.png')

transform = T.Compose([ T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.5], std=[0.5]), ])

img = transform(img)[:3].unsqueeze(0)

with torch.no_grad(): features = dinov2_vits14(img, return_patches=True)[0]

print(features.shape) import matplotlib.pyplot as plt import numpy as np from sklearn.decomposition import PCA

pca = PCA(n_components=3) pca.fit(features)

pca_features = pca.transform(features) pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min()) pca_features = pca_features * 255

plt.imshow(pca_features.reshape(16, 16, 3).astype(np.uint8)) plt.savefig('meta_dog_features.png')

In dinov2/models/vision_transformer.py line 290 add

def forward(self, *args, is_training=False, return_patches=False, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret elif return_patches: return ret["x_norm_patchtokens"] else: return self.head(ret["x_norm_clstoken"])

input: meta_dog

visualized features:

meta_dog_features

how to visualize feature like this ?? ,
I tried as below

import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA


test_img = r"image.png"

features = extract_features_new(test_img)

pca = PCA(n_components=3)
pca.fit(features)

pca_features = pca.transform(features)
pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min())
pca_features = pca_features * 255

plt.imshow(pca_features.reshape(16, 16, 3).astype(np.uint8))


with this i'm getting error

Output exceeds the size limit. Open the full output data in a text editor---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[133], line 11
      8 features = extract_features_new(test_img)
     10 pca = PCA(n_components=3)
---> 11 pca.fit(features)
     13 pca_features = pca.transform(features)
     14 pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min())
     
     
ValueError: Expected 2D array, got 1D array instead:
array=[ 0.48167408 -2.6765716  -1.8200531  ... -2.971799    1.1348227
 -1.9918671 ].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.    

feature shape is 1024, how would i fix this ?

@woctezuma
Copy link

woctezuma commented May 5, 2023

@XiaominLi1997
Copy link

@Suhail To generate features from the pretrained backbones, just use a transform similar to the standard one used for evaluating on image classification with the typical ImageNet normalization mean and std (see what's used in the code). Also, as noted in the model card, the model can also use image sizes that are multiple of the patch size.

Hi, it seems that I can get feature embedding of [1, 256, 384] for an image, then I reshape it to [1, 16, 16, 384], I can get the visualized features. But, how can I get a feature map with a larger resolution because I wonna get finer info such as texture.

@purnasai
Copy link

Hi @XiaominLi1997, Use Larger models.

  • feat_dim = 384 # vits14
  • feat_dim = 768 # vitb14
  • feat_dim = 1024 # vitl14
  • feat_dim = 1536 # vitg14

So, you can use Vitg14 & Also Increase Input Image size in Multiple of 14. Ex: 518pix( i.e 14patchsize * 37pixels).
Hope this helps.

@ydove0324
Copy link

T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

Why do you need to center to (0.485,0.456,0.406)? Is anywhere mentioning this?

@charchit7
Copy link

@ydove0324 this is standard imagenet mean used for training. It's a common practice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

13 participants