In [1]:
!pip install fastapi uvicorn sentence-transformers open-clip-torch qdrant-client \
    llama-index llama-index-vector-stores-qdrant llama-index-core pyngrok -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.5/1.5 MB[0m [31m45.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.3/337.3 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m62.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m303.3/303.3 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.5/42.5 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from fastapi import FastAPI, UploadFile, File, Form, Depends, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from qdrant_client import QdrantClient, models
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import VectorStoreIndex
from collections import defaultdict
import heapq
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import re
from pydantic import PrivateAttr
from llama_index.core.embeddings import BaseEmbedding
from typing import List, Optional
import open_clip
from pydantic import BaseModel
import json
from PIL import Image
import hashlib
import io
import json
from collections import defaultdict
import time

In [3]:
class CLIPEmbedding(BaseEmbedding):
    _model = PrivateAttr()
    _preprocess = PrivateAttr()
    _tokenizer = PrivateAttr()
    _device = PrivateAttr()

    def __init__(self, model_name: str = "ViT-H-14-quickgelu", device: str = "cpu"):
        super().__init__()
        self._device = device
        self._model, _, self._preprocess = open_clip.create_model_and_transforms(
            model_name=model_name,
            pretrained="dfn5b",
            device=self._device
        )
        self._tokenizer = open_clip.get_tokenizer(model_name)
        self._model = self._model.to(self._device).eval()

    def _encode_text(self, text: str) -> List[float]:
        tokens = self._tokenizer([text]).to(self._device)
        with torch.no_grad():
            emb = self._model.encode_text(tokens)
            emb = emb / emb.norm(dim=-1, keepdim=True)
        return emb[0].cpu().numpy().tolist()

    def _get_query_embedding(self, query: str) -> List[float]:
        return self._encode_text(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        return self._encode_text(text)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

    def _encode_image(self, image: Image.Image) -> List[float]:
        image_tensor = self._preprocess(image).unsqueeze(0).to(self._device)
        with torch.no_grad():
            emb = self._model.encode_image(image_tensor)
            emb = emb / emb.norm(dim=-1, keepdim=True)
        return emb[0].cpu().numpy().tolist()

    def _get_image_embedding(self, image: Image.Image) -> List[float]:
            return self._encode_image(image)

    async def _aget_image_embedding(self, image: Image.Image) -> List[float]:
        return self._get_image_embedding(image)

class CaptionEmbedding(BaseEmbedding):
    _model: SentenceTransformer = PrivateAttr()

    def __init__(self, model_name: str = "BAAI/bge-small-en", device: str = "cpu", trust_remote_code: bool = False):
        super().__init__()
        print(f"Loading model: {model_name}")
        self._model = SentenceTransformer(model_name, device=device,
                                          trust_remote_code=trust_remote_code)
        self._model = self._model.eval()

    def _get_query_embedding(self, query: str) -> List[float]:
        return self._model.encode(query, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).tolist()

    def _get_text_embedding(self, text: str) -> List[float]:
        return self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).tolist()

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

class Translator:
    def __init__(self, model_name: str = "VietAI/envit5-translation", device: str = 'cpu'):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)

    def _clean_prefix(self, text: str) -> str:
        return re.sub(r"^(en|vi)\s*:\s*", "", text.strip(), flags=re.IGNORECASE)

    def translate(self, text: str, source_lang: str = "vi", max_length: int = 128) -> str:
        content = f"{source_lang}: {text}"
        inputs = self.tokenizer(
            content,
            return_tensors="pt",
            truncation=True,
            max_length=max_length).to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(**inputs, max_length=max_length)
        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        return self._clean_prefix(decoded)

In [15]:
# Configuration with Kaggle secrets support
import os

# Use Kaggle secrets or fallback to environment variables
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    QDRANT_URL = user_secrets.get_secret("QDRANT_URL")
    QDRANT_API_KEY = user_secrets.get_secret("QDRANT_API_KEY")
    NGROK_AUTH_TOKEN = user_secrets.get_secret("NGROK_AUTH_TOKEN")
    print("✅ Using Kaggle secrets for configuration")
except Exception as e:
    print(f"⚠️  Kaggle secrets not available: {e}")
    print("🔧 Falling back to hardcoded values (update these with your credentials)")
    # Fallback to hardcoded values - UPDATE THESE WITH YOUR ACTUAL CREDENTIALS
    QDRANT_URL = "https://09a6d049-00c4-4b77-8e95-1dcc9ea5df34.eu-west-1-0.aws.cloud.qdrant.io:6333"
    QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.-ZPZib9FxehqbTuqxsk7QdVjBQd0LlQEq7dpjF1b4PI"
    NGROK_AUTH_TOKEN = "32BGHcjWCrbF91sroT87POFHH0p_82FYWpSCf33Lf4rhkPikL"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Using device: {DEVICE}")

# Initialize Qdrant client
qdrant_client = QdrantClient(
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
)

CORS_SETTINGS = {
    "allow_origins": ["*"],
    "allow_credentials": True,
    "allow_methods": ["*"],
    "allow_headers": ["*"],
}

# Collection names
CLIP_collection = "Image"
BGE_collection = "BGE_Caption"
GTE_collection = "GTE_Caption"

print(f"📊 Collections: {CLIP_collection}, {BGE_collection}, {GTE_collection}")

⚠️  Kaggle secrets not available: No module named 'kaggle_secrets'
🔧 Falling back to hardcoded values (update these with your credentials)
🚀 Using device: cpu
📊 Collections: Image, BGE_Caption, GTE_Caption


In [5]:
# Initialize models
print("🔧 Initializing models...")

translator = Translator(device=DEVICE)
print("✅ Translator loaded")

clip_embed_model = CLIPEmbedding(device=DEVICE)
clip_vector_store = QdrantVectorStore(client=qdrant_client,
                                      collection_name=CLIP_collection)
clip_index = VectorStoreIndex.from_vector_store(vector_store=clip_vector_store,
                                                embed_model=clip_embed_model)
print("✅ CLIP model and index loaded")

bge_embed_model = CaptionEmbedding(model_name="AITeamVN/Vietnamese_Embedding_v2", device=DEVICE)
bge_vector_store = QdrantVectorStore(client=qdrant_client,
                                     collection_name=BGE_collection)
bge_index = VectorStoreIndex.from_vector_store(vector_store=bge_vector_store,
                                               embed_model=bge_embed_model)
print("✅ BGE Vietnamese model loaded")

gte_embed_model = CaptionEmbedding(model_name="dangvantuan/vietnamese-document-embedding",
                                   device=DEVICE, trust_remote_code=True)
gte_vector_store = QdrantVectorStore(client=qdrant_client,
                                     collection_name=GTE_collection)
gte_index = VectorStoreIndex.from_vector_store(vector_store=gte_vector_store,
                                               embed_model=gte_embed_model)
print("✅ GTE Document model loaded")

print("🎉 All models initialized successfully!")

🔧 Initializing models...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/1.10M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/721 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.10G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.10G [00:00<?, ?B/s]

✅ Translator loaded


open_clip_pytorch_model.bin:   0%|          | 0.00/3.94G [00:00<?, ?B/s]

✅ CLIP model and index loaded
Loading model: AITeamVN/Vietnamese_Embedding_v2


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/171 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/664 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.27G [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

✅ BGE Vietnamese model loaded
Loading model: dangvantuan/vietnamese-document-embedding


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/171 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

configuration.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/dangvantuan/Vietnamese_impl:
- configuration.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/dangvantuan/Vietnamese_impl:
- modeling.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ GTE Document model loaded
🎉 All models initialized successfully!


In [6]:
# Build frame mappings for temporal search
print("📊 Building frame mappings for temporal search...")

FRAME_NAMES = []
offset = None
batch_count = 0

while True:
    result, offset = qdrant_client.scroll(
        collection_name=CLIP_collection,
        scroll_filter=None,
        with_payload=True,
        limit=1000,
        offset=offset
    )

    batch_count += 1
    print(f"  Batch {batch_count}: {len(result)} frames")

    for point in result:
        if "id" in point.payload:
            FRAME_NAMES.append(point.payload["id"])

    if offset is None:
        break

FRAME_NAMES = sorted(set(FRAME_NAMES))

VIDEO_TO_FRAMES = defaultdict(list)
for f in FRAME_NAMES:
    vid = "_".join(f.split("_")[:2])
    VIDEO_TO_FRAMES[vid].append(f)

print(f"✅ Loaded {len(FRAME_NAMES)} frame names from {len(VIDEO_TO_FRAMES)} videos")
print(f"📹 Sample videos: {list(VIDEO_TO_FRAMES.keys())[:5]}")

📊 Building frame mappings for temporal search...
  Batch 1: 1000 frames
  Batch 2: 1000 frames
  Batch 3: 1000 frames
  Batch 4: 1000 frames
  Batch 5: 1000 frames
  Batch 6: 1000 frames
  Batch 7: 1000 frames
  Batch 8: 1000 frames
  Batch 9: 1000 frames
  Batch 10: 1000 frames
  Batch 11: 1000 frames
  Batch 12: 1000 frames
  Batch 13: 1000 frames
  Batch 14: 1000 frames
  Batch 15: 1000 frames
  Batch 16: 1000 frames
  Batch 17: 1000 frames
  Batch 18: 1000 frames
  Batch 19: 1000 frames
  Batch 20: 1000 frames
  Batch 21: 1000 frames
  Batch 22: 1000 frames
  Batch 23: 1000 frames
  Batch 24: 1000 frames
  Batch 25: 1000 frames
  Batch 26: 1000 frames
  Batch 27: 1000 frames
  Batch 28: 1000 frames
  Batch 29: 1000 frames
  Batch 30: 1000 frames
  Batch 31: 1000 frames
  Batch 32: 1000 frames
  Batch 33: 1000 frames
  Batch 34: 1000 frames
  Batch 35: 1000 frames
  Batch 36: 1000 frames
  Batch 37: 1000 frames
  Batch 38: 1000 frames
  Batch 39: 1000 frames
  Batch 40: 1000 frames


In [7]:
def retrieve(query: str, topK: int, frame_ids: Optional[List] = None,
             mode: str = "clip", caption_mode: str = "bge"):
    """
    Universal retrieve function with frame filtering support
    """
    if mode == "clip":
        embed_model = clip_embed_model
        index = clip_index
        collection_name = CLIP_collection
        query_text = translator.translate(query, source_lang="vi")
    elif mode == "vintern":
        if caption_mode == "bge":
            embed_model = bge_embed_model
            index = bge_index
            collection_name = BGE_collection
            query_text = query
        else:  # gte
            embed_model = gte_embed_model
            index = gte_index
            collection_name = GTE_collection
            query_text = query

    if frame_ids:
        # Direct query with frame filtering
        vector_query = embed_model._get_text_embedding(query_text)
        nodes = qdrant_client.query_points(
            collection_name=collection_name,
            query=vector_query,
            limit=topK,
            with_payload=True,
            query_filter=models.Filter(must=[
                models.FieldCondition(
                    key="id",
                    match=models.MatchAny(any=frame_ids)
                )
            ])
        ).points
        results = [
            {"id": node.payload["id"].strip(), "score": node.score}
            for node in nodes
        ]
    else:
        # Use index for full search
        retriever = index.as_retriever(similarity_top_k=topK)
        nodes = retriever.retrieve(query_text)
        results = [
            {"id": node.metadata.get("id", "").strip(), "score": node.score}
            for node in nodes
        ]
    return results

In [8]:
def retrieve_frame(query: str, topK: int, mode: str = "hybrid", caption_mode: str = "bge",
                   alpha: float = 0.5, frame_ids: Optional[List] = None):

    if mode == "clip":
        clip_nodes = retrieve(query, topK, frame_ids, "clip")
        results = [
            {"image": node["id"], "caption": f"{node['id']} | Score: {node['score']:.2f}"}
            for node in clip_nodes
        ]
        return results

    elif mode == "vintern":
        caption_nodes = retrieve(query, topK, frame_ids, "vintern", caption_mode)
        results = [
            {"image": node["id"], "caption": f"{node['id']} | Score: {node['score']:.2f}"}
            for node in caption_nodes
        ]
        return results

    else:  # hybrid mode
        clip_nodes = retrieve(query, topK, frame_ids, "clip")
        caption_nodes = retrieve(query, topK, frame_ids, "vintern", caption_mode)

        combined_scores = defaultdict(float)
        for node in caption_nodes:
            combined_scores[node["id"]] += node["score"] * alpha

        for node in clip_nodes:
            combined_scores[node["id"]] += node["score"] * (1 - alpha)

        top_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:topK]

        return [
            {"image": video_id.strip(), "caption": f"{video_id} | Score: {score:.2f}"}
            for video_id, score in top_results
        ]

In [9]:
def retrieve_from_image(contents: bytes, topK: int):
    """
    Image-based search using CLIP embeddings
    """
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    vector_query = clip_embed_model._get_image_embedding(image)

    clip_nodes = qdrant_client.query_points(
        collection_name=CLIP_collection,
        query=vector_query,
        limit=topK,
        with_payload=True
    ).points

    results = [
        {
            "image": node.payload.get("id", "").strip(),
            "caption": f"{node.payload.get('id', '')} | Score: {node.score:.2f}"
        }
        for node in clip_nodes
    ]
    return results

In [10]:
def get_vid(image_name: str) -> str:
    """Extract video ID from frame name"""
    parts = image_name.split("_", 2)
    return f"{parts[0]}_{parts[1]}"

def temporal_search(events: List[str], topK: int = 100,
                    mode: str = "hybrid", caption_mode: str = "bge",
                    alpha: float = 0.5, search_mode: str = "progressive"):

    print(f"🔍 Starting {search_mode} temporal search with {len(events)} events...")

    if search_mode == "progressive":
        # Original progressive filtering approach
        frame_ids = None
        final_results = []

        for i, event in enumerate(events):
            print(f"  Event {i+1}: {event[:50]}...")

            results = retrieve_frame(query=event, topK=topK, mode=mode,
                                     caption_mode=caption_mode, alpha=alpha, frame_ids=frame_ids)
            final_results.append(results)

            # Extract video IDs to narrow search space
            video_ids = {"_".join(item['image'].split("_")[:2]) for item in results}
            print(f"    → Found {len(results)} results from {len(video_ids)} videos")

            # Update frame_ids for next iteration
            frame_ids = [f for vid in video_ids for f in VIDEO_TO_FRAMES[vid]]
            print(f"    → Narrowed search space to {len(frame_ids)} frames")

        print(f"✅ Progressive temporal search completed!")
        return final_results

    else:  # consolidated mode
        frame_ids = None
        final_results = []
        score_pattern = re.compile(r"Score:\s*([\d.]+)")

        for event in events:
            results = retrieve_frame(query=event, topK=topK, mode=mode,
                                     caption_mode=caption_mode, alpha=alpha, frame_ids=frame_ids)
            final_results.append(results)

            video_ids = {get_vid(item['image']) for item in results}

            frame_ids = [f for vid in video_ids for f in VIDEO_TO_FRAMES[vid]]

        video_sets = [{get_vid(item['image']) for item in results}
                       for results in final_results]
        common_videos = set.intersection(*video_sets) if video_sets else set()

        video_event_frames = {vid: {} for vid in common_videos}
        for event_idx, event_results in enumerate(final_results):
            event_name = events[event_idx]
            for item in event_results:
                vid = get_vid(item['image'])
                if vid not in common_videos:
                    continue

                match = score_pattern.search(item["caption"])
                score = float(match.group(1)) if match else 0.0

                current = video_event_frames[vid].get(event_name)
                if current is None or score > current["score"]:
                    video_event_frames[vid][event_name] = {
                        "frame": item["image"],
                        "caption": item["caption"],
                        "score": score
                    }

        results_list = []
        for vid, events_dict in video_event_frames.items():
            frames   = [e["frame"] for e in events_dict.values()]
            captions = [e["caption"] for e in events_dict.values()]
            total_score = sum(e["score"] for e in events_dict.values())

            results_list.append({
                "image": frames,
                "caption": captions,
                "score": total_score
            })

        results_list.sort(key=lambda x: x["score"], reverse=True)

        print(f"✅ Consolidated temporal search completed! Returning {len(results_list)} results")
        return results_list

In [11]:
def visualize_video_results(results, root_dir="/kaggle/input/aic-batch-1/keyframes/keyframes", top_k=10):
    """
    Enhanced visualization function for TRAKE mode temporal search results
    Displays video timelines with frames from each event in sequence

    Args:
        results: Results from temporal_search with search_mode="consolidated"
        root_dir: Root directory containing keyframe images
        top_k: Number of top video timelines to display
    """
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    from matplotlib.patches import Rectangle
    import numpy as np
    import os

    if not results:
        print("❌ No results to visualize")
        return

    # Handle both progressive and consolidated result formats
    if isinstance(results[0], dict) and 'video_id' in results[0]:
        # Consolidated format - video timelines
        video_results = results[:top_k]
        print(f"🎬 Visualizing {len(video_results)} video timelines for TRAKE mode")

        fig, axes = plt.subplots(len(video_results), len(video_results[0]['image']),
                                figsize=(4 * len(video_results[0]['image']), 3 * len(video_results)))
        fig.suptitle(f'TRAKE Mode: Top {len(video_results)} Video Event Sequences', fontsize=16, fontweight='bold')

        for video_idx, video_info in enumerate(video_results):
            frames = video_info["image"]
            video_id = video_info["video_id"]
            total_score = video_info["score"]

            for event_idx, frame_name in enumerate(frames):
                if len(video_results) == 1:
                    ax = axes[event_idx] if len(frames) > 1 else axes
                else:
                    ax = axes[video_idx, event_idx] if len(frames) > 1 else axes[video_idx]

                # Construct image path
                img_path = os.path.join(root_dir, frame_name + '.jpg')

                try:
                    if os.path.exists(img_path):
                        img = mpimg.imread(img_path)
                        ax.imshow(img)
                    else:
                        # Create placeholder if image not found
                        ax.text(0.5, 0.5, f'Image not found:\\n{frame_name}',
                               ha='center', va='center', transform=ax.transAxes,
                               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))

                except Exception as e:
                    ax.text(0.5, 0.5, f'Error loading:\\n{frame_name}',
                           ha='center', va='center', transform=ax.transAxes,
                           bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))

                # Add frame info and event marker
                ax.set_title(f'Event {event_idx + 1}\\n{frame_name}', fontsize=10)
                ax.axis('off')

                # Add colored border for event sequence
                colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray']
                border_color = colors[event_idx % len(colors)]
                rect = Rectangle((0, 0), 1, 1, transform=ax.transAxes,
                               linewidth=3, edgecolor=border_color, facecolor='none')
                ax.add_patch(rect)

            # Add video info as row label
            if len(video_results) > 1:
                axes[video_idx, 0].text(-0.1, 0.5,
                                       f'{video_id}\\nScore: {total_score:.3f}',
                                       ha='right', va='center', transform=axes[video_idx, 0].transAxes,
                                       fontsize=12, fontweight='bold',
                                       bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"))

        plt.tight_layout()
        plt.show()



    else:
        # Progressive format - convert to simple grid visualization
        print(f"📋 Visualizing progressive search results ({len(results)} events)")

        # Take top frames from final event for visualization
        final_event_results = results[-1][:top_k] if results else []

        if not final_event_results:
            print("❌ No final results to visualize")
            return

        cols = min(5, len(final_event_results))
        rows = (len(final_event_results) + cols - 1) // cols

        fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
        fig.suptitle(f'Progressive Search: Final Results ({len(results)} events)', fontsize=14)

        if rows == 1 and cols == 1:
            axes = [axes]
        elif rows == 1:
            axes = [axes]
        else:
            axes = axes.flatten() if rows > 1 else axes

        for idx, result in enumerate(final_event_results):
            if idx >= len(axes):
                break

            ax = axes[idx]
            frame_name = result['image']
            img_path = os.path.join(root_dir, frame_name + '.jpg')

            try:
                if os.path.exists(img_path):
                    img = mpimg.imread(img_path)
                    ax.imshow(img)
                else:
                    ax.text(0.5, 0.5, f'Image not found:\\n{frame_name}',
                           ha='center', va='center', transform=ax.transAxes)
            except Exception as e:
                ax.text(0.5, 0.5, f'Error loading:\\n{frame_name}',
                       ha='center', va='center', transform=ax.transAxes)

            ax.set_title(f'{frame_name}\\n{result["caption"].split("|")[1]}', fontsize=9)
            ax.axis('off')

        # Hide unused subplots
        for idx in range(len(final_event_results), len(axes)):
            axes[idx].set_visible(False)

        plt.tight_layout()
        plt.show()

print("✅ Video visualization function")

✅ Video visualization function


In [12]:
# FastAPI Application
app = FastAPI(title="Video Event Retrieval API v2.0",
              description="Enhanced multimodal search with temporal capabilities")

app.add_middleware(
    CORSMiddleware,
    allow_origins=CORS_SETTINGS["allow_origins"],
    allow_credentials=CORS_SETTINGS["allow_credentials"],
    allow_methods=CORS_SETTINGS["allow_methods"],
    allow_headers=CORS_SETTINGS["allow_headers"],
)

router = APIRouter()

@router.post("/search")
async def api_search(
    query: Optional[str] = Form(None),
    topK: int = Form(...),
    mode: str = Form("hybrid"),
    caption_mode: str = Form("bge"),
    alpha: float = Form(0.5),
    file: UploadFile = File(None)
):
    """
    Enhanced search API with caption mode support
    - mode: hybrid, clip, vintern, image
    - caption_mode: bge, gte (for vintern and hybrid modes)
    - alpha: text/visual balance for hybrid mode (0.1-0.9)
    """
    start_time = time.time()

    try:
        if mode == "image":
            if file is None:
                return {"error": "No file uploaded for image mode"}
            contents = await file.read()
            results = retrieve_from_image(contents=contents, topK=topK)
            search_info = f"IMAGE search"
        else:
            if query is None or query.strip() == "":
                return {"error": "No query provided for text mode"}
            results = retrieve_frame(query=query, topK=topK, mode=mode,
                                    caption_mode=caption_mode, alpha=alpha)
            search_info = f"{mode.upper()} mode with {caption_mode.upper()} model"

        duration = time.time() - start_time

        return {
            "results": results,
            "search_info": {
                "mode": mode,
                "caption_mode": caption_mode if mode in ["hybrid", "vintern"] else None,
                "alpha": alpha if mode == "hybrid" else None,
                "duration": round(duration, 3),
                "count": len(results),
                "description": search_info
            }
        }
    except Exception as e:
        return {"error": f"Search failed: {str(e)}"}

@router.post("/temporal_search")
async def api_temporal_search(
    events: str = Form(...),  # JSON string of event list
    topK: int = Form(100),
    mode: str = Form("hybrid"),
    caption_mode: str = Form("bge"),
    alpha: float = Form(0.5),
    search_mode: str = Form("progressive")  # New parameter: "progressive" or "consolidated"
):
    """
    Enhanced Temporal search API for TRAKE mode
    - events: JSON array of sequential event descriptions
    - search_mode: "progressive" (frontend compatible) or "consolidated" (TRAKE visualization)
    - Returns different result formats based on search_mode
    """
    start_time = time.time()

    try:
        events_list = json.loads(events)
        if not isinstance(events_list, list) or len(events_list) == 0:
            return {"error": "Events must be a non-empty list"}

        # Filter out empty events
        valid_events = [e.strip() for e in events_list if e.strip()]
        if len(valid_events) == 0:
            return {"error": "No valid events provided"}

        results = temporal_search(events=valid_events, topK=topK, mode=mode,
                                 caption_mode=caption_mode, alpha=alpha,
                                 search_mode=search_mode)

        duration = time.time() - start_time

        # Different response format based on search mode
        if search_mode == "consolidated":
            final_count = len(results)
            result_type = "video_timelines"
        else:  # progressive
            final_count = len(results[-1]) if results else 0
            result_type = "progressive_events"

        return {
            "results": results,
            "search_info": {
                "mode": mode,
                "caption_mode": caption_mode,
                "alpha": alpha if mode == "hybrid" else None,
                "search_mode": search_mode,
                "duration": round(duration, 3),
                "events_processed": len(valid_events),
                "final_count": final_count,
                "result_type": result_type,
                "description": f"{search_mode.title()} temporal search through {len(valid_events)} events"
            }
        }
    except json.JSONDecodeError:
        return {"error": "Invalid JSON format for events"}
    except Exception as e:
        return {"error": f"Temporal search failed: {str(e)}"}

@router.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "device": DEVICE,
        "models_loaded": {
            "clip": clip_embed_model is not None,
            "bge": bge_embed_model is not None,
            "gte": gte_embed_model is not None,
            "translator": translator is not None
        },
        "collections": {
            "clip": CLIP_collection,
            "bge": BGE_collection,
            "gte": GTE_collection
        },
        "frame_count": len(FRAME_NAMES),
        "video_count": len(VIDEO_TO_FRAMES),
        "supported_search_modes": ["progressive", "consolidated"]
    }

