In [1]:
from fastbook import *
from tqdm import tqdm
from PIL import Image

In [2]:
search_queries = [
    'dog memes',
    'animal memes',
    'cat memes',
    'funny animal memes',
    'Cute animal memes',
    'Classic animal memes',
    'Pet fail memes'
    ]

all_urls = []

for query in tqdm(search_queries):
    while True:
        try:
            urls = search_images_ddg(query, max_images=200)
            break
        except:
            print(f'refused by duckduckgo, waiting and trying again...')
            time.sleep(2)
    all_urls += urls

100%|██████████| 7/7 [00:36<00:00,  5.27s/it]


In [4]:
## remove duplicate urls
print(len(all_urls), len(set(all_urls)))
urls = set(all_urls)

1400 1038


In [5]:
## download data in parallel

from concurrent.futures import ThreadPoolExecutor
from typing import Tuple,List

def download(idx_url : List[Tuple[str,str]]) -> None:
    idx,url = idx_url
    try:
        download_url(url, f'./meme2-images/{idx}', show_progress=False, timeout=2)
    except:
        pass

jobs = []
for idx,url in tqdm(enumerate(urls)):
    jobs.append((idx,url))
    
with ThreadPoolExecutor(max_workers=10) as executor:
    list(tqdm(executor.map(download, jobs)))


1038it [00:00, 1412617.64it/s]
1038it [01:13, 14.07it/s]


In [11]:
from IPython.display import display, HTML
for idx in range(5):
    tmp = ''
    for j in range(7):
        tmp+= f"<td><img src='./meme2-images/{idx*7+j}' width=150></td>"
    display(HTML(f"<table><tr>{tmp}</tr></table>"))

In [12]:
## laod encoder model
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

model = SentenceTransformer("clip-ViT-B-32", device='cuda')

In [13]:
def batch_encode(image_paths : str):
    images = []
    cleaned_paths = []
    for path in tqdm(image_paths):
        try:
            images.append(Image.open(path))
            cleaned_paths.append(path)
        except:
            pass
    print(len(images), len(image_paths))
    return model.encode(images, batch_size=1024, show_progress_bar=True), cleaned_paths


In [14]:
## embed data
image_paths = [path for path in glob.glob('./meme2-images/*')]
embeddings, image_paths = batch_encode(image_paths)
len(embeddings)

100%|██████████| 969/969 [00:00<00:00, 5958.95it/s]


943 969


Batches:   0%|          | 0/1 [00:00<?, ?it/s]



943

In [15]:
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest

In [16]:
!docker run -d -p "6333:6333" -p "6334:6334" --name "qdrant-db" qdrant/qdrant:master

0125fbd47544a2a273aa8f684e8b2b60b7e520300051e38322a13c753b73f190


In [17]:
client = QdrantClient('localhost')
collections = client.get_collections()
collections

CollectionsResponse(collections=[])

In [18]:
client.recreate_collection(
    collection_name='meme-images',
    vectors_config=rest.VectorParams(
        size=512,
        distance=rest.Distance.COSINE,
    )
)

True

In [19]:
import uuid
meme_ids = [uuid.uuid4().hex for _ in range(len(embeddings))]

assert len(embeddings) == len(image_paths)
client.upload_collection(
    collection_name="meme-images",
    vectors=list(embeddings),
    payload=[{'image_path':image_path} for image_path in image_paths],
    ids=meme_ids
)

In [20]:
client.count("meme-images")

CountResult(count=943)

In [21]:

def search(query: str):
    search_result = client.search(
    collection_name="meme-images",
    query_vector=model.encode(query),
    limit=3
    )
    print([data.score for data in search_result])
    output_images = [data.payload['image_path'] for data in search_result]
    images_html = "".join(
        f"<td><img src='{path}' width=400></td>"
        for path in output_images
    )
    display(HTML(f"<table><tr>{images_html}</tr></table>"))

In [25]:
search('pixar')

[0.2831471, 0.2637637, 0.2617895]


In [26]:
search('fat cats')

[0.29802775, 0.29693586, 0.29271647]


In [27]:
search('humans')

[0.28624472, 0.28155106, 0.2798286]


In [28]:
search('shy memes')

[0.2883101, 0.28510135, 0.27912968]
