In [1]:
%pip -q install open_clip_torch
%pip -q install transformers
%pip -q install torch
%pip -q install pillow
%pip -q install numpy
%pip -q install torch
%pip -q install tqdm
%pip -q install opencv-python
%pip -q install imagehash
%pip -q install ffmpeg-python
%pip -q install einops
%pip -q install faiss-cpu
%pip -q install usearch
%pip -q install translate
%pip -q install googletrans
%pip -q install pillow
%pip -q install matplotlib

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you 

In [4]:
%pip install httpcore

Note: you may need to restart the kernel to use updated packages.


In [None]:
# standard lib
import os
import re
import json
import io
import base64
from typing import List, Dict, Any, Tuple

# numerical computing
import numpy as np

# Deep Learning and AI:
import torch
import open_clip

#img processing
from PIL import Image

# Visualization
import matplotlib.pyplot as plt

# Progress Tracking
from tqdm.notebook import tqdm

#indexing and searchin:
import faiss
from usearch.index import Index as UsearchIndex

# translation
import googletrans
import translate



In [None]:
class CLIPEmbedding:
    def __init__(
        self, 
        model_name: str,
        model_nick_name: str,
        device: str = None
    ):
        self.model_nick_name = model_nick_name
        self.device = device if device is not None else ('cuda' if torch.cuda.is_available() else 'cpu')
        
        try:
            print(f"Attempting to load model on {self.device}")
            self.model, _, self.preprocess = open_clip.create_model_and_transforms(model_name)
            self.model = self.model.to(self.device)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                print("GPU out of memory. Falling back to CPU.")
                self.device = 'cpu'
                self.model, _, self.preprocess = open_clip.create_model_and_transforms(model_name)
                self.model = self.model.to(self.device)
            else:
                raise e

        self.model.eval()
        self.tokenizer = open_clip.get_tokenizer(model_name)
        self.faiss_index = None
        self.usearch_index = None
        self.global_index2img_path = {}
    
    def process_image_folder(
        self,
        root_dir: str,
        output_dir: str,
        batch_size: int = 32
    ):
        os.makedirs(output_dir, exist_ok=True)
        
        image_paths = []
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp')):
                    image_paths.append(os.path.join(root, file))
        
        if not image_paths:
            print(f"No image found in the given root directory: {root_dir}")
            return None

        image_paths.sort()
        
        embeddings = []
        for i in tqdm(range(0, len(image_paths), batch_size), desc="Processing Batches of images", unit=f'batch, size = {batch_size}'):
            batch_paths = image_paths[i:i+batch_size]
            batch_embeddings = self.process_batch(batch_paths)
            if batch_embeddings is not None:
                embeddings.append(batch_embeddings)
        
        if embeddings:
            all_embeddings = np.vstack(embeddings)
            
            # Save CLIP embeddings
            clip_file = os.path.join(output_dir, f'{self.model_nick_name}_clip_embeddings.npy')
            np.save(clip_file, all_embeddings)
            print(f"CLIP embeddings saved to {clip_file}")
            
            # Save global_index2img_path
            self.global_index2img_path = {i: path for i, path in enumerate(image_paths)}
            index_path_file = os.path.join(output_dir, 'global2imgpath.json')
            with open(index_path_file, 'w') as f:
                json.dump(self.global_index2img_path, f, indent=4)
            print(f"global2imgpath saved to {index_path_file}")
            
            # Build and save FAISS index
            self.build_faiss_index(all_embeddings)
            faiss_file = os.path.join(output_dir, f"{self.model_nick_name}_faiss.bin")
            self.save_faiss_index(faiss_file)
            
            # Build and save USearch index
            self.build_usearch_index(all_embeddings)
            usearch_file = os.path.join(output_dir, f"{self.model_nick_name}_usearch.bin")
            self.save_usearch_index(usearch_file)
            
            return all_embeddings
        else:
            print("No embeddings were created.")
            return None

    def process_batch(self, batch_paths):
        batch_images = []
        for img_path in batch_paths:
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = self.preprocess(img).unsqueeze(0)
                batch_images.append(img_tensor)
            except Exception as e:
                print(f"Error processing image {img_path}: {str(e)}")
                continue
        
        if batch_images:
            batch_tensor = torch.cat(batch_images).to(self.device)
            with torch.no_grad():
                batch_embeddings = self.model.encode_image(batch_tensor).cpu().detach().numpy().astype(np.float32)
            return batch_embeddings
        return None
    
    def build_faiss_index(self, embeddings: np.ndarray):
        dimension = embeddings.shape[1]
        self.faiss_index = faiss.IndexFlatIP(dimension)
        self.faiss_index.add(embeddings)
    
    def save_faiss_index(self, file_path: str):
        faiss.write_index(self.faiss_index, file_path)
        print(f"FAISS index saved to {file_path}")

    def load_faiss_index(self, file_path: str):
        self.faiss_index = faiss.read_index(file_path)
        print(f"FAISS index loaded from {file_path}")
        
        
    def build_usearch_index(self, embeddings: np.ndarray):
        dimension = embeddings.shape[1]
        self.usearch_index = UsearchIndex(ndim=dimension, metric='cosine')
        for i, embedding in enumerate(embeddings):
            self.usearch_index.add(i, embedding)

    def save_usearch_index(self, file_path: str):
        self.usearch_index.save(file_path)
        print(f"USearch index saved to {file_path}")

    def load_usearch_index(self, file_path: str):
        dimension = self.faiss_index.d
        self.usearch_index = UsearchIndex(ndim=dimension, metric='cosine')
        self.usearch_index.load(file_path)
        print(f"USearch index loaded from {file_path}")
    
    def faiss_search(self, query_embedding: np.ndarray, k: int) -> List[Tuple[int, float]]:
        faiss.normalize_L2(query_embedding)
        distances, indices = self.faiss_index.search(query_embedding, k)
        return list(zip(indices[0], distances[0]))
    def usearch_search(self, query_embedding: np.ndarray, k: int) -> List[Tuple[int, float]]:
        matches = self.usearch_index.search(query_embedding, k)
        return [(int(match.key), match.distance) for match in matches]
        
        
    def text_query(self, query: str, k: int = 20) -> Tuple[List[Tuple[int, float]], List[Tuple[int, float]]]:
        with torch.no_grad():
            text_tokens = self.tokenizer([query]).to(self.device)
            query_embedding = self.model.encode_text(text_tokens).cpu().detach().numpy().astype(np.float32)
        
        faiss_results = self.faiss_search(query_embedding, k)
        usearch_results = self.usearch_search(query_embedding[0], k)
        
        return faiss_results, usearch_results
    
    def image_query(self, img_data: str, k: int = 20) -> Tuple[List[Tuple[int, float]], List[Tuple[int, float]]]:
        
        img_bytes = base64.b64decode(img_data)
        img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
        
        img_preprocessed = self.preprocess(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            query_embedding = self.model.encode_image(img_preprocessed).cpu().detach().numpy().astype(np.float32)
        
        faiss_results = self._faiss_search(query_embedding, k)
        usearch_results = self._usearch_search(query_embedding[0], k)
        
        return faiss_results, usearch_results
    
    
    def get_image_paths(self, indices: List[int]) -> List[str]:
        return [self.global_index2image_path[i] for i in indices]

    def load_indexes(self, faiss_path: str, usearch_path: str, global2imgpath_path: str):
        self.load_faiss_index(faiss_path)
        self.load_usearch_index(usearch_path)
        with open(global2imgpath_path, 'r') as f:
            self.global_index2image_path = json.load(f)
        print("All indexes and mappings loaded successfully.")
        

In [None]:
class Translation:
    def __init__(self, from_lang='vi', to_lang='en', mode='google'):
        self.__mode = mode
        self.__from_lang = from_lang
        self.__to_lang = to_lang

        if mode == 'googletrans':
            self.translator = googletrans.Translator()
        elif mode == 'translate':
            self.translator = translate.Translator(from_lang=from_lang, to_lang=to_lang)

    def preprocessing(self, text):
        return text.lower()

    def __call__(self, text):
        text = self.preprocessing(text)
        return self.translator.translate(text) if self.__mode == 'translate' \
                else self.translator.translate(text, dest=self.__to_lang).text

In [None]:
def display_result(indices, embedder, k=20):
    k = min(k, len(indices))
    
    n_cols = 3 
    n_rows = (k + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 7 * n_rows)) 
    
    if n_rows == 1:
        axes = [axes]
    if n_cols == 1:
        axes = [[ax] for ax in axes]
    
    for i, idx in enumerate(indices[:k]):
        if i >= k:
            break
        
        row = i // n_cols
        col = i % n_cols
        
        img_path = embedder.get_image_paths([idx])[0]
        img = Image.open(img_path)
        axes[row][col].imshow(img)
        axes[row][col].set_title(f"Rank: {i+1}", fontsize=12)
        filename = os.path.basename(img_path)
        axes[row][col].set_xlabel(filename, fontsize=10, wrap=True)
        
        axes[row][col].axis('off')
        
    for i in range(k, n_rows * n_cols):
        row = i // n_cols
        col = i % n_cols
        axes[row][col].axis('off')
    
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.3, wspace=0.1)
    plt.show()