app.include_router(router)
print("✅ FastAPI application configured with consolidated mode support")

✅ FastAPI application configured with consolidated mode support


In [17]:
# Server setup with ngrok
import os, time, threading, socket
from pyngrok import ngrok
import uvicorn

PORT = 8000
HOST = "0.0.0.0"

# Set ngrok auth token
if NGROK_AUTH_TOKEN and NGROK_AUTH_TOKEN != "YOUR_NGROK_TOKEN_HERE":
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
    print("✅ Ngrok auth token set")
else:
    print("⚠️  NGROK_AUTH_TOKEN not configured. Please update with your token.")
    print("   Get your token from: https://dashboard.ngrok.com/get-started/your-authtoken")

def is_port_in_use(port: int, host="127.0.0.1") -> bool:
    """Check if a local TCP port is already in use."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex((host, port)) == 0

def run_server():
    """Run FastAPI server in background thread"""
    uvicorn.run(app, host=HOST, port=PORT, log_level="info")

# Start server only if not already running
if not is_port_in_use(PORT):
    print(f"🚀 Starting FastAPI server on {HOST}:{PORT}")
    server_thread = threading.Thread(target=run_server, daemon=True)
    server_thread.start()
    time.sleep(3)  # Wait for server startup
    print("✅ Server started successfully")
else:
    print(f"🔁 Server already running on http://localhost:{PORT}")

# Setup ngrok tunnel
try:
    # Clean up existing tunnels
    for t in ngrok.get_tunnels():
        addr = (t.config or {}).get("addr", "")
        if str(PORT) in addr:
            try:
                ngrok.disconnect(t.public_url)
            except Exception:
                pass

    # Kill all tunnels if too many
    if len(ngrok.get_tunnels()) >= 3:
        ngrok.kill()

    # Create new tunnel
    tunnel = ngrok.connect(addr=PORT, proto="http", bind_tls=True)
    PUBLIC_URL = tunnel.public_url

    print("\n" + "="*60)
    print("🌐 BACKEND READY!")
    print(f"📡 Public URL: {PUBLIC_URL}")
    print(f"📖 API Docs: {PUBLIC_URL}/docs")
    print(f"🏥 Health Check: {PUBLIC_URL}/health")
    print(f"💻 Local URL: http://localhost:{PORT}")
    print("\n🎯 COPY THE PUBLIC URL TO YOUR FRONTEND!")
    print("="*60)

    # Save to global for later use
    globals()["PUBLIC_URL"] = PUBLIC_URL

except Exception as e:
    print(f"❌ Ngrok tunnel failed: {e}")
    print(f"🔧 Server still available locally: http://localhost:{PORT}")
    print("💡 Try restarting the kernel or checking your ngrok auth token")

✅ Ngrok auth token set
🔁 Server already running on http://localhost:8000

🌐 BACKEND READY!
📡 Public URL: https://b3825d97ce05.ngrok-free.app
📖 API Docs: https://b3825d97ce05.ngrok-free.app/docs
🏥 Health Check: https://b3825d97ce05.ngrok-free.app/health
💻 Local URL: http://localhost:8000

🎯 COPY THE PUBLIC URL TO YOUR FRONTEND!


In [14]:
# Test the enhanced API
import requests

def test_api_endpoints(use_public_url=True):
    """Test all API endpoints with the new features"""

    if use_public_url and 'PUBLIC_URL' in globals():
        base_url = globals()['PUBLIC_URL']
        print(f"🔗 Testing public URL: {base_url}")
    else:
        base_url = "http://localhost:8000"
        print(f"🔗 Testing local URL: {base_url}")

    # Test health endpoint
    print("\n1. 🏥 Testing health endpoint...")
    try:
        response = requests.get(f"{base_url}/health", timeout=10)
        if response.status_code == 200:
            health = response.json()
            print(f"   ✅ Health check passed")
            print(f"   📊 {health['frame_count']} frames from {health['video_count']} videos")
            print(f"   🤖 Models: {health['models_loaded']}")
        else:
            print(f"   ❌ Health check failed: {response.status_code}")
    except Exception as e:
        print(f"   ❌ Health check error: {e}")
        return

    # Test search with BGE
    print("\n2. 🔍 Testing hybrid search with BGE...")
    test_search(base_url, "người đang nấu ăn", mode="hybrid", caption_mode="bge")

    # Test search with GTE
    print("\n3. 🔍 Testing hybrid search with GTE...")
    test_search(base_url, "người đang nấu ăn", mode="hybrid", caption_mode="gte")

    # Test vintern only with GTE
    print("\n4. 📝 Testing vintern search with GTE...")
    test_search(base_url, "cảnh đẹp thiên nhiên", mode="vintern", caption_mode="gte")

    # Test temporal search
    print("\n5. ⏰ Testing temporal search...")
    test_temporal_search(base_url, [
        "Một người  đang cắt đôi ổ bánh mì có rắc mè rồi đem nướng trên chảo. Hãy lấy khoảnh khắc chiếc dao cắt qua hoàn toàn chiếc bánh.",
        "Sau đó người này rắc bột lên những miếng thịt, trong quá trình này người đầu bếp lật những miếng thịt để rắc bột đều hai mặt. Hãy lấy khoảnh khắc đầu tiên người đầu bếp này buông tay khỏi miếng thịt sau khi lật miếng thịt đầu tiên.",
        "Các miếng thịt sau đó được đem đi áp chảo cùng với bơ (3 ngang 1 dọc theo chiều của camera). Hãy lấy khoảnh khắc đầu tiên người đầu bếp cầm vào chảo để nhấc lên đảo bơ đều xung quanh"
    ])

def test_search(base_url, query, mode="hybrid", caption_mode="bge", topK=5):
    """Test search endpoint"""
    try:
        data = {
            "query": query,
            "topK": topK,
            "mode": mode,
            "caption_mode": caption_mode,
            "alpha": 0.6
        }

        response = requests.post(f"{base_url}/search", data=data, timeout=30)

        if response.status_code == 200:
            result = response.json()
            search_info = result.get("search_info", {})
            print(f"   ✅ Search successful: {search_info.get('description')}")
            print(f"   ⏱️ Duration: {search_info.get('duration')}s")
            print(f"   📊 Results: {len(result['results'])}")

            # Show top results
            for i, res in enumerate(result['results'][:3]):
                print(f"      {i+1}. {res['caption']}")
        else:
            print(f"   ❌ Search failed: {response.status_code} - {response.text}")
    except Exception as e:
        print(f"   ❌ Search error: {e}")

def test_temporal_search(base_url, events, topK=20):
    """Test temporal search endpoint"""
    try:
        data = {
            "events": json.dumps(events),
            "topK": topK,
            "mode": "hybrid",
            "caption_mode": "gte",
            "alpha": 0.7
        }

        response = requests.post(f"{base_url}/temporal_search", data=data, timeout=60)

        if response.status_code == 200:
            result = response.json()
            search_info = result.get("search_info", {})
            print(f"   ✅ Temporal search successful: {search_info.get('description')}")
            print(f"   ⏱️ Duration: {search_info.get('duration')}s")
            print(f"   📊 Events processed: {search_info.get('events_processed')}")
            print(f"   🎯 Final results: {search_info.get('final_count')}")

            # Show progression
            for i, event_results in enumerate(result['results']):
                print(f"      Event {i+1}: {len(event_results)} results")
                for j, res in enumerate(event_results[:2]):
                    print(f"        → {res['caption']}")
        else:
            print(f"   ❌ Temporal search failed: {response.status_code} - {response.text}")
    except Exception as e:
        print(f"   ❌ Temporal search error: {e}")

# Run tests
print("🧪 Testing Enhanced API...")
test_api_endpoints()

🧪 Testing Enhanced API...
🔗 Testing local URL: http://localhost:8000

1. 🏥 Testing health endpoint...
INFO:     127.0.0.1:35126 - "GET /health HTTP/1.1" 200 OK
   ✅ Health check passed
   📊 96545 frames from 866 videos
   🤖 Models: {'clip': True, 'bge': True, 'gte': True, 'translator': True}

2. 🔍 Testing hybrid search with BGE...
INFO:     127.0.0.1:35140 - "POST /search HTTP/1.1" 200 OK
   ✅ Search successful: HYBRID mode with BGE model
   ⏱️ Duration: 8.53s
   📊 Results: 5
      1. L26_V272_5537 | Score: 0.42
      2. L26_V452_4878 | Score: 0.42
      3. L26_V339_5215 | Score: 0.41

3. 🔍 Testing hybrid search with GTE...
INFO:     127.0.0.1:35148 - "POST /search HTTP/1.1" 200 OK
   ✅ Search successful: HYBRID mode with GTE model
   ⏱️ Duration: 2.003s
   📊 Results: 5
      1. L27_V003_6294 | Score: 0.38
      2. L26_V367_5958 | Score: 0.38
      3. L26_V484_3594 | Score: 0.37

4. 📝 Testing vintern search with GTE...
INFO:     127.0.0.1:44490 - "POST /search HTTP/1.1" 200 OK
   ✅ Sea