## Compute image feature vectors

In [1]:
from PIL import Image
import requests
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModel, CLIPTextModel
import torch
import glob
import pandas as pd
from  tqdm import tqdm
from pathlib import Path
# from transformers import CLIPFeatureExtractor, 

### All folders recursively will be checked in this path

In [10]:
# path = Path('images/personal')
path = Path('images/philippines')

In [11]:
BS = 128  # batch size although does not seem to differ too much
model_id = "openai/clip-vit-base-patch32"   # preconfigured with image size = 224: https://huggingface.co/openai/clip-vit-base-patch32/blob/main/preprocessor_config.json
# model_id = "openai/clip-vit-large-patch14-336"  # preconfigured with image size = 336: https://huggingface.co/openai/clip-vit-large-patch14-336/blob/main/preprocessor_config.json

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
model = CLIPModel.from_pretrained(model_id)
model.to(device)
processor = CLIPProcessor.from_pretrained(model_id)

### Compute the embeddings of all images

In [6]:
def get_images_feats(images):
    input_images = processor(text=None, images=images, return_tensors="pt", padding=True).to(device)
    output_images_features = model.get_image_features(**input_images).detach()  # don't keep grad data and avoid run out of memory
    images_embeds = output_images_features / output_images_features.norm(p=2, dim=-1, keepdim=True)  # normalized features
    return images_embeds.cpu()

In [7]:
def get_image_feats(image):
    input_image = processor(text=None, images=image, return_tensors="pt", padding=True).to(device)
    output_image_features = model.get_image_features(**input_image).detach()
    image_embeds = output_image_features / output_image_features.norm(p=2, dim=-1, keepdim=True)  # normalized features
    return image_embeds.cpu()

In [12]:
# For BS = 1 only [For testing only]
feats=[]
paths = []
for fn in tqdm(path.rglob('*.*')):
    try:
        image = Image.open(fn)
    except:
        print(f'Failed to open {fn}')
        continue
    paths.append(fn)
    feats.append(get_image_feats(image))

df = pd.DataFrame(zip(paths, feats), columns=['path', 'features'])

2335it [03:26, 10.51it/s]

Failed to open images/philippines/images_world/flickr_Blue-backed_parrot-68.jpg


5430it [08:01, 15.70it/s]

Failed to open images/philippines/images_world/flickr_Pagong_-177.jpg


9963it [14:34, 19.77it/s]

Failed to open images/philippines/images_world/flickr_Crocodylus_mindorensis-166.jpg


11482it [16:55, 19.15it/s]

Failed to open images/philippines/images_world/flickr_Southeast_Asian_box_turtle-234.jpg


21173it [30:57, 16.90it/s]

Failed to open images/philippines/images_world/flickr_Philippine_crocodile-183.jpg


24974it [36:25, 16.85it/s]

Failed to open images/philippines/images_world/flickr_Southeast_Asian_box_turtle-158.jpg


25487it [37:08, 14.70it/s]

Failed to open images/philippines/images_world/flickr_Southeast_Asian_box_turtle-58.jpg


25536it [37:12, 15.11it/s]

Failed to open images/philippines/images_world/flickr_Southeast_Asian_box_turtle-302.jpg


28745it [41:54, 15.92it/s]

Failed to open images/philippines/images_world/flickr_Malayan_box_turtle-261.jpg


29582it [43:15, 16.58it/s]

Failed to open images/philippines/images_world/flickr_Mago_-60.jpg


31742it [46:24, 14.07it/s]

Failed to open images/philippines/images_world/flickr_Binturong-23.jpg


34409it [50:15,  5.54it/s]

Failed to open images/philippines/images_world/flickr_Green_turtle-453.jpg


37014it [54:00, 10.55it/s]

Failed to open images/philippines/images_world/flickr_Mago_-180.jpg


37964it [55:22, 20.33it/s]

Failed to open images/philippines/images_world/flickr_Blue-naped_parrot-242.jpg


42142it [1:01:08, 12.23it/s]

Failed to open images/philippines/images_world/flickr_Manis_culionensis-2.jpg


46849it [1:08:02, 17.95it/s]

Failed to open images/philippines/images_world/flickr_Southeast_Asian_box_turtle-161.jpg


56988it [1:22:28, 11.52it/s]


### The same but with BS batch size [in my test the same speed]

In [21]:
# feats=[]
# paths = []
# images = []
# idx = 0
# for fn in tqdm(path.rglob('*.*')):
#     try:
#         image = Image.open(fn)
#     except:
#         print(f'Failed to open {fn}')
#         continue
#     images.append(image)
#     paths.append(fn)
#     idx += 1
#     if idx == BS:
#         feats.extend(get_image_feats(images))
#         images = []
#         idx = 0
# if len(images)>0:
#         feats.extend(get_image_feats(images))
#         images = []
# df = pd.DataFrame(zip(paths, feats), columns=['path', 'features'])

511it [00:34, 14.98it/s]


KeyboardInterrupt: 

In [13]:
df.to_pickle(str(path).replace('/', '-')+'.pickle')