In [1]:
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from PIL import Image
import numpy as np
from pycocotools.coco import COCO
import requests
import os

In [2]:
# Load GIT model and processor
processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

embedder = SentenceTransformer('all-MiniLM-L6-v2')
K = 500

def generate_caption(image):
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_caption

def fetch_image(url):
    image = Image.open(requests.get(url, stream=True).raw)
    return image

def get_image_urls(img_ids, coco):
    img_urls = []
    for img_id in img_ids:
        img = coco.loadImgs(img_id)[0]
        img_urls.append(img['coco_url'])
    return img_urls

def embed_text(text,embedder):
    return embedder.encode([text])

def set_coco_object(data_dir):
    ann_file = os.path.join(data_dir, 'annotations/captions_val2017.json')
    coco = COCO(ann_file)
    return coco



In [3]:
coco = set_coco_object('coco-images')
GT_embeddings = np.load('data/coco/coco_embeddings_' + str(K) +'_averaged.npy')
GT_img_ids = [int(line.strip()) for line in open('data/coco/coco_image_ids_' + str(K) +'.txt')]

image_urls = get_image_urls(GT_img_ids, coco)
  
print(len(image_urls))  
print(image_urls)

loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
500
['http://images.cocodataset.org/val2017/000000397133.jpg', 'http://images.cocodataset.org/val2017/000000037777.jpg', 'http://images.cocodataset.org/val2017/000000252219.jpg', 'http://images.cocodataset.org/val2017/000000087038.jpg', 'http://images.cocodataset.org/val2017/000000174482.jpg', 'http://images.cocodataset.org/val2017/000000403385.jpg', 'http://images.cocodataset.org/val2017/000000006818.jpg', 'http://images.cocodataset.org/val2017/000000480985.jpg', 'http://images.cocodataset.org/val2017/000000458054.jpg', 'http://images.cocodataset.org/val2017/000000331352.jpg', 'http://images.cocodataset.org/val2017/000000296649.jpg', 'http://images.cocodataset.org/val2017/000000386912.jpg', 'http://images.cocodataset.org/val2017/000000502136.jpg', 'http://images.cocodataset.org/val2017/000000491497.jpg', 'http://images.cocodataset.org/val2017/000000184791.jpg', 'http://images.cocodataset.org/val2017/000

In [4]:
generated_captions = []
generated_caption_embeddings = []
for img_url in image_urls:
    image = fetch_image(img_url)
    caption = generate_caption(image)
    generated_captions.append(caption)
    generated_caption_embeddings.append(embed_text(caption,embedder))
    
generated_caption_embeddings = np.array(generated_caption_embeddings).squeeze()



In [5]:
np.save('data/git/git_caption_embeddings_' + str(K) +'.npy', generated_caption_embeddings)

with open('data/git/git_captions_' + str(K) +'.txt', 'w') as f:
    for caption in generated_captions:
        f.write(caption + '\n')

In [6]:
print(generated_caption_embeddings.shape)

(500, 384)
