# OpenAI CLIP Demo

Based on [Github Minimal user-friendly demo of OpenAI's CLIP](https://github.com/vivien000/clip-demo) which is also available at [Huggingface CLIP demo](https://huggingface.co/spaces/vivien/clip), however heavily modified.

NOTE:

Huggingface offeres different interfaces/functions from those in the original CLIP. e.g. ```CLIPProcessor``` handles both text and images. Whereas the original CLIP in github handles separately.

## References

### Open AI

* [OpenAI Github CLIP](https://github.com/openai/CLIP)

```
import torch
import clip   # <--- This is effective when cloning the git repository only
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]
```

### Huggingface

* [Huggingface CLIP model](https://huggingface.co/docs/transformers/main/en/model_doc/clip)
* [CLIPProcessor](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPProcessor)

> Constructs a CLIP processor which wraps a CLIP image processor and a CLIP tokenizer into a single processor. CLIPProcessor offers all the functionalities of CLIPImageProcessor and CLIPTokenizerFast. See the __call__() and decode() for more information.

* [Huggingface Model openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)

### Tutorials

* [Quick-fire Guide to Multi-Modal ML With OpenAI’s CLIP](https://towardsdatascience.com/quick-fire-guide-to-multi-modal-ml-with-openais-clip-2dad7e398ac0)


In [3]:
!pip install -q transformers torch tensorflow > /dev/null

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pytest-astropy 0.8.0 requires pytest-cov>=2.0, which is not installed.
pytest-astropy 0.8.0 requires pytest-filter-subpackage>=0.1, which is not installed.
sagemaker 2.145.0 requires importlib-metadata<5.0,>=1.4.0, but you have importlib-metadata 6.3.0 which is incompatible.
sagemaker 2.145.0 requires PyYAML==5.4.1, but you have pyyaml 6.0 which is incompatible.
docker-compose 1.29.2 requires PyYAML<6,>=3.10, but you have pyyaml 6.0 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
import os
import glob
import time
import pathlib
import urllib.request
import urllib
import requests
import multiprocessing
from multiprocessing.dummy import Pool
from typing import (
    List,
    Dict,
    Callable,
    Any,
    Union,
    Optional
)

from PIL import Image, ImageFile
from IPython.display import display, Markdown, HTML, clear_output
import ipywidgets as widgets

import pandas as pd
import numpy as np
from transformers import (
    CLIPProcessor, 
    CLIPTextModel, 
    CLIPModel, 
    logging
)
import torch

In [5]:
logging.get_verbosity = lambda: logging.NOTSET
clear_output()

# Environment

In [6]:
DEVICE_CPU: str = 'cpu'
DEVICE_CUDA: str = 'cuda'
DEVICE_IS_CUDA: bool = torch.cuda.is_available()
DEVICE = torch.device(DEVICE_CUDA if torch.cuda.is_available() else DEVICE_CPU)
DEVICE_TYPE: str = DEVICE.type

# Constant

In [7]:
NUM_CPUS: int = multiprocessing.cpu_count()
DATA_DIR = "./data"

PATH_TO_UNSPLUSH_CSV: str = os.path.join(DATA_DIR, 'unsplush.csv')
PATH_TO_MOVIES_CSV: str = os.path.join(DATA_DIR, 'movies.csv')

MODEL_NAME: str = "openai/clip-vit-base-patch32"
BATCH_SIZE: int = 256

DO_DOWNLOAD: bool = False
DO_EMBEDDING: bool = True

# Utility

In [8]:
def mkdir(path: str):
    pathlib.Path(path).mkdir(parents=True, exist_ok=True)    


def exists_url(url: str) -> bool:
    """Check if URL exists"""
    response: requests.models.Response = requests.head(url)
    if response.status_code not in [200, 404]:
        response.raise_for_status()
        
    return response.status_code == 200


def fetch_url(url_filename, data_dir=DATA_DIR):
    try:
        url, filename = url_filename
        path_to_image_file: str = os.path.join(data_dir, filename)
        if not is_file(path_to_image_file):
            urllib.request.urlretrieve(url, path_to_image_file)
            time.sleep(0.5)
            
    except urllib.error.HTTPError as error:
        msg = f"featch URL:[{url}] filename:[{filename}] failed due to {error}"
        raise RuntimeError(msg)

        
def get_fetch_url(data_dir):
    def fetch(url_filename):
        fetch_url(url_filename, data_dir)
        
    return fetch


def load_image(path_to_file, same_height=False):
    try:
        im = Image.open(path_to_file)
        if im.mode != 'RGB':
            im = im.convert('RGB')
        if same_height:
            ratio = 224/im.size[1]
            return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))    
        else:
            ratio = 224/min(im.size)
            return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))
    except FileNotFoundError as error:
        print(f"path: {os.path.join(data_dir, path)} does not exist.")
        
        
