In [None]:
import os
import pathlib
import random
import numpy as np
import pickle

from tqdm.notebook import tqdm
import cv2
from PIL import Image
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

In [None]:
from embedding_model import EmbeddingModel

## Define initial variables

In [None]:
IMGS_FOLDER = 'flickr_images/15K Nocturna Valencia Banco Mediolanum/'
EMBEDDINGS_PATH = 'flickr_images/15K_Nocturna_Valencia_Banco_Mediolanum_nobg_embeddings.pkl'

images_path = pathlib.Path(IMGS_FOLDER)

## Read images

In [None]:
# Read images from disk
imgs_list = [p for p in images_path.glob('*') if p.suffix in ('.jpg','.jpeg','.png')]
print(f'Number of images: {len(imgs_list)}')


In [None]:
# Display a grid of random images

fig, _ = plt.subplots(3,3, figsize=(12,8))
for ax, img_path in zip(fig.axes, random.sample(imgs_list, 9)):
    image = Image.open(img_path)
    image.thumbnail((400,400))
    ax.imshow(image)
    ax.axis('off')

plt.tight_layout()
plt.show()

## Load model

In [None]:
embed_model = EmbeddingModel()

## Get image embeddings

In [None]:
# generate embedding for images in batches
batch_size = 50

# This list could be a generator, but then we would need to provide tqdm with the number of batches as total
batch_list = [imgs_list[i:i+batch_size] for i in range(0, len(imgs_list), batch_size)]
batch_embeddings = [embed_model.encode_images(batch, normalize=False) for batch in tqdm(batch_list, unit='batch')]
img_embeddings_np = np.concatenate(batch_embeddings, axis=0)

# Normalize image embeddings
img_embeddings_np /= np.linalg.norm(img_embeddings_np, ord=2, axis=-1, keepdims=True)

print(img_embeddings_np.shape)

In [None]:
# Save image embeddings in disk. Each item of the dict will contain 
# the name of the image as a key, and the numpy array of the embedding as a value

embed_dict = {img_path.name: embedding_np for img_path, embedding_np in zip(imgs_list, img_embeddings_np)}

with open(EMBEDDINGS_PATH, 'wb') as handle:
    pickle.dump(embed_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)