In [1]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
from transformers import GPT2Tokenizer, GPT2Model, GPT2Config
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTModel, ViTConfig, CLIPProcessor, CLIPModel
from sentence_transformers import SentenceTransformer, models
from transformers import AutoImageProcessor, AutoModel
from transformers import ConvNextImageProcessor, ConvNextForImageClassification
from transformers import AutoImageProcessor, DetrModel
from transformers import SegformerForSemanticSegmentation
from PIL import Image
import requests
from torchvision.models.feature_extraction import create_feature_extractor
from transformers import SamModel, SamProcessor
from transformers import SegformerModel
from transformers import DPTImageProcessor, DPTForDepthEstimation
from transformers import AutoFeatureExtractor, ResNetForImageClassification
import timm


import torch, random
from tqdm import tqdm
import sys
import re

gpu = 'cuda:1'
device = torch.device(gpu)
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import numpy as np

class FeatureExtractor:
    def __init__(self):
        self.extracted_features = None

    def __call__(self, module, input_, output):
        self.extracted_features = output

def save_embed(model_name, dataset):
    
    if model_name == "convnext":
        processor = ConvNextImageProcessor.from_pretrained("facebook/convnext-base-224-22k")
        model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224-22k").to(device)
    elif model_name == "dinov2":
        vision_model_name = "facebook/dinov2-large"
        processor = AutoImageProcessor.from_pretrained(vision_model_name)
        model = AutoModel.from_pretrained(vision_model_name).to(device)
    elif model_name == "clip":
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
        model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
    elif model_name == "allroberta":
        language_model_name = "all-roberta-large-v1"
        language_model = SentenceTransformer(language_model_name).to(device)
    elif model_name == "detr_resnet_50_encoder":
        image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
        model = DetrModel.from_pretrained("facebook/detr-resnet-50").to(device)
        
        extractor = FeatureExtractor()
        model.encoder.register_forward_hook(extractor)
    elif model_name == "detr_resnet_50_decoder":
        image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
        model = DetrModel.from_pretrained("facebook/detr-resnet-50").to(device)
        
        extractor = FeatureExtractor()
        model.decoder.register_forward_hook(extractor)
    elif model_name == "detr_resnet_50_backbone":
        image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
        model = DetrModel.from_pretrained("facebook/detr-resnet-50").to(device)
        
        extractor = FeatureExtractor()
        model.backbone.conv_encoder.register_forward_hook(extractor)
    elif model_name == "detr_resnet_101_backbone":
        image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-101")
        model = DetrModel.from_pretrained("facebook/detr-resnet-101").to(device)
        
        extractor = FeatureExtractor()
        model.backbone.conv_encoder.register_forward_hook(extractor)
    elif model_name == "detr_resnet_101_encoder":
        image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-101")
        model = DetrModel.from_pretrained("facebook/detr-resnet-101").to(device)
        
        extractor = FeatureExtractor()
        model.encoder.register_forward_hook(extractor)
    elif model_name == "detr_resnet_101_decoder":
        image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-101")
        model = DetrModel.from_pretrained("facebook/detr-resnet-101").to(device)
        
        extractor = FeatureExtractor()
        model.decoder.register_forward_hook(extractor)
    elif model_name == "sam":
        model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
        processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
        
        extractor = FeatureExtractor()
        model.vision_encoder.layers[31].register_forward_hook(extractor)
    elif model_name == "sam_embed":
        model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
        processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
        
        extractor = FeatureExtractor()
        model.shared_image_embedding.register_forward_hook(extractor)
    elif model_name == "segformer":
        processor = AutoImageProcessor.from_pretrained("nvidia/mit-b0")
        model = SegformerModel.from_pretrained("nvidia/mit-b0").to(device)

        extractor = FeatureExtractor()
        model.encoder.register_forward_hook(extractor)
    elif model_name == "segformer_segment":
        processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
        model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to(device)

        extractor = FeatureExtractor()
        model.segformer.encoder.register_forward_hook(extractor)
    elif model_name == "dpt":
        processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
        model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)

        extractor = FeatureExtractor()
        model.dpt.encoder.register_forward_hook(extractor)
    elif model_name == "resnet101":
        feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101")
        model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101").to(device)

        extractor = FeatureExtractor()
        model.resnet.encoder.register_forward_hook(extractor)
    elif model_name == "vit":
        model = timm.create_model(
            'vit_base_patch16_384.augreg_in1k',
            pretrained=True,
            num_classes=0,  # remove classifier nn.Linear
        ).to(device)
        model = model.eval()
        
        # get model specific transforms (normalization, resize)
        data_config = timm.data.resolve_model_data_config(model)
        transform = timm.data.create_transform(**data_config, is_training=False)
        



    if dataset=="coco":
        cap = dset.CocoCaptions(root = '/datasets/coco_2024-01-04_1601/val2017',
                        annFile = '/datasets/coco_2024-01-04_1601/annotations/captions_val2017.json')#,
                        #transform=transforms.Compose([
                            # transforms.Resize((256,256)), 
                            # transforms.RandomResizedCrop((224,224)), 
                            #transforms.PILToTensor()]),)
                            # target_transform=select_first_k_captions)
    elif dataset=="nocaps":
        cap = dset.CocoCaptions(root = '/shared/group/openimages/validation',
                        annFile = 'nocaps_val_4500_captions.json',
                        transform=transforms.Compose([
                            #transforms.Resize((256,256)), 
                            #transforms.RandomResizedCrop((224,224)), 
                            transforms.PILToTensor()]),)
        

    if model_name == "dinov2":
        image_representations = []
    
        for img, target in tqdm(cap):          
            inputs = processor(images=img, return_tensors="pt")
            inputs = inputs.to(device)
            outputs = model(**inputs)
            image_representation = outputs.last_hidden_state.mean(dim=1).detach().cpu()[0]
            image_representations.append(image_representation)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif "resnet101" == model_name:
        image_representations = []
        for img, target in tqdm(cap):
            inputs = feature_extractor(img, return_tensors="pt")
            inputs = inputs.to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            output = extractor.extracted_features.last_hidden_state.reshape((1, 2048, -1)).mean(dim=-1).detach().cpu()[0]
            image_representations.append(output)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif "detr_resnet_50_backbone" == model_name or "detr_resnet_101_backbone" == model_name:
        image_representations = []
        for img, target in tqdm(cap):
            inputs = image_processor(images=img, return_tensors="pt").to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            output = extractor.extracted_features[-1][0].reshape((1, 2048, -1)).mean(dim=-1).detach().cpu()[0]
            image_representations.append(output)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif "detr_resnet" in model_name:
        image_representations = []
    
        for img, target in tqdm(cap):
            inputs = image_processor(images=img, return_tensors="pt").to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            output = extractor.extracted_features.last_hidden_state.mean(dim=1).detach().cpu()[0]
            image_representations.append(output)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif model_name == "vit":
        image_representations = []
    
        for img, target in tqdm(cap):
            input = transform(img).unsqueeze(0).to(device)
            with torch.no_grad():
                output = model(input)[0]  # output is (batch_size, num_features) shaped tensor
            image_representations.append(output)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif model_name == "sam":
        image_representations = []
    
        for img, target in tqdm(cap):
            inputs = processor(img, return_tensors="pt").to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            output = extractor.extracted_features[0].reshape(1, -1, 1280).mean(dim = 1).detach().cpu()[0]
            image_representations.append(output)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif model_name == "sam_embed":
        image_representations = []
    
        for img, target in tqdm(cap):
            inputs = processor(img, return_tensors="pt").to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            output = extractor.extracted_features.reshape(-1, 256).mean(dim = 0).detach().cpu()
            image_representations.append(output)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif model_name == "segformer" or model_name == "segformer_segment":
        image_representations = []
    
        for img, target in tqdm(cap):
            inputs = processor(img, return_tensors="pt").to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            output = extractor.extracted_features.last_hidden_state.reshape((1, 256, -1)).mean(-1).detach().cpu()[0]
            image_representations.append(output)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif model_name == "dpt":
        image_representations = []
    
        for img, target in tqdm(cap):
            inputs = processor(img, return_tensors="pt").to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            output = extractor.extracted_features.last_hidden_state.mean(dim=1).detach().cpu()[0]
            image_representations.append(output)
        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif model_name == "convnext":
        print("here")
        image_representations = []
    
        for img, target in tqdm(cap):
            
            inputs = processor(images=img, return_tensors="pt")
            inputs = inputs.to(device)
            outputs = model.convnext(**inputs)
            #print(outputs.last_hidden_state.shape)
            image_representation = outputs.last_hidden_state.reshape(1, 1024, -1).mean(2).detach().cpu()[0]
            image_representation = image_representation / np.linalg.norm(image_representation, axis=0, keepdims=True)
            image_representations.append(image_representation)

        image_representations_tensor = torch.stack(image_representations)
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif model_name == "clip":
        image_representations = []
        text_representations = []
        
        for img, target in tqdm(cap):
            
            inputs = processor(text=target, images=img, return_tensors="pt", padding=True)
            inputs = inputs.to(device)
            outputs = model(**inputs)
            text_representation = outputs.text_embeds.detach().cpu().squeeze()
            text_representation = text_representation.mean(dim=0).squeeze()
            image_representation = outputs.image_embeds.detach().cpu().squeeze()
        
            text_representations.append(text_representation)
            image_representations.append(image_representation)

        text_representations_tensor = torch.stack(text_representations)
        image_representations_tensor = torch.stack(image_representations)

        torch.save(text_representations_tensor, f'{dataset}_{model_name}_text.pt')
        torch.save(image_representations_tensor, f'{dataset}_{model_name}_img.pt')
    elif model_name == "allroberta":
        text_representations = []

        for img, target in tqdm(cap):
            output = language_model.encode(target)
            text_representation = torch.Tensor(output)
            text_representation = text_representation.mean(dim=0)
            text_representations.append(text_representation)
            
        text_representations_tensor = torch.stack(text_representations)
        torch.save(text_representations_tensor, f'{dataset}_{model_name}_text.pt')

