In [51]:
import datetime as dt
import uuid
import pathlib
from PIL import Image

import numpy as np
import torch
import transformers

import qdrant_client as qc
import qdrant_client.models

In [None]:
!wget https://storage.googleapis.com/ads-dataset/subfolder-0.zip
!wget https://storage.googleapis.com/ads-dataset/subfolder-1.zip
!unzip subfolder-0.zip
!unzip subfolder-1.zip
!mkdir -p ./data/ads/images
!mv .0/*.jpg ./data/ads/images
!mv ./1/*.jpg ./data/ads/images
!rmdir ./0
!rmdir ./1
!rm subfolder-0.zip subfolder-1.zip

In [2]:
captioning_model_name = 'Salesforce/blip-image-captioning-base'
embedding_model_name = 'sentence-transformers/all-MiniLM-L6-v2'

In [4]:
captioning_pipeline = transformers.pipeline('image-to-text', model=captioning_model_name)
embedding_model = transformers.AutoModel.from_pretrained(embedding_model_name)
embedding_tokenizer = transformers.AutoTokenizer.from_pretrained(embedding_model_name)

In [46]:
qdrant_url = 'localhost:6333'
qdrant_collection = 'advert_captions'
embedding_size = 384
qdrant = qc.QdrantClient(qdrant_url)

qdrant.recreate_collection(
    collection_name=qdrant_collection,
    vectors_config=qc.models.VectorParams(size=embedding_size,
                                          distance=qc.models.Distance.COSINE)
)

True

In [47]:
img_dir = pathlib.Path(r'.\data\ads\images')
np_embeddings_dir = pathlib.Path(r'.\data\ads\minilm_embeddings')

np_embeddings_dir.mkdir(exist_ok=True, parents=True)

In [48]:
def batched(iterable, batch_size):
    import itertools   
    iterator = iter(iterable)
    while batch := tuple(itertools.islice(iterator, batch_size)):
        yield batch

In [49]:
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


In [57]:
img_paths = img_dir.glob('*.jpg')

total_batches = 0
for img_paths_batch in batched(img_paths, batch_size=4):
    imgs_batch = [Image.open(img_path).convert('RGB') for img_path in img_paths_batch]
    
    with torch.no_grad():
        gen_captions_batch = captioning_pipeline(imgs_batch)
        captions_batch = [c[0]['generated_text'] for c in gen_captions_batch]
        caption_tokens_batch = embedding_tokenizer(captions_batch, padding=True, truncation=True, return_tensors='pt')
        embeddings_batch = embedding_model(**caption_tokens_batch)
        embeddings_batch = mean_pooling(embeddings_batch, caption_tokens_batch['attention_mask'])
        embeddings_batch = torch.nn.functional.normalize(embeddings_batch, p=2, dim=1)

    qdpoints = []
    for img_path, caption, embedding in zip(img_paths_batch, captions_batch, embeddings_batch):
        embedding_np = embedding.cpu().detach().numpy()
        qdpoint = qc.models.PointStruct(id=str(uuid.uuid1()),
                                        vector=embedding_np,
                                        payload={'image': img_path.name,
                                                 'location': img_path,
                                                 'caption': caption,
                                                 'created': str(dt.datetime.now())}
        )
        np.save(np_embeddings_dir / f'{img_path.stem}.npy', embedding_np)
        qdpoints.append(qdpoint)

    qdrant.upsert(
        collection_name=qdrant_collection,
        points=qdpoints
    )

    total_batches += 1
    if total_batches % 100 == 0:
        print('completed embeddings for', total_batches, 'batches')




completed embeddings for 100 batches
completed embeddings for 200 batches
completed embeddings for 300 batches
completed embeddings for 400 batches
completed embeddings for 500 batches
completed embeddings for 600 batches
completed embeddings for 700 batches
completed embeddings for 800 batches
completed embeddings for 900 batches
completed embeddings for 1000 batches
completed embeddings for 1100 batches
completed embeddings for 1200 batches
completed embeddings for 1300 batches
completed embeddings for 1400 batches
completed embeddings for 1500 batches
completed embeddings for 1600 batches
completed embeddings for 1700 batches
completed embeddings for 1800 batches
completed embeddings for 1900 batches
completed embeddings for 2000 batches
completed embeddings for 2100 batches
completed embeddings for 2200 batches
completed embeddings for 2300 batches
completed embeddings for 2400 batches
completed embeddings for 2500 batches
completed embeddings for 2600 batches
completed embeddings 