In [None]:
!pip install transformers pillow torchvision torch safetensors pillow
!CFLAGS="-mavx -DWARN(a)=(a)" pip install nmslib

In [None]:
!python --version
!pip list

In [None]:
!pip uninstall transformers

In [None]:
!pip install transformers --no-cache-dir

# Load the CLIP model

In [None]:
from transformers import CLIPProcessor, CLIPModel



#repo_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
#repo_id = "laion/CLIP-ViT-g-14-laion2B-s12B-b42K" # more modern CLIP model
repo_id = "openai/clip-vit-large-patch14-336" # the CLIP model used for SD up to v1.5
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

print("loading model...")
model = CLIPModel.from_pretrained(repo_id)
print("loading preprocessor...")
processor = CLIPProcessor.from_pretrained(repo_id)

print(f"sending to {device}...")
if device == 'cuda':
    model = model.half()
model.to(device)
print("done")


## helper code to load images and create CLIP feature vectors

In [None]:
import io
import PIL
from torchvision import transforms
import torch
from typing import Union


def get_clip_image_features(image: PIL.Image) -> torch.Tensor:
    preprocess_results = processor(text=None, 
                                   images=image, 
                                   return_tensors="pt", 
                                   padding=True, 
                                   device=device
                                  )
    pixel_values = preprocess_results.pixel_values
    #print(pixel_values.device)
    image_features = model.get_image_features(pixel_values = pixel_values.to(model.device))
    return image_features


def resize_image(fullsize_image: PIL.Image, min_edge_length: int, max_edge_length: int=None) -> PIL.Image:
    transform = transforms.Resize(size=min_edge_length, max_size=max_edge_length)
    return transform(fullsize_image)


def get_images_feature_vecs(images) -> torch.Tensor:
    if type(images) is str or type(images) is bytes:
        images = [images]
        
    #print([type(i) for i in images])
    image_sources = [io.BytesIO(i) if type(i) is bytes else i for i in images]
    images = [resize_image(PIL.Image.open(i), min_edge_length=512, max_edge_length=768) for i in image_sources]
    clip_features = get_clip_image_features(images)

    return clip_features


# `gather_clip_features_recursive`: crawl a folder and collect CLIP vectors

Starting at `root_path`, `gather_clip_features_recursive` calculates a CLIP feature vector for every 
encountered image. When all directories have been processed, a tuple of `hashes, per_image_features` 
is returned.
*  `hashes` contains sha256 hashes for each image, indexed by image path relative to `root_path`.
*  `per_image_features` contains a CLIP feature vector for each image, also indexed by image path 
    relative to `root_path`.

If `existing_hashes` is passed, CLIP feature analysis will be skipped if the image hash is already 
present in `existing_hashes`.


In [None]:
import os
import traceback
from tqdm.notebook import tqdm
import hashlib
from safetensors.torch import safe_open

def gather_clip_features_recursive(root_path: str, existing_hashes: dict=None, check_hashes: bool=True) -> tuple[dict,dict]:
    """
    """
    image_extensions = [".png", ".jpg", ".jpeg"]
    per_image_features = {}
    hashes = {}
    try:
        for directory, _, filenames in os.walk(root_path):
            #print('directory:', directory)
            image_filenames = [f for f in filenames if os.path.splitext(f)[1] in image_extensions]
            if len(image_filenames)==0:
                continue
            for filename in tqdm(image_filenames, desc=directory):
                image_path = os.path.join(directory, filename)
                relative_path = "./" + os.path.relpath(image_path, root_path).replace('\\', '/')
                #print(f"from '{root_path}', '{directory}', '{filename}' got {relative_path}")
                #raise KeyboardInterrupt
                if not check_hashes and relative_path in existing_hashes:
                    #print(f"already have entry for {filename}")
                    continue
                try:
                    image_file_bytes = None
                    
                    with open(image_path, 'rb') as image_file_handle:
                        image_file_bytes = image_file_handle.read() # read entire file as bytes
                    shasum = hashlib.sha256(image_file_bytes).hexdigest()

                    #print(f"checking for hash {shasum} in {existing_hashes}...")
                    if existing_hashes is not None and existing_hashes.get(filename, '') == shasum:
                        continue
                        
                    features = get_images_feature_vecs(image_file_bytes).squeeze(0).detach()
                    per_image_features[relative_path] = features.detach().cpu().half().numpy()
                    hashes[relative_path] = shasum
                except (IOError, OSError) as e:
                    print(f"Error loading {image_path}: {e}")
                    continue
    except KeyboardInterrupt:
        print(f"interrupted, returning what we have collected so far (= vectors for {len(hashes)} images)")
    except Exception as e:
        traceback.print_exc()
        print(f"an exception occurred, saving what we have collected so far (= vectors for {len(hashes)} images)")

    return hashes, per_image_features


    