In [4]:
for model in ["vit"]:
    for dataset in ["coco"]:
        save_embed(model, dataset)

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


100%|███████████████████████████████████████████████████████████████████████████████| 5000/5000 [01:40<00:00, 49.75it/s]


In [11]:
for model in ["detr_resnet_101_backbone"]:
    for dataset in ["coco"]:
        save_embed(model, dataset)

loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


100%|███████████████████████████████████████████████████████████████████████████████| 5000/5000 [08:49<00:00,  9.45it/s]


In [10]:
for model in ["convnext"]:
    for dataset in ["coco", "nocaps"]:
        save_embed(model, dataset)

loading annotations into memory...
Done (t=0.06s)
creating index...
index created!
here


100%|███████████████████████████████████████████████████████████████████████████████| 5000/5000 [02:37<00:00, 31.69it/s]


loading annotations into memory...
Done (t=0.11s)
creating index...
index created!
here


100%|███████████████████████████████████████████████████████████████████████████████| 4500/4500 [03:28<00:00, 21.60it/s]


In [30]:
de = save_embed("convnext", "coco")

loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
here


  0%|                                                                                          | 0/5000 [00:00<?, ?it/s]


In [33]:
de.last_hidden_state.shape

torch.Size([1, 1024, 7, 7])

