# 1. Run Stable Diffusion to generate some images
# 2. Make a Dataset of these images, and their image embeddings
# 3. Semantic image search via Nearest Neighbor text query embeddings

## Motivating example
  * there are pictures of 
    * a chinchilla
    * a grizzly
    * a bouquet of roses
    * a couple smiling at a salad
    * a concrete brutalist building
  * I want to find all of the pictures closest to "a fuzzy cute animal"


# 0. Imports and prep.

In [1]:
subdir = "image-semantic-search"

! mkdir -p {subdir}

import torch
from diffusers import StableDiffusionPipeline
from datetime import datetime
from PIL import Image
from pathlib import Path
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
prompts = [
    "a chinchilla", "a grizzly bear", "a bouquet of roses",
    "an ethnically diverse couple smiling at a salad", "a concrete brutalist apartment building"
]

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

# 1. Generate a data set.

In [2]:
for prompt in (prompts * 1):
    pipe(prompt).images[0].save(Path(subdir) / f"{prompt} {datetime.now().timestamp()}.jpg")

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

# 2. Compile this data set and compute embeddings.

In [3]:
imgfold = load_dataset("imagefolder", data_dir=Path(subdir))['train']

Downloading and preparing dataset imagefolder/default to /Users/lsb/.cache/huggingface/datasets/imagefolder/default-f8063eea6ab80210/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f...


Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset imagefolder downloaded and prepared to /Users/lsb/.cache/huggingface/datasets/imagefolder/default-f8063eea6ab80210/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
imgfold = imgfold.map(lambda e: {
    "name": Path(e['image'].filename).name,
    "embedding": model.get_image_features(**processor(images=e['image'], return_tensors="pt", padding=True))[0].detach().numpy(),
})

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [5]:
imgfold.add_faiss_index(column="embedding", index_name="flat", string_factory="Flat")

  0%|          | 0/1 [00:00<?, ?it/s]

Dataset({
    features: ['image', 'name', 'embedding'],
    num_rows: 5
})

# 3. Nearest Neighbor search

In [6]:
imgfold.get_nearest_examples(
    "flat",
    model.get_text_features(**tokenizer("high-rise architecture", padding=True, return_tensors="pt"))[0].detach().numpy(),
    k=3,
).examples['name']

['a concrete brutalist apartment building 1679299607.33792.jpg',
 'a chinchilla 1679298989.838047.jpg',
 'an ethnically diverse couple smiling at a salad 1679299455.227876.jpg']

In [7]:
imgfold.get_nearest_examples(
    "flat",
    model.get_text_features(**tokenizer("fuzzy, cute", padding=True, return_tensors="pt"))[0].detach().numpy(),
    k=3,
).examples['name']

['a chinchilla 1679298989.838047.jpg',
 'a concrete brutalist apartment building 1679299607.33792.jpg',
 'a bouquet of roses 1679299300.107625.jpg']

In [8]:
imgfold.get_nearest_examples(
    "flat",
    model.get_text_features(**tokenizer("flower arrangement", padding=True, return_tensors="pt"))[0].detach().numpy(),
    k=3,
).examples['name']

['a bouquet of roses 1679299300.107625.jpg',
 'an ethnically diverse couple smiling at a salad 1679299455.227876.jpg',
 'a concrete brutalist apartment building 1679299607.33792.jpg']