# OpenAI CLIP Demo

Using [Github Minimal user-friendly demo of OpenAI's CLIP](https://github.com/vivien000/clip-demo) also available at [Huggingface CLIP demo](https://huggingface.co/spaces/vivien/clip).

* [Huggingface CLIP model](https://huggingface.co/docs/transformers/main/en/model_doc/clip)
* [Huggingface Model openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)

In [None]:
!pip install -q transformers torch tensorflow > /dev/null

In [None]:
import os
import glob
import time
import pathlib
import urllib.request
import urllib
import multiprocessing
from multiprocessing.dummy import Pool
from typing import (
    List,
    Dict,
    Callable,
    Any,
    Union,
    Optional
)

from PIL import Image, ImageFile
from IPython.display import display, Markdown, HTML, clear_output
import ipywidgets as widgets

import pandas as pd
import numpy as np
from transformers import (
    CLIPProcessor, 
    CLIPTextModel, 
    CLIPModel, 
    logging
)
import torch

In [None]:
logging.get_verbosity = lambda: logging.NOTSET
clear_output()

# Environment

In [None]:
DEVICE_CPU: str = 'cpu'
DEVICE_CUDA: str = 'cuda'
DEVICE_IS_CUDA: bool = torch.cuda.is_available()
DEVICE = torch.device(DEVICE_CUDA if torch.cuda.is_available() else DEVICE_CPU)
DEVICE_TYPE: str = DEVICE.type

# Constant

In [None]:
NUM_CPUS: int = multiprocessing.cpu_count()
RUN_EMBEDDING: bool = True
DATA_DIR = "./data"

MODEL_NAME: str = "openai/clip-vit-base-patch32"

# Utility

In [None]:
def fetch_url(url_filename, data_dir=DATA_DIR):
    try:
        url, filename = url_filename
        urllib.request.urlretrieve(url, os.path.join(data_dir, filename))
    except urllib.error.HTTPError as error:
        print(f"featch URL:[{url}] filename:[{filename}] failed due to {error}")

def get_fetch_url(data_dir):
    def fetch(url_filename):
        fetch_url(url_filename, data_dir)
        
    return fetch

def load_image(path_to_file, same_height=False):
    try:
        im = Image.open(path_to_file)
        if im.mode != 'RGB':
            im = im.convert('RGB')
        if same_height:
            ratio = 224/im.size[1]
            return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))    
        else:
            ratio = 224/min(im.size)
            return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))
    except FileNotFoundError as error:
        print(f"path: {os.path.join(data_dir, path)} does not exist.")
        
        
def is_file(path, data_dir=DATA_DIR):
    return pathlib.Path(os.path.join(data_dir, path)).is_file()

# Data

In [None]:
pathlib.Path(DATA_DIR).mkdir(parents=True, exist_ok=True)

In [None]:
urllib.request.urlretrieve(
    'https://drive.google.com/uc?export=download&id=1bt1O-iArKuU9LGkMV1zUPTEHZk8k7L65', 
    os.path.join(DATA_DIR, 'unsplush.csv')
)
urllib.request.urlretrieve(
    'https://drive.google.com/uc?export=download&id=19aVnFBY-Rc0-3VErF_C7PojmWpBsb5wk', 
    os.path.join(DATA_DIR, 'movies.csv')
)
# urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1onKr-pfWb4l6LgL-z8WDod3NMW-nIJxE', 'embeddings.npy')
# urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1KbwUkE0T8bpnHraqSzTeGGV4-TZO_CFB', 'embeddings2.npy')

In [None]:
def download(df, data_dir:str = DATA_DIR, fetch_fn=None):
    max_n_parallel = NUM_CPUS * 3
    latency = 1  # idle duration to reduce the download rate for the images
    divider = 200
    length = len(df)
    print(f"total images:[{length}]")

    position: int = 0
    while position < length:
        n_parallel = min(max_n_parallel, length - position)
        url_filename_list = [
            (df.iloc[position + increment]['path'], str(position + increment) + '.jpeg') 
            for increment in range(n_parallel)
        ]
        _ = Pool(n_parallel).map(fetch_fn, url_filename_list)
        position += n_parallel

        if position // divider > 0:
            print(position)
            divider += 200

        time.sleep(latency)

In [None]:
# for name in ['unsplush', 'movies']:
for name in ['movies']:
    data_dir: str = os.path.join(DATA_DIR, name)
    pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True)
    
    df = pd.read_csv(os.path.join(DATA_DIR, f'{name}.csv'))
    download(df=df, data_dir=data_dir, fetch_fn=get_fetch_url(data_dir))

# Model

In [None]:
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", from_tf=True)
model = CLIPModel.from_pretrained(MODEL_NAME)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

In [None]:
if DEVICE_IS_CUDA:
    torch.cuda.empty_cache()

model.to(DEVICE)

# Embedding