In [None]:
def get_sorted_query_files(root_dir: str) -> List[str]:
    list_querys = []
    for dir_root, _, filenames in os.walk(root_dir):
        for filename in filenames:
            if filename.lower().endswith('.txt'):
                list_querys.append(os.path.join(dir_root, filename))
    
    def sort_key(filepath):
        filename = os.path.basename(filepath)
        match = re.search(r'p(\d+)\.txt$', filename)
        if match:
            return int(match.group(1))
        return 0 
    return sorted(list_querys, key=sort_key)

def read_queries(file_paths: List[str]) -> List[str]:
    queries = []
    for file_path in file_paths:
        with open(file_path, 'r', encoding='utf-8') as file:
            queries.append(file.read().strip())
    return queries


In [None]:
root_directory_query = '/kaggle/input/query-test'
image_root_directory = "/kaggle/input/aic-video"
output_dir = '/kaggle/working/DFN5B_CLIP_ViT_H_14_378'
model_name = 'hf-hub:apple/DFN5B-CLIP-ViT-H-14-378'
batch_size = 16

In [None]:
embedder = CLIPEmbedding(model_name=model_name, model_nick_name = 'DFN5B_CLIP_ViT_H_14_378')
print(f"\nProcessing images for model: {model_name}")
embeddings = embedder.process_image_folder(image_root_directory, output_dir, batch_size)
if embeddings is not None:
    print(f"Processed {len(embedder.global_index2image_path)} images")
    print(f"Embedding shape: {embeddings.shape}")