def is_file(path):
    return pathlib.Path(path).is_file()

In [9]:
def remove_rows_with_non_exist_url(df: pd.DataFrame) -> pd.DataFrame:
    print("start remove_rows_with_non_exist_url")
    missings = []
    for index in df.index.values:
        url = df.iloc[index]['path']
        if not exists_url(url=url):
            print(f"index:[{index}] url:[{url}] does not exist")
            missings.append(index)

        if index % 500 == 0:
            print(index)
            
    df.drop(labels=missings, axis=0, inplace=True)
    df.reset_index(drop=True, inplace=True)
        
    return df

# Data

In [9]:
mkdir(DATA_DIR)

In [14]:
if not is_file(PATH_TO_UNSPLUSH_CSV):
    urllib.request.urlretrieve(
        'https://drive.google.com/uc?export=download&id=1bt1O-iArKuU9LGkMV1zUPTEHZk8k7L65', 
        PATH_TO_UNSPLUSH_CSV
    )
if not is_file(PATH_TO_MOVIES_CSV):
    urllib.request.urlretrieve(
        'https://drive.google.com/uc?export=download&id=19aVnFBY-Rc0-3VErF_C7PojmWpBsb5wk', 
        PATH_TO_MOVIES_CSV
    )

# urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1onKr-pfWb4l6LgL-z8WDod3NMW-nIJxE', 'embeddings.npy')
# urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1KbwUkE0T8bpnHraqSzTeGGV4-TZO_CFB', 'embeddings2.npy')

Remove rows if URL does not exist

In [10]:
if DO_DOWNLOAD:
    unsplush_df: pd.DataFrame = remove_rows_with_non_exist_url(df=pd.read_csv(os.path.join(DATA_DIR, 'unsplush.csv')))
else:
    unsplush_df: pd.DataFrame = pd.read_csv(os.path.join(DATA_DIR, 'unsplush.csv'))

In [11]:
if DO_DOWNLOAD:
    movies_df: pd.DataFrame = remove_rows_with_non_exist_url(df=pd.read_csv(os.path.join(DATA_DIR, 'movies.csv')))
else:
    movies_df = pd.read_csv(os.path.join(DATA_DIR, 'movies.csv'))

In [12]:
dataframes: Dict[int, pd.DataFrame] = {
    0: unsplush_df,
    1: movies_df
}

In [13]:
len(dataframes[0])

24994

In [14]:
if DO_DOWNLOAD:
    dataframes[0].to_csv(PATH_TO_UNSPLUSH_CSV, index=False, encoding='utf-8')
    dataframes[1].to_csv(PATH_TO_MOVIES_CSV, index=False, encoding='utf-8')

In [15]:
movies_df: pd.DataFrame = dataframes[0]
movies_df[:3]

Unnamed: 0,path,tooltip,link
0,https://images.unsplash.com/uploads/1411949294...,"""Woman exploring a forest"" by Michelle Spencer",https://unsplash.com/photos/XMyPniM9LF0
1,https://images.unsplash.com/photo-141633941111...,"""Succulents in a terrarium"" by Jeff Sheldon",https://unsplash.com/photos/rDLBArZUl1c
2,https://images.unsplash.com/photo-142014251503...,"""Rural winter mountainside"" by John Price",https://unsplash.com/photos/cNDGZ2sQ3Bo


In [16]:
def download(df, data_dir:str, fetch_fn=None):
    max_n_parallel = NUM_CPUS * 2
    latency = 1  # idle duration to reduce the download rate for the images
    divider = 300
    length = len(df)
    print(f"total images:[{length}]")

    position: int = 0
    while position < length:
        n_parallel = min(max_n_parallel, length - position)
        url_filename_list = [
            (df.iloc[position + increment]['path'], str(position + increment) + '.jpeg') 
            for increment in range(n_parallel)
        ]
        _ = Pool(n_parallel).map(fetch_fn, url_filename_list)
        position += n_parallel

        if position // divider > 0:
            print(position)
            divider += 300

        # time.sleep(latency)
    assert position == length, f"expected position:[{position}] == length:[{length}]"
    print(f"done [{data_dir}]")

Python multiprocessing can cause "can't start new thread". The kernel resource may have been exhuaused, then the instance itself needs to be restarted, not just Python kernel.

