In [15]:
import pandas as pd 
import numpy as np
import urllib.request
from transformers import CLIPProcessor, CLIPModel

from IPython.display import display, HTML
import ipywidgets as widgets

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
def compute_text_embeddings(list):
    inputs = processor(text=list, return_tensors="pt", padding=True)
    return model.get_text_features(**inputs) #embeddings

urllib.request.urlretrieve('https://huggingface.co/spaces/vivien/clip/raw/main/data.csv', 'data.csv')
urllib.request.urlretrieve('https://huggingface.co/spaces/vivien/clip/raw/main/data2.csv', 'data2.csv')
urllib.request.urlretrieve('https://huggingface.co/spaces/vivien/clip/resolve/main/embeddings-vit-base-patch32.npy', 'embeddings.npy')
urllib.request.urlretrieve('https://huggingface.co/spaces/vivien/clip/resolve/main/embeddings2-vit-base-patch32.npy', 'embeddings2.npy')

('embeddings2.npy', <http.client.HTTPMessage at 0x32c4e2630>)

In [None]:
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
for k in [0, 1]: # helps with cosine similarity
  embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}

def get_html(url_list, height=200):
    html = "<div style='margin-top: 20px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url, title, link in url_list:
        html2 = f"<img title='{title}' style='height: {height}px; margin-bottom: 10px' src='{url}'>"
        if len(link) > 0:
            html2 = f"<a href='{link}' target='_blank'>" + html2 + "</a>"
        html = html + html2
    html += "</div>"
    return html

In [17]:
query = widgets.Text(layout=widgets.Layout(width='400px'))
dataset =widgets.Dropdown(
    options=['Unsplash', 'Movies'],
    value='Unsplash'
)
button = widgets.Button(description="Search")
output = widgets.Output()

display(widgets.HBox([query, button, dataset],
                     layout=widgets.Layout(justify_content='center')),
        output)

def image_search(query, n_results=24):
    text_embeddings = compute_text_embeddings([query]).detach().numpy()
    k = 0 if dataset.value == 'Unsplash' else 1
    results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
    return [(df[k].iloc[i]['path'],
             df[k].iloc[i]['tooltip'] + source[k],
             df[k].iloc[i]['link']) for i in results]

def on_button_clicked(b):
    if len(query.value) > 0:
        results = image_search(query.value)
        output.clear_output()
        with output:
            display(HTML(get_html(results)))

button.on_click(on_button_clicked)
dataset.observe(on_button_clicked, names='value')

HBox(children=(Text(value='', layout=Layout(width='400px')), Button(description='Search', style=ButtonStyle())â€¦

Output()