In [7]:
import os
import json
import numpy as np
from transformers import BertTokenizer, BertModel
from sentence_transformers import SentenceTransformer
import faiss
import torch
import time

In [8]:
def load_model(model_name):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model_name == "bert":
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        model = BertModel.from_pretrained("bert-base-uncased").to(device) 
        return tokenizer, model, device
    elif model_name == "bge-m3":
        model = SentenceTransformer("BAAI/bge-large-en").to(device)  
        return None, model, device 
    elif model_name == "all-MiniLM-L6-v2":
        model = SentenceTransformer("all-MiniLM-L6-v2").to(device) 
        return None, model, device 
    else:
        raise ValueError(f"Unsupported model: {model_name}")

In [9]:
def save_embeddings(data, model_name, output_dir):
    tokenizer, model, device = load_model(model_name)
    embeddings = []
    metadata = []

    for item in data:
        if model_name == "bert":
            inputs = tokenizer(item["caption"], return_tensors="pt", padding=True, truncation=True).to(device)  # Move inputs to GPU
            with torch.no_grad():
                outputs = model(**inputs)
                last_hidden_state = outputs.last_hidden_state
                embedding = torch.mean(last_hidden_state, dim=1).squeeze().cpu().numpy()  # Move result to CPU
        else:
            embedding = model.encode(item["caption"], device=device)  

        embeddings.append(embedding)
        metadata.append({
            "video_id": item["video_id"],
            "timestamp": item["timestamp"],
            "frame_image_path": item["frame_image_path"],
            "caption": item["caption"]
        })

    os.makedirs(output_dir, exist_ok=True)
    np.save(os.path.join(output_dir, f"{model_name}_embeddings.npy"), np.array(embeddings))
    with open(os.path.join(output_dir, f"{model_name}_metadata.json"), "w") as f:
        json.dump(metadata, f)

In [10]:
def load_embeddings(embeddings_path, metadata_path):
    embeddings = np.load(embeddings_path)
    with open(metadata_path, "r") as f:
        metadata = json.load(f)
    return embeddings, metadata


def create_and_save_faiss_index(embeddings, output_path):
    d = embeddings.shape[1] 
    index = faiss.IndexFlatL2(d)
    index.add(embeddings)
    faiss.write_index(index, output_path)


def load_faiss_index(input_path):
    return faiss.read_index(input_path)

In [11]:
def search(query, index, metadata, model, model_name, device, top_k=5):
    if model_name == "bert":
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(device) 
        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_state = outputs.last_hidden_state
            query_embedding = torch.mean(last_hidden_state, dim=1).squeeze().cpu().numpy() 
    else:
        query_embedding = model.encode(query, device=device) 

    query_embedding = query_embedding.reshape(1, -1)
    distances, indices = index.search(query_embedding, top_k)

    results = []
    for idx in indices[0]:
        results.append(metadata[idx])

    return results

In [20]:
if __name__ == "__main__":
    frames_dir = os.path.join("frames")
    captions_file = os.path.join("frame_captions_0.5sec.json")
    output_dir = os.path.join("embeddings")

    with open(captions_file, "r") as f:
        data = json.load(f)

    models = ["all-MiniLM-L6-v2"] #["bert", "bge-m3", "all-MiniLM-L6-v2"]

    for model_name in models:
        print(f"\n=== Processing with {model_name} ===")

        embeddings_path = os.path.join(output_dir, f"{model_name}_embeddings.npy")
        metadata_path = os.path.join(output_dir, f"{model_name}_metadata.json")
        index_path = os.path.join(output_dir, f"{model_name}_faiss.index")

        if not os.path.exists(embeddings_path) or not os.path.exists(metadata_path):
            print(f"Generating embeddings for {model_name}...")
            save_embeddings(data, model_name, output_dir)

        if not os.path.exists(index_path):
            print(f"Creating FAISS index for {model_name}...")
            embeddings, _ = load_embeddings(embeddings_path, metadata_path)
            create_and_save_faiss_index(embeddings, index_path)

        embeddings, metadata = load_embeddings(embeddings_path, metadata_path)
        index = load_faiss_index(index_path)

        query = "밤의 설원에서 숲으로 둘러싸인 따뜻한 불빛이 비치는 통나무집"
        tokenizer, model, device = load_model(model_name) 
        results = search(query, index, metadata, model, model_name, device) 

        for res in results:
            print(f"Video ID: {res['video_id']}, Timestamp: {res['timestamp']}")
            print(f"Caption: {res['caption']}")
            print(f"Frame Path: {res['frame_image_path']}\n")



=== Processing with all-MiniLM-L6-v2 ===
Video ID: Fz9HnTVx52g, Timestamp: 0:01:52.862
Caption: 체크무늬 드레스를 입고 웃고 있는 금발 머리의 여자
Frame Path: own_dataset_video/frames/Fz9HnTVx52g_112.863.jpg

Video ID: Fz9HnTVx52g, Timestamp: 0:01:53.780
Caption: 체크무늬 드레스를 입고 웃고 있는 금발 머리의 여자
Frame Path: own_dataset_video/frames/Fz9HnTVx52g_113.780.jpg

Video ID: Fz9HnTVx52g, Timestamp: 0:01:53.322
Caption: 체크무늬 드레스를 입고 웃고 있는 금발 머리의 여자
Frame Path: own_dataset_video/frames/Fz9HnTVx52g_113.322.jpg

Video ID: n1lbpj6868o, Timestamp: 0:00:52.761
Caption: 체크무늬 양복과 넥타이를 입은 남자가 뭔가를 보고 있다
Frame Path: own_dataset_video/frames/n1lbpj6868o_52.761.jpg

Video ID: n1lbpj6868o, Timestamp: 0:00:51.843
Caption: 체크무늬 양복과 넥타이를 입은 남자가 뭔가를 보고 있다
Frame Path: own_dataset_video/frames/n1lbpj6868o_51.843.jpg