```
----> 5     download(df=dataframes[1], data_dir=data_dir, fetch_fn=get_fetch_url(data_dir))

<ipython-input-21-50353c4a8d6f> in download(df, data_dir, fetch_fn)
     13             for increment in range(n_parallel)
     14         ]
---> 15         _ = Pool(n_parallel).map(fetch_fn, url_filename_list)
     16         position += n_parallel
     17 

/opt/conda/lib/python3.7/multiprocessing/dummy/__init__.py in Pool(processes, initializer, initargs)
    122 def Pool(processes=None, initializer=None, initargs=()):
    123     from ..pool import ThreadPool
--> 124     return ThreadPool(processes, initializer, initargs)
    125 
    126 JoinableQueue = Queue

/opt/conda/lib/python3.7/multiprocessing/pool.py in __init__(self, processes, initializer, initargs)
    800 
    801     def __init__(self, processes=None, initializer=None, initargs=()):
--> 802         Pool.__init__(self, processes, initializer, initargs)
    803 
    804     def _setup_queues(self):

/opt/conda/lib/python3.7/multiprocessing/pool.py in __init__(self, processes, initializer, initargs, maxtasksperchild, context)
    174         self._processes = processes
    175         self._pool = []
--> 176         self._repopulate_pool()
    177 
    178         self._worker_handler = threading.Thread(

/opt/conda/lib/python3.7/multiprocessing/pool.py in _repopulate_pool(self)
    239             w.name = w.name.replace('Process', 'PoolWorker')
    240             w.daemon = True
--> 241             w.start()
    242             util.debug('added worker')
    243 

/opt/conda/lib/python3.7/multiprocessing/dummy/__init__.py in start(self)
     49         if hasattr(self._parent, '_children'):
     50             self._parent._children[self] = None
---> 51         threading.Thread.start(self)
     52 
     53     @property

/opt/conda/lib/python3.7/threading.py in start(self)
    850             _limbo[self] = self
    851         try:
--> 852             _start_new_thread(self._bootstrap, ())
    853         except Exception:
    854             with _active_limbo_lock:

RuntimeError: can't start new thread
```

In [22]:
if DO_DOWNLOAD:
    data_dir=data_dir=os.path.join(DATA_DIR, 'unsplush')
    mkdir(data_dir)

    download(df=dataframes[0], data_dir=data_dir, fetch_fn=get_fetch_url(data_dir))

total images:[24994]
304
608
912
1200
1504
1808
2112
2400
2704
3008
3312
3600
3904
4208
4512
4800
5104
5408
5712
6000
6304
6608
6912
7200
7504
7808
8112
8400
8704
9008
9312
9600
9904
10208
10512
10800
11104
11408
11712
12000
12304
12608
12912
13200
13504
13808
14112
14400
14704
15008
15312
15600
15904
16208
16512
16800
17104
17408
17712
18000
18304
18608
18912
19200
19504
19808
20112
20400
20704
21008
21312
21600
21904
22208
22512
22800
23104
23408
23712
24000
24304
24608
24912
done [./data/unsplush]


In [16]:
if DO_DOWNLOAD:
    data_dir=os.path.join(DATA_DIR, 'movies')
    mkdir(data_dir)

    download(df=dataframes[1], data_dir=data_dir, fetch_fn=get_fetch_url(data_dir))

total images:[8170]
304
608
912
1200
1504
1808
2112
2400
2704
3008
3312
3600
3904
4208
4512
4800
5104
5408
5712
6000
6304
6608
6912
7200
7504
7808
8112
done [./data/movies]


# Model

In [17]:
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", from_tf=True)
model = CLIPModel.from_pretrained(MODEL_NAME)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

In [18]:
if DEVICE_IS_CUDA:
    torch.cuda.empty_cache()

model.to(DEVICE)

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0): CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05, element

# Embedding

