In [None]:
!pip install opencv_python

In [None]:
from transformers import CLIPProcessor, CLIPModel

#repo_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
#repo_id = "laion/CLIP-ViT-g-14-laion2B-s12B-b42K" # more modern CLIP model
repo_id = "openai/clip-vit-large-patch14-336" # the CLIP model used for SD up to v1.5
device = 'mps'

print("loading model...")
model = CLIPModel.from_pretrained(repo_id)
print("loading preprocessor...")
processor = CLIPProcessor.from_pretrained(repo_id)

print(f"sending to {device}...")
model.half().to(device)
print("done")


In [None]:
import PIL
import torch
import torchvision.transforms as transforms

def resize_image(fullsize_image: PIL.Image, min_edge_length: int, max_edge_length: int=None) -> PIL.Image:
    transform = transforms.Resize(size=min_edge_length, max_size=max_edge_length)
    return transform(fullsize_image)



def get_clip_image_features(image: PIL.Image) -> torch.Tensor:
    preprocess_results = processor(text=None, 
                                   images=image, 
                                   return_tensors="pt", 
                                   padding=True, 
                                   device=device
                                  )
    pixel_values = preprocess_results.pixel_values
    #print(pixel_values.device)
    image_features = model.get_image_features(pixel_values = pixel_values.half().to(model.device))
    return image_features

def get_clip_text_features(text: str) -> torch.Tensor:
    input = processor(text=[text], images=None, return_tensors="pt", padding=True)
    #print(input)
    input_ids = input.input_ids.to(device)
    return model.get_text_features(input_ids=input_ids)


In [None]:

import cv2
import numpy as np
from async_video_processor import AsyncVideoProcessor


accumulated_results = {}
out_data = {}

def process_func(frame_cv_bgr):
    frame_cv_rgb = cv2.cvtColor(np.array(frame_cv_bgr), cv2.COLOR_BGR2RGB)    
    pil_image = PIL.Image.fromarray(frame_cv_rgb)
    features = get_clip_image_features(resize_image(pil_image, min_edge_length= 512))
    
    return features.detach().cpu()

def results_func(frame_index, data):
    accumulated_results[frame_index] = data



In [None]:
import time

def test_tqdm_manual():
    pbar = tqdm(range(1000))
    for i in range(1000):
        pbar.update(1)
        time.sleep(0.01)

test_tqdm_manual()

In [None]:
import pickle
import os
from tqdm.notebook import tqdm

