# Multimodal search using CLIP

![mmclip](https://www.researchgate.net/publication/363808556/figure/fig2/AS:11431281086053770@1664048343869/Architectures-of-the-designed-machine-learning-approaches-with-OpenAI-CLIP-model.jpg)

### Installing all dependencies

In [2]:
!pip install --quiet -U lancedb
!pip install --quiet gradio transformers torch torchvision duckdb
!pip install pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985

Collecting tantivy@ git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
  Cloning https://github.com/quickwit-oss/tantivy-py to c:\users\kaush\appdata\local\temp\pip-install-ltmzhzqb\tantivy_343fb804edcc44cbbdc4da3de36f142a
  Resolved https://github.com/quickwit-oss/tantivy-py to commit a47fcfb3a6ad3fa2fca76513bd52d840ff15c596
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'error'


  Running command git clone --filter=blob:none --quiet https://github.com/quickwit-oss/tantivy-py 'C:\Users\kaush\AppData\Local\Temp\pip-install-ltmzhzqb\tantivy_343fb804edcc44cbbdc4da3de36f142a'
  error: subprocess-exited-with-error
  
  Preparing metadata (pyproject.toml) did not run successfully.
  exit code: 1
  
  [6 lines of output]
  
  Cargo, the Rust package manager, is not installed or is not on PATH.
  This package requires Rust and Cargo to compile extensions. Install it through
  the system's package manager or via https://rustup.rs/
  
  Checking for Rust toolchain....
  [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

Encountered error while generating package metadata.

See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.


## First run setup: Download data and pre-process


In [2]:
import io
import PIL
import duckdb
import lancedb

In [3]:
!wget https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz
!tar -xvf diffusiondb_lance.tar.gz
!mv diffusiondb_test rawdata.lance

--2023-07-05 14:41:47--  https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz
Resolving eto-public.s3.us-west-2.amazonaws.com (eto-public.s3.us-west-2.amazonaws.com)... 52.218.224.17, 52.92.192.98, 52.92.131.34, ...
Connecting to eto-public.s3.us-west-2.amazonaws.com (eto-public.s3.us-west-2.amazonaws.com)|52.218.224.17|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6121107645 (5.7G) [application/x-gzip]
Saving to: ‘diffusiondb_lance.tar.gz’


2023-07-05 15:11:05 (3.32 MB/s) - ‘diffusiondb_lance.tar.gz’ saved [6121107645/6121107645]

x diffusiondb_test/
x diffusiondb_test/_versions/
x diffusiondb_test/_latest.manifest
x diffusiondb_test/data/
x diffusiondb_test/data/138fc0d8-a806-4b10-84f8-00dc381afdad.lance
x diffusiondb_test/_versions/1.manifest


## Create / Open LanceDB Table

In [3]:
import pyarrow.compute as pc
import lance

db = lancedb.connect("~/datasets/demo")
if "diffusiondb" in db.table_names():
    tbl= db.open_table("diffusiondb")
else:
    # First data processing and full-text-search index
    data = lance.dataset("rawdata.lance/diffusiondb_test").to_table()
    # remove null prompts
    tbl = db.create_table("diffusiondb", data.filter(~pc.field("prompt").is_null()), mode="overwrite")
    tbl.create_fts_index(["prompt"])

## Create CLIP embedding function for the text

In [9]:
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast

MODEL_ID = "openai/clip-vit-base-patch32"

tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)

def embed_func(query):
    inputs = tokenizer([query], padding=True, return_tensors="pt")
    text_features = model.get_text_features(**inputs)
    return text_features.detach().numpy()[0]

In [11]:
tbl.schema
tbl.to_pandas().head()

Unnamed: 0,prompt,seed,step,cfg,sampler,width,height,timestamp,image_nsfw,prompt_nsfw,vector,image
0,"a renaissance portrait of dwayne johnson, art ...",2480545905,50,16.0,k_euler_ancestral,512,768,2022-08-20 05:28:00,0.163488,0.000793,"[0.22208574, 0.045346797, 0.3416304, 0.6416262...",b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...
1,"portrait of a dancing eagle woman, beautiful b...",2250159284,50,9.0,k_lms,512,640,2022-08-20 05:28:00,0.27665,0.00309,"[0.23513708, 0.23905377, -0.25548398, 0.15406,...",b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...
2,"epic 3 d, become legend shiji! gpu mecha contr...",4292948605,50,7.0,k_lms,512,768,2022-08-20 05:28:00,0.090421,0.000533,"[0.13876896, 0.02688741, -0.73428893, -0.00962...",b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...
3,an airbrush painting of cyber war machine scen...,2374713726,50,12.0,k_lms,512,768,2022-08-20 05:29:00,0.078309,0.000597,"[0.44222537, -0.16692133, 0.16401242, 0.335270...",b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...
4,concept art of a silent hill monster. painted ...,2320897141,50,6.0,k_lms,640,512,2022-08-20 05:29:00,0.086802,0.083516,"[0.21429706, -0.18471082, -0.30426037, 0.42390...",b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...



## Search functions for Gradio

In [6]:
def find_image_vectors(query):
    emb = embed_func(query)
    code = (
        "import lancedb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        f"embedding = embed_func('{query}')\n"
        "tbl.search(embedding).limit(9).to_df()"
    )
    return (_extract(tbl.search(emb).limit(9).to_pandas()), code)

def find_image_keywords(query):
    code = (
        "import lancedb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        f"tbl.search('{query}').limit(9).to_df()"
    )
    return (_extract(tbl.search(query).limit(9).to_pandas()), code)

def find_image_sql(query):
    code = (
        "import lancedb\n"
        "import duckdb\n"
        "db = lancedb.connect('~/datasets/demo')\n"
        "tbl = db.open_table('diffusiondb')\n\n"
        "diffusiondb = tbl.to_lance()\n"
        f"duckdb.sql('{query}').to_df()"
    )    
    diffusiondb = tbl.to_lance()
    return (_extract(duckdb.sql(query).to_pandas()), code)

def _extract(df):
    image_col = "image"
    return [(PIL.Image.open(io.BytesIO(row[image_col])), row["prompt"]) for _, row in df.iterrows()]

## Setup Gradio interface

In [7]:
import gradio as gr


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Tab("Embeddings"):
            vector_query = gr.Textbox(value="portraits of a person", show_label=False)
            b1 = gr.Button("Submit")
        with gr.Tab("Keywords"):
            keyword_query = gr.Textbox(value="ninja turtle", show_label=False)
            b2 = gr.Button("Submit")
        with gr.Tab("SQL"):
            sql_query = gr.Textbox(value="SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9", show_label=False)
            b3 = gr.Button("Submit")
    with gr.Row():
        code = gr.Code(label="Code", language="python")
    with gr.Row():
        gallery = gr.Gallery(
                label="Found images", show_label=False, elem_id="gallery"
            ).style(columns=[3], rows=[3], object_fit="contain", height="auto")   
        
    b1.click(find_image_vectors, inputs=vector_query, outputs=[gallery, code])
    b2.click(find_image_keywords, inputs=keyword_query, outputs=[gallery, code])
    b3.click(find_image_sql, inputs=sql_query, outputs=[gallery, code])
    
demo.launch()



Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


