### Libraries

In [1]:
import torch
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModel
from transformers import ImageGPTFeatureExtractor, ImageGPTModel
from PIL import Image
import numpy as np
import os
import time
import json
import faiss

In [2]:
images_folder_path = "../../1) Data_Collection/OID/Dataset/train"
output = "imagegpt_embeddings"

bin_file = f'{output}/imageGPTIndex.bin'

In [3]:
index = faiss.IndexFlatL2(512)

In [4]:
# Load DINOv2 model (replace with the actual loading function or model URL)
def load_imagegpt_model():
    feature_extractor_gpt = ImageGPTFeatureExtractor.from_pretrained('openai/imagegpt-small')
    model_gpt = ImageGPTModel.from_pretrained('openai/imagegpt-small')
    return feature_extractor_gpt, model_gpt

# Preprocess the image using ImageGPT's feature extractor
def preprocess_image(feature_extractor_gpt, image_path):
    image = Image.open(image_path).convert("RGB")  # Ensure image is in RGB mode
    inputs = feature_extractor_gpt(images=image, return_tensors="pt")
    return inputs # Return preprocessed image tensor

# Generate embeddings using ImageGPT
def generate_embeddings(model, image_tensor):
    with torch.no_grad():
        outputs = model(**image_tensor)

    return outputs.last_hidden_state.cpu().numpy().mean(axis=1)

# Save embeddings to a .bin file
def save_embeddings_to_bin(file_path):
    faiss.write_index(index, file_path)


In [5]:
# Main function
def main(image_path, output_path):

    global total_time, total_images
    
    feature_extractor_gpt, model = load_imagegpt_model()


    start_time = time.time()
    
    image_tensor = preprocess_image(feature_extractor_gpt, image_path)
    embeddings = generate_embeddings(model, image_tensor)

    end_time = time.time()
    elapsed_time = end_time - start_time

    total_time += elapsed_time
    total_images += 1

    index.add(embeddings)

In [6]:
indexes = {}
idx = 0

total_time = 0
total_images = 0

for one_category in os.listdir(images_folder_path):
    one_category_path = os.path.join(images_folder_path, one_category)

    all_images_in_category_folder = [x for x in os.listdir(one_category_path) if x.endswith("jpg")]

    for img in all_images_in_category_folder:

        img_path = os.path.join(one_category_path, img)
        
        output_path = os.path.join(output, img[:-3]+"bin")

        main(img_path, output_path)

        indexes[idx] = img_path
        idx += 1

save_embeddings_to_bin(bin_file)



In [7]:
average_time = total_time / total_images if total_images > 0 else 0

print(f"Total Time: {total_time:.2f} seconds")
print(f"Average Time per Image: {average_time:.2f} seconds")

Total Time: 125.17 seconds
Average Time per Image: 2.28 seconds


In [8]:
with open(f'{output}/indices.json', 'w') as file:
    json.dump(indexes, file, indent=4)