async def write_clip_features(root_path):
    global accumulated_results
    print("walking", root_path)
    for directory, _, filenames in os.walk(root_path):
        video_extensions = [".mp4"]
        print('directory:', directory)
        video_filenames = [f for f in filenames if os.path.splitext(f)[1] in video_extensions]
        if len(video_filenames)==0:
            continue
        for filename in tqdm(video_filenames, desc=directory):
            video_path = os.path.join(directory, filename)
            pickle_path = video_path + ".clip-features.pickle"
            if os.path.exists(pickle_path):
                print("not overwriting existing", pickle_path)
                continue

            accumulated_results = {}
            process_fps = 0.5
            first_frame_to_process = 0

            def write_results_func(video, partial:bool):
                if partial:
                    return
                out_data = {
                    'type': 'clip features ' + repo_id,
                    'fps': video.get(cv2.CAP_PROP_FPS),
                    'features': accumulated_results,
                    'frameIncrement': frame_increment
                }
                
                with open(pickle_path, 'wb') as handle:
                    pickle.dump(out_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
                print("cumulative detection count:",str(len(accumulated_results)))

            async_video_processor = AsyncVideoProcessor(video_path, process_func, results_func, write_results_func, first_frame_to_process, process_fps)
            frame_increment = async_video_processor.frameIncrement
            await async_video_processor.run()


In [None]:
await write_clip_features("./videos")

In [None]:
import numpy as np
from dataclasses import dataclass, field
import nmslib

@dataclass
class ClipFeatureDataCollection:
    paths: list[str] = field(default_factory=list)
    feature_datas: list[ClipFeatureData] = field(default_factory=list)
    id_offsets: list[int] = field(default_factory=list)
    current_id_offset: int = 0
    index = None
        
    def add_feature_data(self, path: str, fd: ClipFeatureData):
        self.paths.append(path)
        self.feature_datas.append(fd)
        self.id_offsets.append(self.current_id_offset)
        self.current_id_offset += len(fd.features)
    
    def rebuild_index(self):
        self.index = None
        
        all_features = [torch.tensor(cfd.features[key]) for cfd in self.feature_datas for key in sorted(cfd.features.keys())]
        all_features_tensor = torch.cat(all_features)
        self.index = build_bruteforce_index(all_features_tensor)
    
    def _get_fd_index_for_id(self, id) -> int:
        for index in range(1,len(self.id_offsets)):
            if id < self.id_offsets[index]:
                # gone too far -> step back
                return index-1
        # last one
        return len(self.id_offsets)-1
    
    def get_path_and_frame_for_id(self, id) -> (str, int):
        index = self._get_fd_index_for_id(id)
        fd = self.feature_datas[index]
        feature_index = id - self.id_offsets[index]
        frame_number = fd.frame_increment * feature_index
        path = self.paths[index]
        return path,frame_number

    
    def get_closest_frames(self, features: torch.Tensor, k=4) -> [(str, int)]:
        ids, distances = self.index.knnQuery(features, k=k)
        return [self.get_path_and_frame_for_id(i) for i in ids]
        


@dataclass
class ClipFeatureData:
    features: dict[int, np.array]
    fps: float
    frame_increment: int
    
    def get_nearest_features(self, frame:int) -> np.array:
        matching_frame = (frame // self.frame_increment) * self.frame_increment
        next_frame = matching_frame + self.frame_increment
        prev_frame = matching_frame - self.frame_increment
        
        matching_distance = abs(frame - matching_frame)
        next_distance = abs(frame - next_frame)
        prev_distance = abs(frame - prev_frame)
        
        if next_distance < matching_distance and next_frame in self.features:
            print("next:", next_frame)
            return self.features[next_frame]
        if prev_frame in self.features and (
            # may need to catch valid frames after the last frame we've captured
                prev_distance < matching_distance or matching_frame not in self.features
            ):
            print("prev:", prev_frame)
            return self.features[prev_frame]
        if matching_frame in self.features:
            print("at:", matching_frame)
            return self.features[matching_frame]
            
        raise ValueError(f"frame {frame} not available (max {max(self.features.keys())})")
        
def load_clip_features(video_path):
    pickle_path = video_path + ".clip-features.pickle"
    with open(pickle_path, 'rb') as f:
        data = pickle.load(f)
        return ClipFeatureData(features = data['features'], 
                               fps = data['fps'], 
                               frame_increment = data['frameIncrement'])

def build_fancy_index(all_features_tensor) -> nmslib:
    return build_index(all_features_tensor, method='hnsw')
    
def build_bruteforce_index(all_features_tensor: torch.Tensor) -> nmslib:
    return build_index(all_features_tensor, method='brute_force')

def build_index(all_features_tensor: torch.Tensor, method:str = 'brute_force') -> nmslib:
    print("building index for all_features_tensor of shape", all_features_tensor.shape)
    index = nmslib.init(method=method, space='cosinesimil')
    index.addDataPointBatch(all_features_tensor.cpu())
    opts = {}
    if method == 'hnsw':
        args['post'] = 2
    index.createIndex(opts, print_progress=True)
    return index

def get_closest_images(index: nmslib, features: np.array, k=4):
    ids, distances = index.knnQuery(features, k=k)
    #print("knnquery got ids", ids)
    return ids, distances


In [None]:
import torch
video_path = "./videos/hotline-bling.mp4"
data = load_clip_features(video_path)
all_features_tensor = torch.cat([torch.tensor(f) for f in data.features.values()])
index = build_bruteforce_index(all_features_tensor)

In [None]:
import random
import PIL
import numpy as np
import cv2

def load_video(video_path):
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise ValueError("cannot open " + video_path)

    video.set(cv2.CAP_PROP_POS_AVI_RATIO,1)
    total_frames = video.get(cv2.CAP_PROP_POS_FRAMES)
    video.set(cv2.CAP_PROP_POS_AVI_RATIO,0)
    
    return video,total_frames


def get_frame_pil(video, frame_number):
    video.set(cv2.CAP_PROP_POS_FRAMES,frame_number)
    ret, frame_bgr = video.read()
    if not ret:
        raise RuntimeError(f"unable to read frame {frame_number}")

    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    return PIL.Image.fromarray(frame_rgb)
    
def resize_image(image, max_width, max_height) -> PIL.Image:
    width, height = image.size
    aspect = width/height
    factor = max_width/width
    if height * factor > max_height:
        factor = max_height/height
    return image.resize([int(width*factor), int(height*factor)])
    
def display_video_frame(video, frame_number, max_width, max_height):
    display(resize_image(get_frame_pil(video, frame_number), max_width, max_height))
    
def display_nearest_frames(data:ClipFeatureData, index: nmslib, video, search_features, k=4):
    ids, distances = get_closest_images(index, search_features, k=k)
    print(f"found ids: {ids} -> frames: {[data.frame_increment*i for i in ids]}")
    for frame_number in [data.frame_increment*i for i in ids]:
        display_video_frame(video, frame_number, 300, 300)

def get_combined_weighted_features(feature_tensors: [torch.Tensor], weights: [float]) -> torch.Tensor:
    combined = torch.cat([t*weights[i] for i,t in enumerate(feature_tensors)])
    print(combined.shape)
    result = torch.sum(combined, dim=0) / sum(weights)
    return result
    

In [None]:
video, total_frames = load_video(video_path)
search_string = "a man on a staircase"

for i in range(3):
    source_frame = random.randint(0, total_frames)
    print("source:")
    display(resize_image(get_frame_pil(video, source_frame), 300, 300))
    source_frame_features = data.get_nearest_features(source_frame)
    text_features = get_clip_text_features(search_string)
    
    search_features = get_combined_weighted_features([source_frame_features, text_features.detach().cpu()],
                                                    [0.1,0.9])
    
    print("nearest:")
    display_nearest_frames(data, index, video, search_features, k=4)


In [None]:
video_paths = [ "./videos/hotline-bling.mp4", "./videos/b.mp4" ]
fpc = ClipFeatureDataCollection()
for video_path in video_paths:
    data = load_clip_features(video_path)
    fpc.add_feature_data(video_path, data)
fpc.rebuild_index()

In [None]:
search_string = "a man on a staircase"
text_features = get_clip_text_features(search_string).detach().cpu()
for video_path, frame_number in fpc.get_closest_frames(text_features, k=10):
    print("video_path: ", video_path)
    video, _ = load_video(video_path)
    display_video_frame(video, frame_number, 300, 300)

