In [1]:
from fastbook import *
from tqdm import tqdm
from PIL import Image

In [2]:
search_queries = [
    'dog memes',
    'animal memes',
    'cat memes',
    'funny animal memes',
    'Cute animal memes',
    'Classic animal memes',
    'Pet fail memes'
    ]

all_urls = []

for query in tqdm(search_queries):
    while True:
        try:
            urls = search_images_ddg(query, max_images=200)
            break
        except:
            print(f'refused by duckduckgo, waiting and trying again...')
            time.sleep(2)
    all_urls += urls

100%|██████████| 7/7 [00:36<00:00,  5.21s/it]


In [3]:
## remove duplicate urls
print(len(all_urls), len(set(all_urls)))
urls = set(all_urls)

1400 1058


In [4]:
## download data in parallel

from concurrent.futures import ThreadPoolExecutor
from typing import Tuple,List

def download(idx_url : List[Tuple[str,str]]) -> None:
    idx,url = idx_url
    try:
        download_url(url, f'./meme2-images/{idx}', show_progress=False, timeout=2)
    except:
        pass

jobs = []
for idx,url in tqdm(enumerate(urls)):
    jobs.append((idx,url))
    
with ThreadPoolExecutor(max_workers=10) as executor:
    list(tqdm(executor.map(download, jobs)))


1058it [00:00, 2235553.47it/s]
1058it [01:12, 14.58it/s]


In [21]:
from IPython.display import display, HTML
for idx in range(5):
    display(HTML(f"<td><img src='./meme2-images/{idx}' width=150></td>"))

In [7]:
## laod encoder model
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

model = SentenceTransformer("clip-ViT-B-32", device='cuda')

In [8]:
def batch_encode(image_paths : str):
    images = []
    cleaned_paths = []
    for path in tqdm(image_paths):
        try:
            images.append(Image.open(path))
            cleaned_paths.append(path)
        except:
            pass
    print(len(images), len(image_paths))
    return model.encode(images, batch_size=1024, show_progress_bar=True), cleaned_paths


In [10]:
## embed data
image_paths = [path for path in glob.glob('./meme2-images/*')]
embeddings, image_paths = batch_encode(image_paths)
len(embeddings)

100%|██████████| 1024/1024 [00:00<00:00, 5187.58it/s]


985 1024


Batches:   0%|          | 0/1 [00:00<?, ?it/s]



985

In [11]:
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest

In [12]:
client = QdrantClient('localhost')
collections = client.get_collections()
collections

CollectionsResponse(collections=[CollectionDescription(name='meme-images')])

In [13]:
client.recreate_collection(
    collection_name='meme-images',
    vectors_config=rest.VectorParams(
        size=512,
        distance=rest.Distance.COSINE,
    )
)

True

In [14]:
import uuid
meme_ids = [uuid.uuid4().hex for _ in range(len(embeddings))]

assert len(embeddings) == len(image_paths)
client.upload_collection(
    collection_name="meme-images",
    vectors=list(embeddings),
    payload=[{'image_path':image_path} for image_path in image_paths],
    ids=meme_ids
)

In [15]:
client.count("meme-images")

CountResult(count=985)

In [16]:

def search(query: str):
    search_result = client.search(
    collection_name="meme-images",
    query_vector=model.encode(query),
    limit=3
    )
    print([data.score for data in search_result])
    output_images = [data.payload['image_path'] for data in search_result]
    images_html = "".join(
        f"<td><img src='{path}' width=400></td>"
        for path in output_images
    )
    display(HTML(f"<table><tr>{images_html}</tr></table>"))

In [17]:
search('dogs')

[0.3084375, 0.30787122, 0.30299708]


In [18]:
search('cats')

[0.30072117, 0.2999292, 0.29837263]


In [19]:
search('humans')

[0.28624472, 0.28158134, 0.28155106]


In [29]:
search('shy memes')

[0.29159823, 0.2883101, 0.28807914]