### File management

In [None]:
import pickle

def get_clip_features_pickle_path(root_path):
    return os.path.join(root_path, "__clip-features.pickle")

def load_clip_features(root_path):
    data_path = get_clip_features_pickle_path(root_path)
    hashes = {}
    per_image_features = {}
    if os.path.exists(data_path):
        with open(data_path, 'rb') as f:
            existing_dict = pickle.load(f)
            #print(f"loaded from {data_path}: {existing_dict}")
            existing_repo_id = existing_dict.get("repo_id")
            if repo_id != existing_repo_id:
                raise ImportError(f"repo id mismatch. saved is {existing_repo_id}, running is {repo_id}")
            
            hashes = existing_dict["hashes"]
            per_image_features = existing_dict["clip_features"]
    else:
        print(f"no pickle file at {data_path}")
    return hashes, per_image_features

def save_clip_features(root_path, hashes, per_image_features, repo_id):
    output_dict = {
        "repo_id": repo_id,
        "hashes": hashes,
        "clip_features": per_image_features
    }
    data_path = get_clip_features_pickle_path(root_path)
    print("saving to ", root_path)
    with open(data_path, 'wb') as f:
        pickle.dump(output_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
        print("saved")

def update_clip_features_recursive(root_path, check_hashes=True):
    """
    Crawl root_path and all subdirectories, creating or updating CLIP feature vectors as required.
    """
    print("updating features for images in", root_path)
    try:
        hashes, per_image_features = load_clip_features(root_path)
    except EOFError as e:
        print("error loading clip feature for {root_path}: {e}")
        hashes = {}
        per_image_features = {}
    new_hashes, new_per_image_features = gather_clip_features_recursive(root_path, hashes, check_hashes)
    if len(new_hashes)>0:
        hashes.update(new_hashes)
        per_image_features.update(new_per_image_features)
        save_clip_features(root_path, hashes, per_image_features, repo_id)
        print(f"wrote vectors for {len(hashes)} images ({len(new_hashes)} new) to {root_path}")
    else:
        print(f"already up to date with {len(hashes)} images")
    return hashes, per_image_features

# Create/update CLIP features for images

This is slow and expensive.

In [None]:
# this is slow and expensive
root_path = './images'
hashes, per_image_features = update_clip_features_recursive(root_path, check_hashes=False)

# Load existing CLIP features for images


In [None]:
root_path = './images'
hashes, per_image_features = load_clip_features(root_path)
print(f"{len(hashes)} features loaded")


# Make a UI for searching

In [None]:
import nmslib
import torch

def build_index(all_features_tensor: torch.Tensor) -> nmslib:
    print("building index for all_features_tensor of shape", all_features_tensor.shape)
    index = nmslib.init(method='hnsw', space='cosinesimil')
    index.addDataPointBatch(all_features_tensor.cpu())
    index.createIndex({'post': 2}, print_progress=True)
    return index

print(f"merging images features into a single tensor...")
all_features_tensor = torch.cat([torch.tensor(f, device=device).unsqueeze(0) for _,f in tqdm(per_image_features.items())])
print("building knn index...")
index = build_index(all_features_tensor)


In [None]:
import numpy as np

def get_filename(path):
    return os.path.split(path)[-1]

def get_closest_images(features: np.array, k=4):
    ids, distances = index.knnQuery(features, k=k)
    print("knnquery got ids", ids)
    return ids, distances

def log_closest(ids, distances, remove_directory=False):
    nearest_paths = [all_features[i][0] for i in ids]
    if remove_directory:
        nearest_paths = [get_filename(p) for p in nearest_paths]
    print(f"  {distances[1:]}")
    print(f"  {nearest_paths[1:]}")


## image 2 image search

In [None]:

for (path,features) in per_image_features.items():
    print(f"{get_filename(path)}:")
    ids, distances = get_closest_images(features.detach().cpu())
    log_closest(ids, distances)


## text 2 image search

In [None]:
def get_text_features(text: str) -> torch.Tensor:

    input = processor(text=[text], images=None, return_tensors="pt", padding=True)
    #print(input)
    input_ids = input.input_ids.to(device)
    return model.get_text_features(input_ids=input_ids)
    


In [None]:
import asyncio

class Timer:
    def __init__(self, timeout, callback):
        self._timeout = timeout
        self._callback = callback

    async def _job(self):
        await asyncio.sleep(self._timeout)
        self._callback()

    def start(self):
        self._task = asyncio.ensure_future(self._job())

    def cancel(self):
        self._task.cancel()

def debounce(wait):
    """ Decorator that will postpone a function's
        execution until after `wait` seconds
        have elapsed since the last time it was invoked. """
    def decorator(fn):
        timer = None
        def debounced(*args, **kwargs):
            nonlocal timer
            def call_it():
                fn(*args, **kwargs)
            if timer is not None:
                timer.cancel()
            timer = Timer(wait, call_it)
            timer.start()
        return debounced
    return decorator

In [None]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, fixed
import PIL

import io
import subprocess
import traceback

def open_in_finder(path):
    print("opening", path)
    subprocess.call(["open", "-R", path])

# Yield successive n-sized
# chunks from l.
def divide_chunks(l, n):
    # looping till length l
    for i in range(0, len(l), n):
        yield l[i:i + n]

def load_and_resize_image(path, target_width) -> PIL.Image:
    image = PIL.Image.open(path)

    size = image.size
    aspect_ratio = size[0]/size[1]

    height=round(width/aspect_ratio)
    new_size=(width,height)
    return image.resize(new_size)

        
@debounce(2.0)
def display_closest(ids, distances, load_more_cb, go_back_cb, in_widget, offset=0):
    in_widget.children = []
    print("displaying", ids)
    all_image_paths = list(per_image_features.keys())
    nearest_paths = [os.path.join(root_path, all_image_paths[i]) for i in ids]
    print("-> displaying", nearest_paths)
    
    width = 300
    images: list[bytes] = []
    aspect_ratios: list[float] = []
    for path in nearest_paths:
        resized = None
        try:
            resized = load_and_resize_image(path, width)
        except Exception:
            traceback.print_exc()
            resized = PIL.Image.new('RGB', (100, 100))
            
        aspect_ratios.append(resized.width/resized.height)
        byte_arr = io.BytesIO()
        resized.save(byte_arr, format='jpeg')
        images.append(byte_arr.getvalue())
    
    image_widgets = [widgets.Image(value=image, format='png', width=width, height=width/aspect_ratio) 
                     for (image,aspect_ratio) in zip(image_datas,aspect_ratios)]
    #print("made image widgets") 
    for image_widget in image_widgets:
        image_widget.layout.object_fit='contain'
    #print("made image widgets") 
    button_widgets = [widgets.Button(description='open') for _ in nearest_paths]
    for i, p in enumerate(nearest_paths):
        button_widgets[i].on_click(lambda b, p=p: open_in_finder(p))
    #    for i, button_widget in enumerate(button_widgets):
    #        button_widget.on_click(lambda path=nearest_paths[i]: print(path))
    #print("made button widgets") 
    widget_lists = [(image_widgets[i],
                    widgets.Label(value=p),
                    button_widgets[i]) 
                    for i,p in enumerate(nearest_paths)]
    #print("made widget lists") 
    
    vboxes = [widgets.VBox(w) for w in widget_lists]
    
    row_length = 5
    chunks = list(divide_chunks(vboxes, row_length))

    hboxes = [widgets.HBox(c) for c in chunks]
    
    #print(f"made {len(hboxes)} hboxes") 
    #print("vboxes:", vboxes)
    
    go_back_button = widgets.Button(description=f"prev ({offset+1-len(nearest_paths)})")
    go_back_button.on_click(go_back_cb)

    load_more_button = widgets.Button(description=f"next ({offset+1+len(nearest_paths)})")
    load_more_button.on_click(load_more_cb)

    in_widget.children = [widgets.VBox([go_back_button, load_more_button] + 
                                       hboxes + 
                                       [go_back_button, load_more_button])]
    #display(in_widget)
    

def find_nearest_images_for_text(text: str, k, offset, in_widget):
    features = get_text_features(text)
    query_vec = np.asarray(features.detach().cpu())
    #print(features)
    ids, distances = get_closest_images(query_vec, k=offset+k)
    #print('loading..', offset, len(ids))
    next_offset = offset+k
    prev_offset = offset-k
    def load_more(b):
        find_nearest_images_for_text(text, k, next_offset, in_widget)
    def go_back(b):
        find_nearest_images_for_text(text, k, prev_offset, in_widget)
    try:
        #print("displaying", ids)
        display_closest(ids[-k:], distances[-k:], load_more_cb=load_more, go_back_cb=go_back, in_widget=in_widget, offset=offset)
    except Exception as e:
        traceback.print_exc()
    return True


#open_in_finder(root_path)


In [None]:
images_hbox = widgets.HBox([])
display(images_hbox)


interact(find_nearest_images_for_text, text='', k=widgets.IntSlider(value=10, 
                                                                    min=1, 
                                                                    max=100,
                                                                    layout={'width':'600px'}
                                                                    
                                                                   ), 
         in_widget=fixed(images_hbox), 
            offset=fixed(0))
                  


#find_nearest_images_for_text("test", k=10, offset=20, in_widget=images_hbox)



## benchmark mps performance (spoiler: it's 2x faster)

In [None]:
# benchmark.py
import time
class benchmark(object):
    """ 
    usage:
        with benchmark(log_string):
            code_to_benchmark
    """
    def __init__(self,name):
        self.name = name
    def __enter__(self):
        self.start = time.time()
    def __exit__(self,ty,val,tb):
        end = time.time()
        print("%s : %0.3f seconds" % (self.name, end-self.start))
        return False

In [None]:

def run_benchmark():
    print("warming up...")
    write_image_feature_vec(["test-image-2.jpeg"])
    print("running benchmark...")
    with benchmark(device + 'batched'):
        write_image_feature_vec(["test-image.webp", "test-image-2.jpeg"])
        write_image_feature_vec(["test-image.webp", "test-image-2.jpeg"])
        write_image_feature_vec(["test-image.webp", "test-image-2.jpeg"])
    with benchmark(device + 'unbatched'):
        write_image_feature_vec(["test-image.webp"])
        write_image_feature_vec(["test-image-2.jpeg"])
        write_image_feature_vec(["test-image.webp"])
        write_image_feature_vec(["test-image-2.jpeg"])
        write_image_feature_vec(["test-image.webp"])
        write_image_feature_vec(["test-image-2.jpeg"])

device='mps'
print("moving model to",device)
model.to(device)
run_benchmark()


device='cpu'
print("moving model to",device)
model.to(device)
run_benchmark()
