# Multimodal search using CLIP

In [3]:
!pip install --quiet -U lancedb
!pip install --quiet gradio transformers torch torchvision duckdb
! pip install pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985

## First run setup: Download data and pre-process


In [1]:
import io
import PIL
import duckdb
import lancedb

In [3]:
!wget https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz
!tar -xvf diffusiondb_lance.tar.gz
!mv diffusiondb_test rawdata.lance

--2023-07-05 14:41:47--  https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz
Resolving eto-public.s3.us-west-2.amazonaws.com (eto-public.s3.us-west-2.amazonaws.com)... 52.218.224.17, 52.92.192.98, 52.92.131.34, ...
Connecting to eto-public.s3.us-west-2.amazonaws.com (eto-public.s3.us-west-2.amazonaws.com)|52.218.224.17|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6121107645 (5.7G) [application/x-gzip]
Saving to: ‘diffusiondb_lance.tar.gz’


2023-07-05 15:11:05 (3.32 MB/s) - ‘diffusiondb_lance.tar.gz’ saved [6121107645/6121107645]

x diffusiondb_test/
x diffusiondb_test/_versions/
x diffusiondb_test/_latest.manifest
x diffusiondb_test/data/
x diffusiondb_test/data/138fc0d8-a806-4b10-84f8-00dc381afdad.lance
x diffusiondb_test/_versions/1.manifest


## Create / Open LanceDB Table

In [3]:
import pyarrow.compute as pc
import lance

db = lancedb.connect("~/datasets/demo")
if "diffusiondb" in db.table_names():
    tbl= db.open_table("diffusiondb")
else:
    # First data processing and full-text-search index
    data = lance.dataset("rawdata.lance").to_table()
    # remove null prompts
    tbl = db.create_table("diffusiondb", data.filter(~pc.field("prompt").is_null()), mode="overwrite")
    tbl.create_fts_index(["prompt"])

## Create CLIP embedding function for the text

In [4]:
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast

MODEL_ID = "openai/clip-vit-base-patch32"

tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)

def embed_func(query):
    inputs = tokenizer([query], padding=True, return_tensors="pt")
    text_features = model.get_text_features(**inputs)
    return text_features.detach().numpy()[0]

  from .autonotebook import tqdm as notebook_tqdm



## Search functions for Gradio

In [5]:
def find_image_vectors(query):
    emb = embed_func(query)
    code = (
        "import lancedb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        f"embedding = embed_func('{query}')\n"
        "tbl.search(embedding).limit(9).to_df()"
    )
    return (_extract(tbl.search(emb).limit(9).to_df()), code)

def find_image_keywords(query):
    code = (
        "import lancedb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        f"tbl.search('{query}').limit(9).to_df()"
    )
    return (_extract(tbl.search(query).limit(9).to_df()), code)

def find_image_sql(query):
    code = (
        "import lancedb\n"
        "import duckdb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        "diffusiondb = tbl.to_lance()\n"
        f"duckdb.sql('{query}').to_df()"
    )    
    diffusiondb = tbl.to_lance()
    return (_extract(duckdb.sql(query).to_df()), code)

def _extract(df):
    image_col = "image"
    return [(PIL.Image.open(io.BytesIO(row[image_col])), row["prompt"]) for _, row in df.iterrows()]

## Setup Gradio interface

In [6]:
import gradio as gr


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Tab("Embeddings"):
            vector_query = gr.Textbox(value="portraits of a person", show_label=False)
            b1 = gr.Button("Submit")
        with gr.Tab("Keywords"):
            keyword_query = gr.Textbox(value="ninja turtle", show_label=False)
            b2 = gr.Button("Submit")
        with gr.Tab("SQL"):
            sql_query = gr.Textbox(value="SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9", show_label=False)
            b3 = gr.Button("Submit")
    with gr.Row():
        code = gr.Code(label="Code", language="python")
    with gr.Row():
        gallery = gr.Gallery(
                label="Found images", show_label=False, elem_id="gallery"
            ).style(columns=[3], rows=[3], object_fit="contain", height="auto")   
        
    b1.click(find_image_vectors, inputs=vector_query, outputs=[gallery, code])
    b2.click(find_image_keywords, inputs=keyword_query, outputs=[gallery, code])
    b3.click(find_image_sql, inputs=sql_query, outputs=[gallery, code])
    
demo.launch()



Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