* [compute_CLIP_embeddings.ipynb](https://github.com/vivien000/clip-demo/blob/master/compute_CLIP_embeddings.ipynb)

In [19]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

def compute_text_embeddings(
    queries: List[str], model: torch.nn.Module, device: torch.device
):
    inputs = processor(text=queries, return_tensors="pt", padding=True)
    inputs.to(device)
    
    with torch.no_grad():
        features = model.get_text_features(**inputs)
        return model.get_text_features(**inputs)

    return features / features.norm(dim=-1, keepdim=True)
    # return features


def compute_image_embeddings(
    images: List[np.ndarray], model: torch.nn.Module, device: torch.device
):
    processed = processor(images=images, return_tensors="pt", padding=True)
    processed['pixel_values'] = processed['pixel_values'].to(device)
    
    # return model.get_image_features(**processor(images=list_of_images, return_tensors="pt", padding=True))    
    # return model.get_image_features(**processed)
    with torch.no_grad():
        features = model.get_image_features(**processed)

    return features / features.norm(dim=-1, keepdim=True)
    # return features


In [20]:
def path_to_file(position, index) -> str:
    return os.path.join(
        image_dir,
        str(position + index) + '.jpeg'
    )

def run_image_embeddings(
    df: pd.DataFrame,
    image_dir:str, 
    model: torch.nn.Module, 
    device: torch.device,
    path_to_embedding_file:str,
    batch_size: int = BATCH_SIZE
):
    files: List[str] = glob.glob(os.path.join(image_dir, '*.jpeg'))
    # num_files: int = len(files)
    length: int = len(df)
    assert length > 0

    image_embeddings: np.ndarray = None
    position :int = 0      # Current position 
    checkpoint_size: int = 1000
    
    while position < length:
        num_images = min(batch_size, length - position)
        images: List[np.ndarray] = [
            load_image(path_to_file=path_to_file(position=position, index=index)) 
            for index in range(num_images)
            # if is_file(path=path_to_file(position=position, index=index))
        ]
        assert len(images) > 0 and len(images) == num_images, \
            f"expected [{num_images}], got [{len(images)}]"

        batch_embeddings = compute_image_embeddings(images=images, model=model, device=device)
        
        if device.type == DEVICE_CUDA:
            batch_embeddings = batch_embeddings.cpu()

        batch_embeddings = batch_embeddings.detach().numpy()

        if image_embeddings is None:
            image_embeddings = batch_embeddings
        else:
            image_embeddings = np.vstack((image_embeddings, batch_embeddings))

        position += num_images
        assert position == image_embeddings.shape[0], \
            f"expected position:[{position}] >= size of embeddings:[{image_embeddings.shape[0]}]"
        
        # Save the current embeddings
        if position // checkpoint_size > 0:
            checkpoint_size += 1000
            np.save(path_to_embedding_file, image_embeddings)
            print(position)

    np.save(path_to_embedding_file, image_embeddings)
    print(f"done: {image_dir}")

In [21]:
if DO_EMBEDDING:
    for index, name in enumerate(['unsplush', 'movies']):
        df = dataframes[index]

        image_dir: str = os.path.join(DATA_DIR, name)
        path_to_embedding_file: str = f"embedding_{name}.npy"
        run_image_embeddings(
            df=df,
            image_dir=image_dir, 
            path_to_embedding_file=path_to_embedding_file, 
            model=model,
            device=DEVICE
        )

1024
2048
3072
4096
5120
6144
7168
8192
9216
10240
11008
12032
13056
14080
15104
16128
17152
18176
19200
20224
21248
22016
23040
24064
done: ./data/unsplush
1024
2048
3072
4096
5120
6144
7168
8170
done: ./data/movies


## Load Embeddings

In [22]:
embeddings = {
    0: np.load("embedding_unsplush.npy"), 
    1: np.load("embedding_movies.npy")
}

source = {
    0: '\nSource: Unsplash', 
    1: '\nSource: The Movie Database (TMDB)'
}

In [23]:
embeddings[0].shape

(24994, 512)

# Image Utility

In [24]:
def get_html(url_list, height=200):
    html = "<div style='margin-top: 20px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url, title, link in url_list:
        html2 = f"<img title='{title}' style='height: {height}px; margin-bottom: 10px' src='{url}'>"
        if len(link) > 0:
            html2 = f"<a href='{link}' target='_blank'>" + html2 + "</a>"
        html = html + html2
    html += "</div>"
    return html

# Image Search

In [25]:
query = widgets.Text(layout=widgets.Layout(width='400px'))
dataset =widgets.Dropdown(
    options=['Unsplash', 'Movies'],
    value='Unsplash'
)
button = widgets.Button(description="Search")
output = widgets.Output()

display(
    widgets.HBox(
        [query, button, dataset],
        layout=widgets.Layout(justify_content='center')
    ),
    output
)

def image_search(
    query: str, model: torch.nn.Module, device:torch.device, n_results: int = 15
):
    text_embeddings = compute_text_embeddings(queries=[query], model=model, device=device)
    if device.type == DEVICE_CUDA:
        text_embeddings = text_embeddings.cpu()
        
    text_embeddings = text_embeddings.detach().numpy()
    
    k = 0 if dataset.value == 'Unsplash' else 1
    results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
    return [
        (dataframes[k].iloc[i]['path'], dataframes[k].iloc[i]['tooltip'] + source[k], dataframes[k].iloc[i]['link']) 
        for i in results
    ]

def on_button_clicked(b):
    if len(query.value) > 0:
        results = image_search(query=query.value, model=model, device=DEVICE)
        output.clear_output()
        with output:
            display(HTML(get_html(results)))

button.on_click(on_button_clicked)
dataset.observe(on_button_clicked, names='value')

HBox(children=(Text(value='', layout=Layout(width='400px')), Button(description='Search', style=ButtonStyle())…

Output()