* [compute_CLIP_embeddings.ipynb](https://github.com/vivien000/clip-demo/blob/master/compute_CLIP_embeddings.ipynb)

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

def compute_text_embeddings(
    queries: List[str], model: torch.nn.Module, device: torch.device
):
    inputs = processor(text=queries, return_tensors="pt", padding=True)
    inputs.to(device)
    return model.get_text_features(**inputs)


def compute_image_embeddings(
    images: List[np.ndarray], model: torch.nn.Module, device: torch.device
):
    processed = processor(images=images, return_tensors="pt", padding=True)
    processed['pixel_values'] = processed['pixel_values'].to(device)
    
    # return model.get_image_features(**processor(images=list_of_images, return_tensors="pt", padding=True))    
    return model.get_image_features(**processed)      

In [None]:
def run_image_embeddings(
    image_dir:str, 
    model: torch.nn.Module, 
    device: torch.device,
    path_to_embedding_file:str, 
):
    files: List[str] = glob.glob(os.path.join(image_dir, '*.jpeg'))
    num_files: int = len(files)
    assert num_files > 0

    image_embeddings: np.ndarray = None
    position :int = 0      # Current position 
    batch_size: int = 32
    checkpoint_size: int = batch_size * 10

    def path_to_file(position, index) -> str:
        return os.path.join(
            image_dir,
            str(position + index) + '.jpeg'
        )
    
    while position < num_files:
        images: List[np.ndarray] = [
            load_image(path_to_file=path_to_file(position=position, index=index)) 
            for index in range(batch_size)
            if is_file(str(position + index) + '.jpeg')
        ]
        assert len(images) > 0
        
        batch_embeddings = compute_image_embeddings(images=images, model=model, device=device)
        if device.type == DEVICE_CUDA:
            batch_embeddings = batch_embeddings.cpu()

        batch_embeddings = batch_embeddings.detach().numpy()

        if image_embeddings is None:
            image_embeddings = batch_embeddings
        else:
            image_embeddings = np.vstack((image_embeddings, batch_embeddings))

        position += batch_size
        assert position >= image_embeddings.shape[0], \
            f"expected position:[{position}] > size of embeddings:[{image_embeddings.shape[0]}]"

        # Save the current embeddings
        if position // checkpoint_size > 0:
            checkpoint_size += batch_size * 10
            np.save(path_to_embedding_file, image_embeddings)
            print(position)

    np.save(path_to_embedding_file, image_embeddings)

In [None]:
if RUN_EMBEDDING:
    for name in ['unsplush', 'movies']:
        image_dir: str = os.path.join(DATA_DIR, name)
        path_to_embedding_file: str = f"embedding_{name}.npy"
        run_image_embeddings(
            image_dir=image_dir, 
            path_to_embedding_file=path_to_embedding_file, 
            model=model
            device=DEVICE
        )

## Load Embeddings

In [None]:
dataframes: Dict[int, pd.DataFrame] = {
    0: pd.read_csv('unsplush.csv'), 
    1: pd.read_csv('movies.csv')
}
embeddings = {
    0: np.load("embedding_unsplush.npy"), 
    1: np.load("embedding_movies.npy")
}

for k in [0, 1]:
    embeddings[k] = np.divide(
        embeddings[k], 
        np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True))
    )

source = {
    0: '\nSource: Unsplash', 
    1: '\nSource: The Movie Database (TMDB)'
}

# Image Utility

In [None]:
def get_html(url_list, height=200):
    html = "<div style='margin-top: 20px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url, title, link in url_list:
        html2 = f"<img title='{title}' style='height: {height}px; margin-bottom: 10px' src='{url}'>"
        if len(link) > 0:
            html2 = f"<a href='{link}' target='_blank'>" + html2 + "</a>"
        html = html + html2
    html += "</div>"
    return html

# Image Search

In [None]:
results = image_search("lion")
result[5:]

In [None]:
query = widgets.Text(layout=widgets.Layout(width='400px'))
dataset =widgets.Dropdown(
    options=['Unsplash', 'Movies'],
    value='Unsplash'
)
button = widgets.Button(description="Search")
output = widgets.Output()

display(
    widgets.HBox(
        [query, button, dataset],
        layout=widgets.Layout(justify_content='center')
    ),
    output
)

def image_search(
    query: str, model: torch.nn.Module, device:torch.device, n_results: int = 12
):
    text_embeddings = compute_text_embeddings(queries=[query], model=model, device=device)
    if device.type == DEVICE_CUDA:
        text_embeddings = text_embeddings.cpu()
        
    text_embeddings = text_embeddings.detach().numpy()
    
    k = 0 if dataset.value == 'Unsplash' else 1
    results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
    return [
        (df[k].iloc[i]['path'], df[k].iloc[i]['tooltip'] + source[k], df[k].iloc[i]['link']) 
        for i in results
    ]

def on_button_clicked(b):
    if len(query.value) > 0:
        results = image_search(query=query.value, model=model, device=DEVICE)
        output.clear_output()
        with output:
            display(HTML(get_html(results)))

button.on_click(on_button_clicked)
dataset.observe(on_button_clicked, names='value')