In [26]:
processor = ConvNextImageProcessor.from_pretrained("facebook/convnext-base-224-22k")
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224-22k").to(device)

In [28]:
model.convnext

ConvNextModel(
  (embeddings): ConvNextEmbeddings(
    (patch_embeddings): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (layernorm): ConvNextLayerNorm()
  )
  (encoder): ConvNextEncoder(
    (stages): ModuleList(
      (0): ConvNextStage(
        (downsampling_layer): Identity()
        (layers): Sequential(
          (0): ConvNextLayer(
            (dwconv): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (layernorm): ConvNextLayerNorm()
            (pwconv1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELUActivation()
            (pwconv2): Linear(in_features=512, out_features=128, bias=True)
            (drop_path): Identity()
          )
          (1): ConvNextLayer(
            (dwconv): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (layernorm): ConvNextLayerNorm()
            (pwconv1): Linear(in_features=128, out_features=512, bias=True)
            (ac

In [7]:
vision_model_name = "facebook/convnext-base-224-22k"
processor = AutoImageProcessor.from_pretrained(vision_model_name)
model = AutoModel.from_pretrained(vision_model_name).to(device)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [8]:
cap = dset.CocoCaptions(root = '/shared/group/openimages/validation',
                        annFile = 'nocaps_val_4500_captions.json',
                        transform=transforms.Compose([
                            #transforms.Resize((256,256)), 
                            #transforms.RandomResizedCrop((224,224)), 
                            transforms.PILToTensor()]),)

loading annotations into memory...
Done (t=0.06s)
creating index...
index created!


In [15]:
image_representations = []

for img, target in tqdm(cap):
    
    inputs = processor(images=img, return_tensors="pt")
    inputs = inputs.to(device)
    outputs = model(**inputs)
    image_representation = outputs.last_hidden_state.reshape(1, 1024, -1).mean(2).detach().cpu()[0]
    image_representations.append(image_representation)

100%|███████████████████████████████████████████████████████████████████████████████| 4500/4500 [03:05<00:00, 24.25it/s]


In [14]:
image_representation.shape

torch.Size([1024])

In [16]:
image_representations_tensor = torch.stack(image_representations)
torch.save(image_representations_tensor, f'convnext_nocaps_val.pt')

In [17]:
cap = dset.CocoCaptions(root = '/shared/group/coco/val2017',
                        annFile = '/shared/group/coco/annotations/captions_val2017.json',
                        transform=transforms.Compose([
                            # transforms.Resize((256,256)), 
                            # transforms.RandomResizedCrop((224,224)), 
                            transforms.PILToTensor()]),)
                            # target_transform=select_first_k_captions)

loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


In [18]:
image_representations = []

for img, target in tqdm(cap):
    
    inputs = processor(images=img, return_tensors="pt")
    inputs = inputs.to(device)
    outputs = model(**inputs)
    image_representation = outputs.last_hidden_state.reshape(1, 1024, -1).mean(2).detach().cpu()[0]
    image_representations.append(image_representation)

100%|███████████████████████████████████████████████████████████████████████████████| 5000/5000 [02:25<00:00, 34.32it/s]


In [19]:
image_representations_tensor = torch.stack(image_representations)
torch.save(image_representations_tensor, f'convnext_coco_val.pt')

In [38]:
from transformers import CLIPProcessor, CLIPModel

processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


In [39]:
cap = dset.CocoCaptions(root = '/shared/group/coco/val2017',
                        annFile = '/shared/group/coco/annotations/captions_val2017.json',
                        transform=transforms.Compose([
                            # transforms.Resize((256,256)), 
                            # transforms.RandomResizedCrop((224,224)), 
                            transforms.PILToTensor()]),)
                            # target_transform=select_first_k_captions)

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [51]:
image_representations = []
text_representations = []

for img, target in tqdm(cap):
    
    inputs = processor(text=target, images=img, return_tensors="pt", padding=True)
    inputs = inputs.to(device)
    outputs = model(**inputs)
    text_representation = outputs.text_embeds.detach().cpu().squeeze()
    text_representation = text_representation.mean(dim=0).squeeze()
    image_representation = outputs.image_embeds.detach().cpu().squeeze()

    text_representations.append(text_representation)
    image_representations.append(image_representation)

100%|███████████████████████████████████████████████████████████████████████████████| 5000/5000 [04:42<00:00, 17.70it/s]


In [52]:
text_representations_tensor = torch.stack(text_representations)
image_representations_tensor = torch.stack(image_representations)

torch.save(text_representations_tensor, f'coco_val_clip_vit_large_patch14_text.pt')
torch.save(image_representations_tensor, f'coco_val_clip_vit_large_patch14_img.pt')

In [53]:
cap = dset.CocoCaptions(root = '/shared/group/openimages/validation',
                        annFile = 'nocaps_val_4500_captions.json',
                        transform=transforms.Compose([
                            #transforms.Resize((256,256)), 
                            #transforms.RandomResizedCrop((224,224)), 
                            transforms.PILToTensor()]),)

loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


In [54]:
image_representations = []
text_representations = []

for img, target in tqdm(cap):
    
    inputs = processor(text=target, images=img, return_tensors="pt", padding=True)
    inputs = inputs.to(device)
    outputs = model(**inputs)
    text_representation = outputs.text_embeds.detach().cpu().squeeze()
    text_representation = text_representation.mean(dim=0).squeeze()
    image_representation = outputs.image_embeds.detach().cpu().squeeze()

    text_representations.append(text_representation)
    image_representations.append(image_representation)

100%|███████████████████████████████████████████████████████████████████████████████| 4500/4500 [05:52<00:00, 12.77it/s]


In [55]:
text_representations_tensor = torch.stack(text_representations)
image_representations_tensor = torch.stack(image_representations)

torch.save(text_representations_tensor, f'nocap_val_clip_vit_large_patch14_text.pt')
torch.save(image_representations_tensor, f'nocap_val_clip_vit_large_patch14_img.pt')

In [61]:
text_representations_tensor.shape

torch.Size([4500, 768])

In [62]:
image_representations_tensor.shape

torch.Size([4500, 768])

In [63]:
outputs.text_embeds.shape

torch.Size([10, 768])

In [64]:
outputs['vision_model_output'].last_hidden_state.shape

torch.Size([1, 257, 1024])