In [1]:
import pandas as pd, numpy as np
import os
import urllib.request
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel, logging

from IPython.display import display, Markdown, HTML, clear_output
import ipywidgets as widgets

logging.get_verbosity = lambda: logging.NOTSET
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clear_output()

def compute_text_embeddings(list_of_strings):
    inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
    return model.get_text_features(**inputs)

In [2]:
urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1bt1O-iArKuU9LGkMV1zUPTEHZk8k7L65', '../data/data.csv')
urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=19aVnFBY-Rc0-3VErF_C7PojmWpBsb5wk', '../data/data2.csv')
urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1onKr-pfWb4l6LgL-z8WDod3NMW-nIJxE', '../data/embeddings.npy')
urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1KbwUkE0T8bpnHraqSzTeGGV4-TZO_CFB', '../data/embeddings2.npy')


('../data/embeddings2.npy', <http.client.HTTPMessage at 0x10476bdf0>)

In [3]:
df = {0: pd.read_csv('../data/data.csv'), 1: pd.read_csv('../data/data2.csv')}
embeddings = {0: np.load('../data/embeddings.npy'), 1: np.load('../data/embeddings2.npy')}
for k in [0, 1]:
  embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}

In [6]:
df[0]

Unnamed: 0,path,tooltip,link
0,https://images.unsplash.com/uploads/1411949294...,"""Woman exploring a forest"" by Michelle Spencer",https://unsplash.com/photos/XMyPniM9LF0
1,https://images.unsplash.com/photo-141633941111...,"""Succulents in a terrarium"" by Jeff Sheldon",https://unsplash.com/photos/rDLBArZUl1c
2,https://images.unsplash.com/photo-142014251503...,"""Rural winter mountainside"" by John Price",https://unsplash.com/photos/cNDGZ2sQ3Bo
3,https://images.unsplash.com/photo-141487280988...,"""Poppy seeds and flowers"" by Kris Atomic",https://unsplash.com/photos/iuZ_D1eoq9k
4,https://images.unsplash.com/photo-141700759404...,"""Silhouette near dark trees"" by Jonas Eriksson",https://unsplash.com/photos/BeD3vjQ8SI0
...,...,...,...
24995,https://images.unsplash.com/photo-159300793778...,Photo by De an Sun,https://unsplash.com/photos/c7OrOMxrurA
24996,https://images.unsplash.com/photo-159296761254...,"""Pearl earrings and seashells"" by Content Pixie",https://unsplash.com/photos/15IuQ5a0Qwg
24997,https://images.unsplash.com/photo-159299937329...,Photo by Maurits Bausenhart,https://unsplash.com/photos/w8nrcXz8pwk
24998,https://images.unsplash.com/photo-159192792878...,"""Floral truck in the streets of Rome"" by Keith...",https://unsplash.com/photos/n1jHrRhehUI


In [7]:
df[1]

Unnamed: 0,path,tooltip,link
0,http://image.tmdb.org/t/p/w780/5hNcsnMkwU2LknL...,Dilwale Dulhania Le Jayenge,https://www.themoviedb.org/movie/19404
1,http://image.tmdb.org/t/p/w780/gNBCvtYyGPbjPCT...,Dilwale Dulhania Le Jayenge,https://www.themoviedb.org/movie/19404
2,http://image.tmdb.org/t/p/w780/iNh3BivHyg5sQRP...,The Shawshank Redemption,https://www.themoviedb.org/movie/278
3,http://image.tmdb.org/t/p/w780/9Xw0I5RV2ZqNLpu...,The Shawshank Redemption,https://www.themoviedb.org/movie/278
4,http://image.tmdb.org/t/p/w780/kXfqcdQKsToO0OU...,The Shawshank Redemption,https://www.themoviedb.org/movie/278
...,...,...,...
8165,http://image.tmdb.org/t/p/w780/hdypWIqmK47ACp1...,Every Day,https://www.themoviedb.org/movie/465136
8166,http://image.tmdb.org/t/p/w780/amycp73vQvnYmQX...,Every Day,https://www.themoviedb.org/movie/465136
8167,http://image.tmdb.org/t/p/w780/jXGT06zsyhNzrLy...,Every Day,https://www.themoviedb.org/movie/465136
8168,http://image.tmdb.org/t/p/w780/87vuFOt2vMofvZe...,Every Day,https://www.themoviedb.org/movie/465136


In [9]:
embeddings.keys()

dict_keys([0, 1])

In [11]:
embeddings[0].shape

(25000, 512)

In [5]:
def get_html(url_list, height=200):
    html = "<div style='margin-top: 20px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url, title, link in url_list:
        html2 = f"<img title='{title}' style='height: {height}px; margin-bottom: 10px' src='{url}'>"
        if len(link) > 0:
            html2 = f"<a href='{link}' target='_blank'>" + html2 + "</a>"
        html = html + html2
    html += "</div>"
    return html

query = widgets.Text(layout=widgets.Layout(width='400px'))
dataset =widgets.Dropdown(
    options=['Unsplash', 'Movies'],
    value='Unsplash'
)
button = widgets.Button(description="Search")
output = widgets.Output()

display(widgets.HBox([query, button, dataset],
                     layout=widgets.Layout(justify_content='center')),
        output)

def image_search(query, n_results=24):
    text_embeddings = compute_text_embeddings([query]).detach().numpy()
    k = 0 if dataset.value == 'Unsplash' else 1
    results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
    return [(df[k].iloc[i]['path'],
             df[k].iloc[i]['tooltip'] + source[k],
             df[k].iloc[i]['link']) for i in results]

def on_button_clicked(b):
    if len(query.value) > 0:
        results = image_search(query.value)
        output.clear_output()
        with output:
            display(HTML(get_html(results)))

button.on_click(on_button_clicked)
dataset.observe(on_button_clicked, names='value')

HBox(children=(Text(value='', layout=Layout(width='400px')), Button(description='Search', style=ButtonStyle())…

Output()