In [1]:
%reload_ext autoreload
%autoreload 2

import torch
from PIL import Image
import matplotlib.pyplot as plt
import clip
import os
import tqdm
import numpy as np
import cv2
from dataset import Dataset
from config import KEYFRAMES
from utils import norm_vectors, get_video_name
import pandas as pd
from retrieval import ClipRetrieval
import dataclasses
from pathlib import Path
import time

In [2]:
import shutil
p = r"C:\Users\ADMIN\Downloads\rag_langchain_\data".replace("\\", "/")
shutil.make_archive(p, "zip", p)

'C:\\Users\\ADMIN\\Downloads\\rag_langchain_\\data.zip'

In [2]:
dataset = Dataset()

Loading clip embeddings: 100%|███████████████████████████████████| 128/128 [00:00<00:00, 278.43it/s]
Loading image paths: 100%|███████████████████████████████████████| 128/128 [00:00<00:00, 857.63it/s]
Loading media info: 100%|███████████████████████████████████████| 128/128 [00:00<00:00, 2713.33it/s]
Loading map keyframes: 100%|█████████████████████████████████████| 128/128 [00:00<00:00, 594.75it/s]


In [3]:
retrive = ClipRetrieval(dataset)

INFO:root: Loading CLIP model...
INFO:root: CLIP model loaded.


In [4]:
st = time.time()
retrive.search_text("a cat")
results = retrive.collect_results()
print("Time: ", time.time() - st)

Time:  0.2593512535095215


In [10]:
print(results[0].watch_url)

https://youtube.com/embed/p6h043fMCUA


In [None]:
tmp = dataset.get_items(retrive.search_result["indexes"])
tmp[0]

In [None]:
results[0].pts_time

In [None]:
img = Image.open(results[0].keyframe[4]).convert("RGB")
plt.imshow(img)

In [None]:
collector = {get_video_name(keyframe): [] for keyframe in collect_results["keyframes"]}
for keyframe in collect_results["keyframes"]:
    video_name = get_video_name(keyframe)
    if video_name in collector.keys():
        collector[video_name].append(keyframe)
collector

In [None]:
len(collector)

In [None]:
collect_results["keyframes"]

In [None]:
p = r"C:\Users\ADMIN\Downloads\rag_langchain_\data\keyframes".replace("\\", "/")
path = list(Path(p).rglob("*.jpg"))
shape = []
for p in tqdm.tqdm(path):
    img = cv2.imread(str(p), -1)
    rz = cv2.resize(img, dsize=(320, 180), interpolation=cv2.INTER_LINEAR)
    cv2.imwrite(str(p), rz)

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(15, 3))

for i in range(5):
    img = Image.open(data.keyframes[idx[i]]).convert("RGB")
    axes[i].imshow(img)
    axes[i].axis("off")
    axes[i].set_title(f"Sim: {sim[idx[i]]:.4f}")


# Extract clip embs

In [None]:
from clip import clip
model, preprocess = clip.load("RN50")
model = model.eval()

In [None]:
from torch.utils.data import Dataset, DataLoader

class KeyframeDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.list_kf = os.listdir(folder_path)
        self.list_kf.sort(key=lambda x: int(x.split(".")[0]))
    
    def __len__(self):
        return len(self.list_kf)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.folder_path, self.list_kf[idx])
        img = Image.open(img_path)
        img = preprocess(img)
        return img

In [None]:
folder_path = r"C:\Users\ADMIN\Downloads\rag_langchain_\data\keyframes".replace('\\', '/')
len(os.listdir(folder_path))

In [None]:
out_dir = r"C:\Users\ADMIN\Downloads\rag_langchain_\data\clip_embs".replace("\\", "/")
exist_name = os.listdir(out_dir)
for idx, x in enumerate(exist_name):
    exist_name[idx] = x.split(".")[0]

for folder in os.listdir(folder_path):
    if folder in exist_name:
        print(f"Skip folder: {folder}")
        continue
    folder_path_full = os.path.join(folder_path, folder).replace("\\", "/")
    data = KeyframeDataset(folder_path_full)
    dataloader = DataLoader(data, batch_size=32, shuffle=False)
    
    all_features = []
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc=f"Folder {folder}"):
            features = model.encode_image(batch)
            features = features.cpu().numpy()
            all_features.extend(features)
    
    all_features = np.array(all_features)
    np.save(os.path.join(out_dir, folder + ".npy"), all_features)

print("-"*10, "DONE", "-"*10)

In [None]:
def norm_embeddings(embeddings):
    if embeddings.ndim == 1:
        norm = np.linalg.norm(embeddings)
        return embeddings / norm
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    return embeddings / norms

In [None]:
text = "a cat"
with torch.no_grad():
    text_features = model.encode_text(clip.tokenize([text]))[0]
text_features = norm(text_features)

In [None]:
sim = torch.nn.functional.softmax(embs @ text_features, dim=-1).cpu().numpy()
sort_ids = np.argsort(-sim)

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(15, 3))

for i in range(5):
    axes[i].imshow(imgs[sort_ids[i]].permute(1, 2, 0))
    axes[i].axis('off')
    axes[i].set_title(f"Top {i+1} - {sim[sort_ids[i]]:.4f}")

plt.tight_layout()
plt.show()