else:
    print(f"Processing images failed for model {model_name}. Please check the image folder path and content.")
    return

Attempting to load model on cpu


  checkpoint = torch.load(checkpoint_path, map_location=map_location)



Processing images for model: hf-hub:apple/DFN5B-CLIP-ViT-H-14-378


OSError: [Errno 30] Read-only file system: '/kaggle'

In [None]:
translator = Translation(from_lang='vi', to_lang='en', mode='translate')
sorted_query_files = get_sorted_query_files(root_directory)
queries = read_queries(sorted_query_files)
translated_queries = [translator(query) for query in tqdm(queries, desc="Translating queries")]

In [None]:
print("\nPerforming text queries")
for query, translated_query in tqdm(zip(queries, translated_queries), total=len(queries), desc="Processing queries"):
    print(f"\nOriginal Query: {query}")
    print(f"Translated Query: {translated_query}")

    faiss_results, usearch_results = embedder.text_query(translated_query, k=5)

    print("FAISS Results:")
    for idx, score in faiss_results:
        print(f"Image: {embedder.get_image_paths([idx])[0]}, Score: {score}")

    print("\nFAISS Results (Visual):")
    faiss_indices = [idx for idx, _ in faiss_results]
    display_result(faiss_indices, embedder, k=20)

    print("\nUSearch Results:")
    for idx, score in usearch_results:
        print(f"Image: {embedder.get_image_paths([idx])[0]}, Score: {score}")

    print("\nUSearch Results (Visual):")
    usearch_indices = [idx for idx, _ in usearch_results]
    display_result(usearch_indices, embedder, k=20)