# Multimodal Search

## Dataset Import + Embeddings computation

In [16]:
# https://huggingface.co/datasets/detection-datasets/coco
import torch
from datasets import load_dataset
from transformers import AutoFeatureExtractor, CvtForImageClassification


dataset = load_dataset("detection-datasets/coco")
# dataset = load_dataset("beans")

# https://huggingface.co/blog/image-similarity

from transformers import AutoFeatureExtractor, AutoModel

# model_ckpt = "nateraw/vit-base-beans"
# extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
# model = AutoModel.from_pretrained(model_ckpt)

extractor = AutoFeatureExtractor.from_pretrained('microsoft/cvt-21-384')
model = AutoModel.from_pretrained('microsoft/cvt-21-384')
device = "mps"


Found cached dataset parquet (/Users/baptiste/.cache/huggingface/datasets/detection-datasets___parquet/detection-datasets--coco-64ef6d5414f6b8df/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 2/2 [00:00<00:00, 10.84it/s]
Some weights of the model checkpoint at microsoft/cvt-21-384 were not used when initializing CvtModel: ['layernorm.bias', 'classifier.bias', 'layernorm.weight', 'classifier.weight']
- This IS expected if you are initializing CvtModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CvtModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[0.229, 0.224, 0.225]

In [35]:
image = dataset["train"][0]["image"]
image = extractor(images=image, return_tensors="pt")["pixel_values"].to(device)
print(image.shape)

embeddings = model(image)
print(embeddings.last_hidden_state[:, 0].shape)



torch.Size([1, 3, 384, 384])
torch.Size([1, 24, 24])


In [20]:
import torchvision.transforms as T
import os

# Data transformation chain.
transformation_chain = T.Compose(
    [
        # use https://huggingface.co/microsoft/cvt-21-384 
        T.Resize(384),
        T.CenterCrop(384),
        T.ToTensor(),
        T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
        T
    ]
)

def extract_embeddings(model: torch.nn.Module):
    """Utility to compute embeddings."""
    device = model.device

    def pp(batch):
        images = batch["image"]
        images = torch.stack([transformation_chain(image) for image in images]).to(device)
        print(images.shape)
        # bap i think create a stack to apply the transformation chain
        # image_batch_transformed = torch.stack(
        #     [transformation_chain(image) for image in images]
        # )
        # new_batch = {"pixel_values": image_batch_transformed.to(device)}
        with torch.no_grad():
            embeddings = model(images).last_hidden_state[:, 0].cpu()

            # embeddings = model(**new_batch)
            # embeddings = embeddings.last_hidden_state.cpu()
            # print(embeddings.keys())
            # print(embeddings.last_hidden_state.shape)
            # 
        return {"embeddings": embeddings}

    return pp


batch_size = 16
extract_fn = extract_embeddings(model.to(device))

save_path = "./data/coco_embeddings"

# if coco embedding dataset is not on disk, do map
if not os.path.exists(save_path):
    dataset_emb = dataset.map(extract_fn, batched=True, batch_size=batch_size)
    dataset_emb.save_to_disk(save_path)
else:
    dataset_emb = load_dataset(save_path)

Map:   0%|          | 0/117266 [00:00<?, ? examples/s]

torch.Size([16, 3, 384, 384])


                                                                 

RuntimeError: output with shape [1, 384, 384] doesn't match the broadcast shape [3, 384, 384]

In [14]:
# https://huggingface.co/course/chapter5/6?fw=pt#using-faiss-for-efficient-similarity-search

dataset_emb.add_faiss_index(column="embeddings")
question_embedding = dataset_emb.tensor(dataset_emb["embeddings"][0]).cpu().detach().numpy()
# print(torch.tensor(question_embedding))



scores, samples = dataset_emb.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

print(samples["image_file_path"])

100%|██████████| 1/1 [00:00<00:00, 1185.17it/s]

['/Users/baptiste/.cache/huggingface/datasets/downloads/extracted/fb3c1511c735d9d1ddc5c6e6a082fa39ab9b92605bd82d8dbc292c1a0dffb1a5/train/bean_rust/bean_rust_train.214.jpg', '/Users/baptiste/.cache/huggingface/datasets/downloads/extracted/fb3c1511c735d9d1ddc5c6e6a082fa39ab9b92605bd82d8dbc292c1a0dffb1a5/train/bean_rust/bean_rust_train.216.jpg', '/Users/baptiste/.cache/huggingface/datasets/downloads/extracted/fb3c1511c735d9d1ddc5c6e6a082fa39ab9b92605bd82d8dbc292c1a0dffb1a5/train/bean_rust/bean_rust_train.47.jpg', '/Users/baptiste/.cache/huggingface/datasets/downloads/extracted/fb3c1511c735d9d1ddc5c6e6a082fa39ab9b92605bd82d8dbc292c1a0dffb1a5/train/bean_rust/bean_rust_train.334.jpg', '/Users/baptiste/.cache/huggingface/datasets/downloads/extracted/fb3c1511c735d9d1ddc5c6e6a082fa39ab9b92605bd82d8dbc292c1a0dffb1a5/train/bean_rust/bean_rust_train.79.jpg']





## Pix2Pix

In [None]:
# https://huggingface.co/timbrooks/instruct-pix2pix

import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler

model_id = "timbrooks/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
pipe.to("mps")
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

url = "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/main/imgs/example.jpg"
def download_image(url):
    image = PIL.Image.open(requests.get(url, stream=True).raw)
    image = PIL.ImageOps.exif_transpose(image)
    image = image.convert("RGB")
    return image
image = download_image(url)

prompt = "turn him into cyborg"
images = pipe(prompt, image=image, num_inference_steps=10, image_guidance_scale=1).images
images[0]

  from .autonotebook import tqdm as notebook_tqdm
Fetching 15 files: 100%|██████████| 15/15 [00:00<00:00, 219980.98it/s]


KeyboardInterrupt: 

## UI

In [None]:
import gradio as gr

def image_classifier(inp):
    return {'cat': 0.3, 'dog': 0.7}

# 3 images return the modified image, the most similiar to the input image, and the most similiar to the modified images
# todo make this return n most similiar images
demo = gr.Interface(fn=image_classifier, inputs=["image", "text"], outputs=["image", "image", "image"])
demo.launch()



Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


