In [1]:
!git clone https://github.com/ludoplayer69/videotree

Cloning into 'videotree'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects: 100% (12/12), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 12 (delta 1), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (12/12), 1.35 MiB | 5.15 MiB/s, done.
Resolving deltas: 100% (1/1), done.


In [2]:
!pip install -U bitsandbytes accelerate transformers

Collecting bitsandbytes
  Downloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting accelerate
  Downloading accelerate-1.10.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers
  Downloading transformers-4.55.0-py3-none-any.whl.metadata (39 kB)
Collecting huggingface_hub>=0.21.0 (from accelerate)
  Downloading huggingface_hub-0.34.4-py3-none-any.whl.metadata (14 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<

In [5]:
from pathlib import Path
from tqdm import tqdm
import json
import gdown  # pip install gdown

DRIVE_JSON_PATH = Path('/kaggle/working/videotree/drive_ids.json')  # Update with your actual path
VIDEO_SAVE_PATH = Path('/kaggle/working/downloaded_videos')
MAX_VIDEOS = 2  # Set your preferred limit

VIDEO_SAVE_PATH.mkdir(parents=True, exist_ok=True)

def load_json(json_path):
    with open(json_path, 'r') as f:
        return json.load(f)

def download_videos(drive_json, max_downloads):
    count = 0
    pbar = tqdm(total=min(len(drive_json), max_downloads), desc="Downloading Videos")

    for uuid, drive_id in drive_json.items():
        if count >= max_downloads:
            break

        output_file = VIDEO_SAVE_PATH / f"{uuid}.mp4"
        gdown.download(id=drive_id, output=str(output_file), quiet=False)
        count += 1
        pbar.update(1)

    pbar.close()

if __name__ == "__main__":
    drive_data = load_json(DRIVE_JSON_PATH)
    download_videos(drive_data, MAX_VIDEOS)


Downloading Videos:   0%|          | 0/2 [00:00<?, ?it/s]Downloading...
From: https://drive.google.com/uc?id=1ZdZ8aUcBNzndj135bqFrxb9L816EMGp1
To: /kaggle/working/downloaded_videos/0074f737-11cb-497d-8d07-77c3a8127391.mp4

100%|██████████| 15.5M/15.5M [00:00<00:00, 221MB/s]
Downloading Videos:  50%|█████     | 1/2 [00:01<00:01,  1.74s/it]Downloading...
From: https://drive.google.com/uc?id=1bVNvPX6BNPIqcqMk-ZJr6WLsXNQybg64
To: /kaggle/working/downloaded_videos/00b9a0de-c59e-49cb-a127-6081e2fb8c8e.mp4

100%|██████████| 12.2M/12.2M [00:00<00:00, 208MB/s]
Downloading Videos: 100%|██████████| 2/2 [00:03<00:00,  1.69s/it]


# Feature Extraction

In [6]:
# ---- Globals ----
import os, sys, json, cv2, torch
from pathlib import Path
from PIL import Image
from tqdm import tqdm

# Paths
INPUT_VIDEOS = Path('/kaggle/working/downloaded_videos')
FRAMES_DIR   = Path('/kaggle/working/extracted_frames')
FEATURES_DIR = Path('/kaggle/working/extracted_features')
ANNOTATION_PATH = Path('/kaggle/input/fullset_anno.json')

FRAMES_DIR.mkdir(parents=True, exist_ok=True)
FEATURES_DIR.mkdir(parents=True, exist_ok=True)

# Params
FPS = 1
MAX_EXAMPLES = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Global model state
perceptionclip_model = None
perceptionclip_preprocess = None

# ---- Utilities ----
def load_json(file_path: Path):
    with open(file_path, 'r') as f:
        return json.load(f)

def save_image_features(img_feats: torch.Tensor, name_id: str, save_folder: Path):
    torch.save(img_feats, save_folder / f"{name_id}.pt")

def _numeric_sort_key(p: Path):
    try:
        return int(p.stem)
    except ValueError:
        return p.stem

# ---- Frame extraction ----
def extract_frames(videos: list[Path] | None = None, fps: int = FPS):
    video_iter = videos if videos is not None else list(INPUT_VIDEOS.iterdir())
    for video_fp in tqdm(video_iter, desc="Extracting frames"):
        out_dir = FRAMES_DIR / video_fp.stem
        out_dir.mkdir(parents=True, exist_ok=True)

        cap = cv2.VideoCapture(str(video_fp))
        if not cap.isOpened():
            print(f"[WARN] Could not open video: {video_fp}")
            continue

        fps_ori = cap.get(cv2.CAP_PROP_FPS)
        try:
            fps_ori = int(fps_ori) if fps_ori and fps_ori > 0 else 1
        except Exception:
            fps_ori = 1
        frame_interval = max(1, int(fps_ori // max(1, fps)))

        count = 0
        success, img = cap.read()
        while success:
            if count % frame_interval == 0:
                cv2.imwrite(str(out_dir / f"{count}.jpg"), img)
            success, img = cap.read()
            count += 1
        cap.release()

# ---- Model loading (singleton) ----
def ensure_model_loaded(model_name: str = 'PE-Core-B16-224', force_reload: bool = False):
    global perceptionclip_model, perceptionclip_preprocess

    if perceptionclip_model is not None and perceptionclip_preprocess is not None and not force_reload:
        return perceptionclip_model, perceptionclip_preprocess

    # One-time setup
    if not Path('perception_models').exists():
        os.system('git clone https://github.com/facebookresearch/perception_models.git')
        os.system('pip install -q decord ftfy')

    sys.path.append('./perception_models')
    cwd = os.getcwd()
    os.chdir('./perception_models')

    import core.vision_encoder.pe as pe
    import core.vision_encoder.transforms as transforms

    model = pe.CLIP.from_config(model_name, pretrained=True).to(DEVICE).eval()
    preprocess = transforms.get_image_transform(model.image_size)

    os.chdir(cwd)

    perceptionclip_model = model
    perceptionclip_preprocess = preprocess
    return perceptionclip_model, perceptionclip_preprocess

# ---- Feature extraction ----
@torch.inference_mode()
def extract_features_for_dir(example_dir: Path):
    if not example_dir.exists():
        print(f"[WARN] Frames dir not found: {example_dir}")
        return

    model, preprocess = ensure_model_loaded()

    image_files = sorted(list(example_dir.iterdir()), key=_numeric_sort_key)
    feats_list = []

    # Use autocast only when CUDA is available
    # use_cuda = torch.cuda.is_available()
    # autocast_ctx = torch.cuda.amp.autocast if use_cuda else torch.autocast
    # autocast_kwargs = ({'dtype': torch.float16} if not use_cuda else {})

    # with (autocast_ctx(device_type='cuda') if not autocast_kwargs else autocast_ctx(**autocast_kwargs)):
    #     for img_fp in image_files:
    #         img = Image.open(img_fp).convert('RGB')
    #         inp = perceptionclip_preprocess(img).unsqueeze(0).to(DEVICE)
    #         feat = perceptionclip_model.encode_image(inp)
    #         feats_list.append(feat)

    # refactored
    for image_file in image_files:

        inputs = preprocess(Image.open(image_file)).unsqueeze(0).to(DEVICE)
    
    
        with torch.no_grad(), torch.cuda.amp.autocast():
            feats = model.encode_image(inputs)
            feats_list.append(feats)
    
        if len(feats_list) == 0:
            print(f"[WARN] No frames in {example_dir}")
            return

    stacked = torch.cat(feats_list, dim=0)  # [T, D]
    save_image_features(stacked, example_dir.name, FEATURES_DIR)

def extract_features_from_all(MAX: int = MAX_EXAMPLES, filter_by_json: Path | None = ANNOTATION_PATH):
    valid_names = None
    if filter_by_json and Path(filter_by_json).exists():
        valid_names = set(load_json(filter_by_json).keys())

    dirs = [d for d in FRAMES_DIR.iterdir() if d.is_dir()]
    processed = 0
    for d in tqdm(dirs, desc="Extracting features"):
        if processed >= MAX:
            break
        if valid_names is not None and d.name not in valid_names:
            continue
        extract_features_for_dir(d)
        processed += 1

# ---- Single-video convenience ----
def infer_single_video(video_fp: Path):
    # 1) Extract frames for just this video
    extract_frames(videos=[video_fp], fps=FPS)
    # 2) Extract features for the corresponding frames directory
    example_dir = FRAMES_DIR / video_fp.stem
    extract_features_for_dir(example_dir)

# ---- Pipeline entrypoints ----
def run_pipeline(process_all: bool = True, video_list: list[Path] | None = None):
    """
    - If process_all=True, process every video under INPUT_VIDEOS.
    - If process_all=False, expects video_list (list of Paths) and processes only those.
    """
    ensure_model_loaded()  # loaded once, reused thereafter

    # Extract frames
    if process_all:
        extract_frames()
    else:
        assert video_list is not None and len(video_list) > 0, "Provide video_list when process_all=False."
        extract_frames(videos=video_list)

    # Extract features
    if process_all:
        extract_features_from_all()
    else:
        for v in video_list:
            extract_features_for_dir(FRAMES_DIR / v.stem)

if __name__ == "__main__":
    # Load once and reuse across calls
    ensure_model_loaded()
    # Example: process everything under INPUT_VIDEOS
    run_pipeline(process_all=True)
    # Example: infer a single video repeatedly without reloading:
    # infer_single_video(Path('/kaggle/working/downloaded_videos/your_uuid.mp4'))


Missing keys for loading model: []
Unexpected keys for loading model: []


Extracting frames: 100%|██████████| 2/2 [00:06<00:00,  3.44s/it]
  with torch.no_grad(), torch.cuda.amp.autocast():
Extracting features: 100%|██████████| 2/2 [00:06<00:00,  3.32s/it]


In [7]:
# 📁 Navigate to working directory
%cd /kaggle/working/

# 📘 Load JSON data
import json
ques_path_json = '/kaggle/input/egoschema/fullset_anno.json'
with open(ques_path_json, "r") as f:
    data = json.load(f)  # Dictionary of video_id ➝ question/options

# 📂 Get list of video IDs (without .mp4 extension)
import os
video_folder = "/kaggle/working/downloaded_videos"
video_ids = [
    os.path.splitext(f)[0]
    for f in os.listdir(video_folder)
    if f.endswith(".mp4")
]
print(f"Found {len(video_ids)} videos.")

# ✍️ Generate question prompts for available videos
import string
prompts = {}

for vid in video_ids:
    if vid not in data:
        continue  # Skip videos without annotation

    option_keys = sorted(
        [k for k in data[vid] if k.startswith("option ")],
        key=lambda x: int(x.split()[1])
    )

    prompt = f"Question:\n{data[vid]['question']}\n\nOptions:\n"
    for letter, key in zip(string.ascii_uppercase, option_keys):
        prompt += f"{letter}. {data[vid][key]}\n"
    prompt += "\nPlease choose the most appropriate answer (A–E)."

    prompts[vid] = prompt

# 📦 Final dictionary of all questions
questions = prompts

# 🗑️ Delete 'outputs' folder if it exists
import shutil
outputs_folder = "/kaggle/working/outputs"
if os.path.exists(outputs_folder):
    shutil.rmtree(outputs_folder)
    print("Outputs folder and contents deleted successfully.")
else:
    print("Outputs folder does not exist.")


/kaggle/working
Found 2 videos.
Outputs folder does not exist.


In [8]:
!git clone https://github.com/subhadarship/kmeans_pytorch

Cloning into 'kmeans_pytorch'...
remote: Enumerating objects: 422, done.[K
remote: Counting objects: 100% (142/142), done.[K
remote: Compressing objects: 100% (82/82), done.[K
remote: Total 422 (delta 64), reused 130 (delta 58), pack-reused 280 (from 1)[K
Receiving objects: 100% (422/422), 1.05 MiB | 22.81 MiB/s, done.
Resolving deltas: 100% (184/184), done.


In [9]:
!git clone https://github.com/Ziyang412/VideoTree.git

Cloning into 'VideoTree'...
remote: Enumerating objects: 102, done.[K
remote: Counting objects: 100% (102/102), done.[K
remote: Compressing objects: 100% (83/83), done.[K
remote: Total 102 (delta 43), reused 53 (delta 17), pack-reused 0 (from 0)[K
Receiving objects: 100% (102/102), 3.76 MiB | 27.51 MiB/s, done.
Resolving deltas: 100% (43/43), done.


In [10]:
!pip install groq

Collecting groq
  Downloading groq-0.31.0-py3-none-any.whl.metadata (16 kB)
Downloading groq-0.31.0-py3-none-any.whl (131 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m131.4/131.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: groq
Successfully installed groq-0.31.0


In [11]:
from glob import glob

In [12]:
glob('/kaggle/working/*')

['/kaggle/working/VideoTree',
 '/kaggle/working/kmeans_pytorch',
 '/kaggle/working/perception_models',
 '/kaggle/working/extracted_frames',
 '/kaggle/working/videotree',
 '/kaggle/working/downloaded_videos',
 '/kaggle/working/extracted_features']

In [13]:
%cd /kaggle/working/videotree
!git clone https://github.com/Ziyang412/VideoTree.git
%cd /kaggle/working/videotree/VideoTree
# Import your original modules
from util import *

/kaggle/working/videotree
Cloning into 'VideoTree'...
remote: Enumerating objects: 102, done.[K
remote: Counting objects: 100% (102/102), done.[K
remote: Compressing objects: 100% (83/83), done.[K
remote: Total 102 (delta 43), reused 53 (delta 17), pack-reused 0 (from 0)[K
Receiving objects: 100% (102/102), 3.76 MiB | 30.57 MiB/s, done.
Resolving deltas: 100% (43/43), done.
/kaggle/working/videotree/VideoTree


In [None]:
import os
print("GROQ_API_KEY")
os.environ["GROQ_API_KEY"] = " "

# Note: Assuming video_ids and questions variables are already defined
# video_ids = [list of video IDs]
# questions = {video_id: "Question text with options", ...}

# Cell 1: Imports and Configuration
import os
import json
from pathlib import Path
from tqdm import tqdm
import time
from datetime import datetime

%cd /kaggle/working/kmeans_pytorch
from kmeans_pytorch import kmeans
import torch
import re
from groq import Groq


%cd /kaggle/working

class MultiVideoConfig:
    """Configuration for multi-video processing with Groq"""
    def __init__(self):
        # Paths
        self.output_base_path = './outputs'
        self.output_filename_template = 'video_{}_groq_pipeline.json'
        self.batch_summary_filename = 'batch_processing_summary.json'
        
        # Video processing settings
        self.frame_feat_path = '/kaggle/working/extracted_features'
        self.video_ids = video_ids
        self.questions = questions  # Use existing questions variable
        self.captions_file_path = '/kaggle/input/egoschema/blip2_fullset.json'
        
        # Clustering parameters
        self.max_cluster_num = 32
        self.init_cluster_num = 4
        self.iter_threshold = 5
        self.default_adaptive_rate = 2
        
        # Groq model configuration
        self.model = 'llama-3.1-8b-instant'
        self.temperature = 0.0
        self.max_tokens = 1000
        
        # Batch processing settings
        self.skip_existing = True
        self.delay_between_videos = 1.0
        self.max_retries = 3
        self.save_intermediate = True

# Initialize config
config = MultiVideoConfig()
print("Multi-Video Configuration loaded!")
print(f"📹 Videos to process: {len(config.video_ids)}")
print(f"❓ Questions available: {len(config.questions)}")

# Validate that all videos have questions
missing_questions = [vid for vid in config.video_ids if vid not in config.questions]
if missing_questions:
    print(f"⚠️  WARNING: Missing questions for videos: {missing_questions}")
else:
    print("✅ All videos have corresponding questions")


# Cell 2: Initialize Groq Model
class GroqModel:
    """Wrapper for Groq API"""
    def __init__(self, model_name='llama-3.1-8b-instant', temperature=0.0, max_tokens=1000):
        self.client = Groq()
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        
    def forward(self, system_prompt, user_prompt):
        """Forward pass"""
        try:
            chat_completion = self.client.chat.completions.create(
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                model=self.model_name,
                temperature=self.temperature,
                max_tokens=self.max_tokens
            )
            
            response = chat_completion.choices[0].message.content
            
            info = {
                'response': response,
                'model': self.model_name,
                'tokens_used': chat_completion.usage.total_tokens if hasattr(chat_completion, 'usage') else 0
            }
            
            return response, info
            
        except Exception as e:
            print(f"Error in Groq API call: {e}")
            return "", {'response': '', 'error': str(e)}

# Initialize Groq model
print("Initializing Groq model...")
model = GroqModel(
    model_name=config.model,
    temperature=config.temperature,
    max_tokens=config.max_tokens
)
print(f"✅ Groq model initialized: {config.model}")


# Cell 3: Utility Functions
def load_frame_captions(captions_file_path, video_id):
    """Load frame captions for a specific video from JSON file"""
    try:
        captions_data = load_json(captions_file_path)
        
        if video_id in captions_data:
            video_captions_raw = captions_data[video_id]
            
            # Handle if it's a list instead of dict
            if isinstance(video_captions_raw, list):
                frame_captions = {}
                for i, caption in enumerate(video_captions_raw):
                    frame_captions[i] = caption
                return frame_captions
            else:
                # Handle dict format
                frame_captions = {}
                for key, caption in video_captions_raw.items():
                    try:
                        frame_idx = int(key)
                        frame_captions[frame_idx] = caption
                    except ValueError:
                        continue
                return frame_captions
        else:
            return {}
            
    except Exception as e:
        print(f"❌ Error loading captions for {video_id}: {e}")
        return {}

def load_frame_features(video_id, features_folder):
    """Load frame features for a video"""
    filename = f"{video_id}.pt"
    filepath = os.path.join(features_folder, filename)
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Feature file not found: {filepath}")
    return torch.load(filepath)

def find_closest_points_per_cluster(features, cluster_ids, cluster_centers):
    """Find closest points to cluster centers"""
    closest_points_idx_per_cluster = {cluster_id: [] for cluster_id in range(len(cluster_centers))}
    
    for cluster_id in range(len(cluster_centers)):
        indices_in_cluster = torch.where(cluster_ids == cluster_id)[0]
        points_in_cluster = features[indices_in_cluster]
        distances = torch.norm(points_in_cluster - cluster_centers[cluster_id], dim=1)
        if distances.numel() > 0:
            closest_idx_in_cluster = torch.argmin(distances).item()
            closest_global_idx = indices_in_cluster[closest_idx_in_cluster].item()
            closest_points_idx_per_cluster[cluster_id].append(closest_global_idx)
    
    return closest_points_idx_per_cluster

def parse_question_text(question_text):
    """Parse the question text to extract question and options"""
    # lines = question_text.strip().split('\n')
    lines = question_text.replace('\\n', '\n').strip().split('\n')
    
    # Find the question line
    question = ""
    options = []
    
    for line in lines:
        line = line.strip()
        if line.startswith("Question:"):
            question = line.replace("Question:", "").strip()
        elif line.startswith("Options:"):
            continue
        elif line and any(line.startswith(f"{opt}.") for opt in ["A", "B", "C", "D", "E"]):
            options.append(line)
        elif line.startswith("Please choose"):
            break
    
    return question, options

def create_relevance_prompt(tree_node, video_captions, video_id, questions_dict):
    """Create a focused prompt for relevance scoring using video-specific question"""
    frame_descriptions = []
    valid_frames = []
    
    for frame_idx in tree_node:
        if frame_idx in video_captions and video_captions[frame_idx]:
            frame_descriptions.append(f"Frame {frame_idx}: {video_captions[frame_idx]}")
            valid_frames.append(frame_idx)
    
    if not frame_descriptions:
        return None, []

    # Get the specific question for this video
    if video_id not in questions_dict:
        print(f"❌ No question found for video {video_id}")
        return None, []
    
    question_text = questions_dict[video_id]
    # question = parse_question_text(question_text)
    
    if not question_text:
        print(f"❌ Could not parse question for video {video_id}")
        return None, []

    # Create the prompt
    prompt = f"""VIDEO QUESTION: {question_text}

FRAME DESCRIPTIONS:
{chr(10).join(frame_descriptions)}

TASK: Rate each frame's relevance to answering the question on a scale of 1-3:
- 1 = Not relevant (doesn't help answer the question)
- 2 = Somewhat relevant (provides some context)  
- 3 = Highly relevant (directly helps answer the question)

Provide ONLY the relevance scores in this exact format:
frame relevance: [score1, score2, score3, ...]

You must provide exactly {len(frame_descriptions)} scores for the {len(frame_descriptions)} frames described above.

Example: frame relevance: [2, 1, 3, 2, 1]"""

    return prompt, valid_frames

def extract_relevance_scores(text):
    """Extract relevance scores from model response"""
    response = text.strip()
    
    patterns = [
        r"frame relevance:\s*\[([0-9,\s]+)\]",
        r"relevance:\s*\[([0-9,\s]+)\]", 
        r"scores:\s*\[([0-9,\s]+)\]",
        r"relevance scores:\s*\[([0-9,\s]+)\]",
        r"\[([0-9,\s]+)\]"
    ]
    
    for pattern in patterns:
        relevance_match = re.search(pattern, response, re.IGNORECASE)
        if relevance_match:
            try:
                numbers_str = relevance_match.group(1)
                relevance = [int(x.strip()) for x in numbers_str.split(',') if x.strip().isdigit()]
                relevance = [max(1, min(3, score)) for score in relevance]
                return relevance
            except Exception as e:
                continue
    
    return []

# Batch processing utilities
def check_existing_result(video_id, config):
    """Check if result already exists for this video"""
    output_filename = config.output_filename_template.format(video_id)
    output_path = os.path.join(config.output_base_path, output_filename)
    return os.path.exists(output_path)

def save_batch_summary(batch_results, config):
    """Save summary of batch processing"""
    summary = {
        'processing_date': datetime.now().isoformat(),
        'total_videos': len(config.video_ids),
        'processed_videos': len([r for r in batch_results if r['status'] == 'success']),
        'failed_videos': len([r for r in batch_results if r['status'] == 'error']),
        'skipped_videos': len([r for r in batch_results if r['status'] == 'skipped']),
        'total_tokens_used': sum(r.get('tokens_used', 0) for r in batch_results),
        'results': batch_results,
        'config': {k: v for k, v in vars(config).items() if not k.startswith('_') and k not in ['video_ids', 'questions']}
    }
    
    summary_path = os.path.join(config.output_base_path, config.batch_summary_filename)
    save_json(summary, summary_path)
    return summary_path

# Ensure output directory exists
makedir(config.output_base_path)
print("✅ Utility functions loaded!")


# Cell 4: Adaptive Clustering Function
def adaptive_clustering_with_groq(frame_feats, config, model, video_id, video_captions):
    """Perform adaptive clustering with Groq model for relevance prediction"""
    cluster_num = config.init_cluster_num
    device = frame_feats.device
    
    clustering_results = []
    
    while cluster_num <= config.max_cluster_num:
        try:
            # Perform k-means clustering
            cluster_ids_x, cluster_centers = kmeans(
                X=frame_feats, 
                num_clusters=cluster_num, 
                distance='cosine', 
                device=device
            )
            
            cluster_ids_x = cluster_ids_x.to(device)
            cluster_centers = cluster_centers.to(device)
            
            # Find representative frames
            closest_points_idx_per_cluster = find_closest_points_per_cluster(
                frame_feats, cluster_ids_x, cluster_centers
            )
            
            if not closest_points_idx_per_cluster:
                cluster_num *= config.default_adaptive_rate
                continue
            
            # Get representative frame indices
            tree_node = sorted([value for sublist in closest_points_idx_per_cluster.values() for value in sublist])
            cluster_ids_list = cluster_ids_x.tolist()
            
            # Create relevance prompt using video-specific question
            prompt, valid_frames = create_relevance_prompt(tree_node, video_captions, video_id, config.questions)
            
            if not prompt:
                cluster_num *= config.default_adaptive_rate
                continue
            
            # Model inference with Groq
            response, info = model.forward(
                "You are an expert video analyst. Analyze frame descriptions and rate their relevance to answering the specific question.", 
                prompt
            )
            
            # Extract frame relevance
            frame_relevance = extract_relevance_scores(response)
            
            # Count high relevance frames (score = 3)
            if isinstance(frame_relevance, list) and len(frame_relevance) == len(valid_frames):
                high_relevance_frame_num = frame_relevance.count(3)
            else:
                high_relevance_frame_num = 0
            
            # Store clustering result
            clustering_result = {
                'num_clusters': cluster_num,
                'actual_clusters': len(set(cluster_ids_list)),
                'representative_frames': tree_node,
                'valid_frames_with_captions': valid_frames,
                'cluster_assignments': cluster_ids_list,
                'frame_relevance': frame_relevance,
                'high_relevance_count': high_relevance_frame_num,
                'prompt': prompt,
                'model_response': info.get('response', ''),
                'tokens_used': info.get('tokens_used', 0)
            }
            clustering_results.append(clustering_result)
            
            # Check stopping condition
            if high_relevance_frame_num < config.iter_threshold:
                if cluster_num < config.max_cluster_num:
                    next_cluster_num = cluster_num * config.default_adaptive_rate
                    cluster_num = next_cluster_num
                else:
                    break
            else:
                break
                
        except Exception as e:
            print(f"❌ Clustering failed with {cluster_num} clusters: {e}")
            cluster_num *= config.default_adaptive_rate
            continue
    
    # Return the final successful clustering result
    if clustering_results:
        final_result = clustering_results[-1]
        return (final_result['representative_frames'], 
                final_result['cluster_assignments'], 
                final_result['frame_relevance'],
                final_result['high_relevance_count'],
                clustering_results)
    else:
        return [], [], [], 0, []


# Cell 5: Single Video Processing Function
def process_single_video(video_id, config, model, device):
    """Process a single video and return results"""
    print(f"\n🎬 PROCESSING VIDEO: {video_id}")
    print("="*60)
    
    start_time = time.time()
    
    try:
        # Check if feature file exists
        feature_file_path = os.path.join(config.frame_feat_path, f"{video_id}.pt")
        if not os.path.exists(feature_file_path):
            return {
                'video_id': video_id,
                'status': 'error',
                'error': f'Feature file not found: {feature_file_path}',
                'processing_time': time.time() - start_time
            }
        
        # Check if question exists for this video
        if video_id not in config.questions:
            return {
                'video_id': video_id,
                'status': 'error',
                'error': f'No question found for video: {video_id}',
                'processing_time': time.time() - start_time
            }
        
        # Load frame features
        print(f"📊 Loading features...")
        frame_feats = load_frame_features(video_id, config.frame_feat_path)
        frame_feats = frame_feats.to(device)
        print(f"✅ Loaded features: shape {frame_feats.shape}")
        
        # Load frame captions
        print(f"📝 Loading captions...")
        video_captions = load_frame_captions(config.captions_file_path, video_id)
        if not video_captions:
            return {
                'video_id': video_id,
                'status': 'error',
                'error': 'No captions found for video',
                'processing_time': time.time() - start_time
            }
        print(f"✅ Loaded {len(video_captions)} frame captions")
        
        # Show the question for this video
        question_text = config.questions[video_id]
        question, options = parse_question_text(question_text)
        print(f"❓ Question: {question[:100]}...")
        
        # Run adaptive clustering
        print(f"🔄 Running adaptive clustering...")
        (representative_frames, cluster_assignments, 
         frame_relevance, high_relevance_count, all_results) = adaptive_clustering_with_groq(
            frame_feats, config, model, video_id, video_captions
        )
        
        # Create result
        result = {
            'video_id': video_id,
            'status': 'success',
            'processing_time': time.time() - start_time,
            'feature_file_path': feature_file_path,
            'total_frames': len(frame_feats),
            'feature_dimensions': frame_feats.shape[1] if len(frame_feats.shape) > 1 else 1,
            'question_text': question_text,
            'parsed_question': question,
            'parsed_options': options,
            'final_result': {
                'representative_frames': representative_frames,
                'cluster_assignments': cluster_assignments,
                'frame_relevance': frame_relevance,
                'high_relevance_count': high_relevance_count,
                'num_clusters': len(set(cluster_assignments)) if cluster_assignments else 0,
                'passed_threshold': high_relevance_count >= config.iter_threshold
            },
            'all_clustering_attempts': all_results,
            'tokens_used': sum(attempt.get('tokens_used', 0) for attempt in all_results)
        }
        
        # Print summary
        final = result['final_result']
        print(f"✅ SUCCESS - Clusters: {final['num_clusters']}, "
              f"High relevance: {final['high_relevance_count']}, "
              f"Passed: {final['passed_threshold']}, "
              f"Tokens: {result['tokens_used']}, "
              f"Time: {result['processing_time']:.1f}s")
        
        return result
        
    except Exception as e:
        print(f"❌ ERROR processing {video_id}: {e}")
        import traceback
        traceback.print_exc()
        return {
            'video_id': video_id,
            'status': 'error',
            'error': str(e),
            'processing_time': time.time() - start_time
        }


# Cell 6: Main Batch Processing Loop
def run_batch_processing():
    """Run the complete batch processing pipeline"""
    print("🚀 STARTING BATCH PROCESSING")
    print("="*80)
    print(f"📹 Total videos to process: {len(config.video_ids)}")
    print("="*80)
    
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_results = []
    
    # Process each video
    for i, video_id in enumerate(config.video_ids, 1):
        print(f"\n🎯 VIDEO {i}/{len(config.video_ids)}: {video_id}")
        
        # Check if we should skip existing results
        if config.skip_existing and check_existing_result(video_id, config):
            print(f"⏭️  SKIPPED - Result already exists")
            batch_results.append({
                'video_id': video_id,
                'status': 'skipped',
                'reason': 'Result already exists'
            })
            continue
        
        # Process video with retries
        retry_count = 0
        result = None
        
        while retry_count < config.max_retries:
            try:
                result = process_single_video(video_id, config, model, device)
                if result['status'] == 'success':
                    break
                else:
                    retry_count += 1
                    if retry_count < config.max_retries:
                        print(f"🔄 Retrying ({retry_count + 1}/{config.max_retries})...")
                        time.sleep(config.delay_between_videos)
            except Exception as e:
                retry_count += 1
                if retry_count >= config.max_retries:
                    result = {
                        'video_id': video_id,
                        'status': 'error',
                        'error': f'Failed after {config.max_retries} retries: {str(e)}'
                    }
        
        batch_results.append(result)
        
        # Save individual result
        if config.save_intermediate and result['status'] == 'success':
            output_filename = config.output_filename_template.format(video_id)
            output_path = os.path.join(config.output_base_path, output_filename)
            save_json(result, output_path)
            print(f"💾 Saved: {output_filename}")
        
        # Rate limiting delay
        if i < len(config.video_ids):  # Don't delay after last video
            time.sleep(config.delay_between_videos)
    
    # Save batch summary
    print(f"\n📊 BATCH PROCESSING COMPLETE!")
    print("="*80)
    
    summary_path = save_batch_summary(batch_results, config)
    
    # Print final statistics
    successful = len([r for r in batch_results if r['status'] == 'success'])
    failed = len([r for r in batch_results if r['status'] == 'error'])
    skipped = len([r for r in batch_results if r['status'] == 'skipped'])
    total_tokens = sum(r.get('tokens_used', 0) for r in batch_results)
    
    print(f"✅ Successful: {successful}")
    print(f"❌ Failed: {failed}")
    print(f"⏭️  Skipped: {skipped}")
    print(f"🔤 Total tokens: {total_tokens}")
    print(f"💾 Summary saved: {summary_path}")
    print("="*80)
    
    return batch_results

# Run the batch processing
batch_results = run_batch_processing()


# Cell 1: Depth Expansion - Imports and Functions

import numpy as np
import torch
import torch.nn.functional as F
from scipy.cluster.hierarchy import linkage, fcluster
import json
import os
from pathlib import Path
from tqdm import tqdm

def hierarchical_clustering_with_external_primary(video_features, cluster_ids, relevance_scores, num_subclusters=5, num_subsubclusters=5):
    """
    Perform hierarchical clustering based on relevance scores:
    - Score 1: Keep only primary cluster
    - Score 2: Split into subclusters  
    - Score 3: Split into sub-subclusters
    """
    clusters = {i: {} for i in range(0, max(cluster_ids)+1)}

    for cluster_id in set(cluster_ids):
        primary_indices = [i for i, x in enumerate(cluster_ids) if x == cluster_id]

        if cluster_id < len(relevance_scores):
            score = relevance_scores[cluster_id]
        else:
            score = 3

        if len(primary_indices) < 2:
            clusters[cluster_id] = primary_indices
            continue

        sub_features = video_features[primary_indices]

        if score == 1:
            # Low relevance: keep as single cluster
            clusters[cluster_id] = primary_indices
            continue

        # Create subclusters
        linked_sub = linkage(sub_features, method='ward')
        sub_cluster_labels = fcluster(linked_sub, num_subclusters, criterion='maxclust')
        sub_cluster_labels = sub_cluster_labels - 1

        if score == 2:
            # Medium relevance: split into subclusters
            clusters[cluster_id] = {i: [primary_indices[j] for j in np.where(sub_cluster_labels == i)[0]] for i in range(0, num_subclusters)}
            continue

        # High relevance (score == 3): split into sub-subclusters
        clusters[cluster_id] = {}
        for subcluster_id in range(0, num_subclusters):
            sub_indices = np.where(sub_cluster_labels == subcluster_id)[0]
            if len(sub_indices) < 2:
                continue

            subsub_features = sub_features[sub_indices]
            linked_subsub = linkage(subsub_features, method='ward')
            subsub_cluster_labels = fcluster(linked_subsub, num_subsubclusters, criterion='maxclust')
            subsub_cluster_labels = subsub_cluster_labels - 1

            clusters[cluster_id][subcluster_id] = {}
            for subsubcluster_id in range(0, num_subsubclusters):
                final_indices = sub_indices[np.where(subsub_cluster_labels == subsubcluster_id)[0]]
                original_indices = [primary_indices[i] for i in final_indices]
                clusters[cluster_id][subcluster_id][subsubcluster_id] = original_indices

    return clusters

def cosine_similarity(points, centroid):
    """Calculate cosine similarity between points and centroid."""
    points_normalized = F.normalize(points, dim=1)
    centroid_normalized = F.normalize(centroid.unsqueeze(0), dim=1)
    return 1 - torch.mm(points_normalized, centroid_normalized.T).squeeze()

def find_closest_points_in_temporal_order_subsub(x, clusters, relevance_scores):
    """Find representative frames from hierarchical clusters in temporal order."""
    closest_points_indices = []

    for cluster_id, cluster_data in clusters.items():
        if cluster_id < len(relevance_scores):
            relevance = relevance_scores[cluster_id]
        else:
            relevance = 3

        if isinstance(cluster_data, list):  # Primary cluster directly
            cluster_data = np.array(cluster_data)
            if cluster_data.size == 0:
                continue
            points_in_cluster = x[torch.tensor(cluster_data, dtype=torch.long)]
            cluster_centroid = points_in_cluster.mean(dim=0)
            distances = cosine_similarity(points_in_cluster, cluster_centroid)
            if distances.numel() > 0:
                closest_idx = torch.argmin(distances).item()
                closest_points_indices.append(int(cluster_data[closest_idx]))

        elif isinstance(cluster_data, dict):  # Handle subclusters and sub-subclusters
            if relevance == 1:
                # Only take representative frame for primary cluster
                primary_indices = []
                for subcluster_data in cluster_data.values():
                    if isinstance(subcluster_data, dict):
                        for sub_data in subcluster_data.values():
                            if len(sub_data) > 0:
                                primary_indices.extend(sub_data)
                    elif isinstance(subcluster_data, list) and len(subcluster_data) > 0:
                        primary_indices.extend(subcluster_data)

                if primary_indices:
                    primary_indices = np.array(primary_indices)
                    primary_points = x[torch.tensor(primary_indices, dtype=torch.long)]
                    primary_centroid = primary_points.mean(dim=0)
                    primary_distances = cosine_similarity(primary_points, primary_centroid)
                    if primary_distances.numel() > 0:
                        closest_primary_idx = torch.argmin(primary_distances).item()
                        closest_points_indices.append(int(primary_indices[closest_primary_idx]))
                continue

            elif relevance == 2 or relevance == 3:
                # Include primary cluster representative
                primary_indices = []
                for subcluster_data in cluster_data.values():
                    if isinstance(subcluster_data, dict):
                        for sub_data in subcluster_data.values():
                            if len(sub_data) > 0:
                                primary_indices.extend(sub_data)
                    elif isinstance(subcluster_data, list) and len(subcluster_data) > 0:
                        primary_indices.extend(subcluster_data)

                if primary_indices:
                    primary_indices = np.array(primary_indices)
                    primary_points = x[torch.tensor(primary_indices, dtype=torch.long)]
                    primary_centroid = primary_points.mean(dim=0)
                    primary_distances = cosine_similarity(primary_points, primary_centroid)
                    if primary_distances.numel() > 0:
                        closest_primary_idx = torch.argmin(primary_distances).item()
                        closest_points_indices.append(int(primary_indices[closest_primary_idx]))

                # Process subclusters/sub-subclusters
                for subcluster_id, subclusters in cluster_data.items():
                    if isinstance(subclusters, dict):  # Sub-subclusters
                        for subsubcluster_id, indices in subclusters.items():
                            if len(indices) == 0:
                                continue
                            indices_tensor = torch.tensor(indices, dtype=torch.long)
                            points_in_subsubcluster = x[indices_tensor]
                            subsubcluster_centroid = points_in_subsubcluster.mean(dim=0)
                            distances = cosine_similarity(points_in_subsubcluster, subsubcluster_centroid)
                            if distances.numel() > 0:
                                closest_idx_in_subsubcluster = torch.argmin(distances).item()
                                closest_global_idx = indices[closest_idx_in_subsubcluster]
                                closest_points_indices.append(int(closest_global_idx))

                    elif isinstance(subclusters, list):
                        subclusters = np.array(subclusters)
                        if subclusters.size == 0:
                            continue
                        points_in_subcluster = x[torch.tensor(subclusters, dtype=torch.long)]
                        subcluster_centroid = points_in_subcluster.mean(dim=0)
                        distances = cosine_similarity(points_in_subcluster, subcluster_centroid)
                        if distances.numel() > 0:
                            closest_idx = torch.argmin(distances).item()
                            closest_points_indices.append(int(subclusters[closest_idx]))

    closest_points_indices.sort()  # Ensure temporal order
    return closest_points_indices

def load_image_features(name_ids, save_folder):
    """Load image features from a .pt file."""
    filename = f"{name_ids}.pt"
    filepath = os.path.join(save_folder, filename)
    img_feats = torch.load(filepath)
    return img_feats

def load_json(fn):
    """Load JSON file."""
    with open(fn, 'r') as f:
        data = json.load(f)
    return data

def save_json(data, fn, indent=4):
    """Save data to JSON file."""
    with open(fn, 'w') as f:
        json.dump(data, f, indent=indent)

print("✅ Depth expansion functions loaded!")


import os
import json
from glob import glob

# Configuration for depth expansion
class DepthExpansionConfig:
    def __init__(self, video_id, json_path):
        # Input paths
        self.save_folder = '/kaggle/working/extracted_features'  # Where your .pt files are
        self.video_id = video_id  # Extracted from filename
        
        # Previous results from Groq pipeline
        self.groq_results_path = json_path
        
        # Output paths
        self.output_base_path = './outputs'
        self.output_filename = f'depth_expansion_{video_id}.json'
        
        # Hierarchical clustering parameters
        self.num_subclusters = 4
        self.num_subsubclusters = 4

# Find all Groq pipeline results
json_files = glob("./outputs/*_groq_pipeline.json")

if not json_files:
    print("❌ No Groq pipeline result JSON files found in ./outputs/")
else:
    print(f"📂 Found {len(json_files)} Groq result files")

# Loop through each file
for json_path in json_files:
    # Extract video_id from filename
    base_name = os.path.basename(json_path)
    video_id = base_name.replace("video_", "").replace("_groq_pipeline.json", "")
    
    config = DepthExpansionConfig(video_id, json_path)
    print("\n" + "="*60)
    print(f"📋 Depth Expansion Configuration for {video_id}:")
    print(f"  - Groq results: {config.groq_results_path}")
    print(f"  - Output: {config.output_filename}")
    
    # Load previous Groq pipeline results
    try:
        with open(config.groq_results_path, 'r') as f:
            groq_results = json.load(f)
        
        print("✅ Groq results loaded successfully!")
        
        # Extract needed data
        cluster_assignments = groq_results['final_result']['cluster_assignments']
        frame_relevance = groq_results['final_result']['frame_relevance'] 
        representative_frames = groq_results['final_result']['representative_frames']
        
        print(f"📊 Data extracted:")
        print(f"  - Total frames: {groq_results['total_frames']}")
        print(f"  - Clusters: {len(set(cluster_assignments))}")
        print(f"  - Representative frames: {len(representative_frames)}")
        print(f"  - Frame relevance: {frame_relevance}")
        print(f"  - Cluster assignments: {cluster_assignments[:10]}...")  # Show first 10
        
    except FileNotFoundError:
        print(f"❌ Error: Could not find Groq results at {config.groq_results_path}")
        continue
    except Exception as e:
        print(f"❌ Error loading Groq results: {e}")
        continue

    # Check if feature file exists
    feature_path = os.path.join(config.save_folder, f"{config.video_id}.pt")
    if os.path.exists(feature_path):
        print(f"✅ Feature file found: {feature_path}")
    else:
        print(f"❌ Feature file not found: {feature_path}")

/kaggle/working/kmeans_pytorch
/kaggle/working
Multi-Video Configuration loaded!
📹 Videos to process: 2
❓ Questions available: 2
✅ All videos have corresponding questions
Initializing Groq model...
✅ Groq model initialized: llama-3.1-8b-instant
✅ Utility functions loaded!
🚀 STARTING BATCH PROCESSING
📹 Total videos to process: 2

🎯 VIDEO 1/2: 0074f737-11cb-497d-8d07-77c3a8127391

🎬 PROCESSING VIDEO: 0074f737-11cb-497d-8d07-77c3a8127391
📊 Loading features...
✅ Loaded features: shape torch.Size([180, 1024])
📝 Loading captions...
✅ Loaded 180 frame captions
❓ Question: ...
🔄 Running adaptive clustering...
running k-means on cuda:0..


[running kmeans]: 8it [00:00, 25.03it/s, center_shift=0.000000, iteration=8, tol=0.000100]  


running k-means on cuda:0..


[running kmeans]: 10it [00:00, 431.27it/s, center_shift=0.000000, iteration=10, tol=0.000100]


✅ SUCCESS - Clusters: 8, High relevance: 6, Passed: True, Tokens: 741, Time: 1.9s
💾 Saved: video_0074f737-11cb-497d-8d07-77c3a8127391_groq_pipeline.json

🎯 VIDEO 2/2: 00b9a0de-c59e-49cb-a127-6081e2fb8c8e

🎬 PROCESSING VIDEO: 00b9a0de-c59e-49cb-a127-6081e2fb8c8e
📊 Loading features...
✅ Loaded features: shape torch.Size([180, 1024])
📝 Loading captions...
✅ Loaded 180 frame captions
❓ Question: ...
🔄 Running adaptive clustering...
running k-means on cuda:0..


[running kmeans]: 9it [00:00, 500.52it/s, center_shift=0.000000, iteration=9, tol=0.000100] 


running k-means on cuda:0..


[running kmeans]: 9it [00:00, 438.27it/s, center_shift=0.000000, iteration=9, tol=0.000100]


running k-means on cuda:0..


[running kmeans]: 11it [00:00, 311.44it/s, center_shift=0.000000, iteration=11, tol=0.000100]


running k-means on cuda:0..


[running kmeans]: 3it [00:00, 187.46it/s, center_shift=0.000000, iteration=3, tol=0.000100]


✅ SUCCESS - Clusters: 32, High relevance: 0, Passed: False, Tokens: 3373, Time: 4.5s
💾 Saved: video_00b9a0de-c59e-49cb-a127-6081e2fb8c8e_groq_pipeline.json

📊 BATCH PROCESSING COMPLETE!
✅ Successful: 2
❌ Failed: 0
⏭️  Skipped: 0
🔤 Total tokens: 4114
💾 Summary saved: ./outputs/batch_processing_summary.json
✅ Depth expansion functions loaded!
📂 Found 2 Groq result files

📋 Depth Expansion Configuration for 00b9a0de-c59e-49cb-a127-6081e2fb8c8e:
  - Groq results: ./outputs/video_00b9a0de-c59e-49cb-a127-6081e2fb8c8e_groq_pipeline.json
  - Output: depth_expansion_00b9a0de-c59e-49cb-a127-6081e2fb8c8e.json
✅ Groq results loaded successfully!
📊 Data extracted:
  - Total frames: 180
  - Clusters: 32
  - Representative frames: 32
  - Frame relevance: [1, 1, 1, 1, 1]
  - Cluster assignments: [10, 31, 31, 31, 31, 31, 31, 23, 19, 31]...
✅ Feature file found: /kaggle/working/extracted_features/00b9a0de-c59e-49cb-a127-6081e2fb8c8e.pt

📋 Depth Expansion Configuration for 0074f737-11cb-497d-8d07-77c3a81

# Depth Expansion

In [15]:
for vid in video_ids:
    print("\n" + "="*60)
    print(f"🚀 STARTING DEPTH EXPANSION for video: {vid}")
    print("="*60)

    # Setup config for current video
    groq_path = f'./outputs/video_{vid}_groq_pipeline.json'
    
    # config = DepthExpansionConfig()
    config = DepthExpansionConfig(vid, groq_path)  # Pass args here
    config.video_id = vid
    config.groq_results_path = f'./outputs/video_{vid}_groq_pipeline.json'
    config.output_filename = f'depth_expansion_{vid}.json'

    # Load Groq results JSON for this video
    try:
        with open(config.groq_results_path, 'r') as f:
            groq_results = json.load(f)
        
        cluster_assignments = groq_results['final_result']['cluster_assignments']
        frame_relevance = groq_results['final_result']['frame_relevance'] 
        representative_frames = groq_results['final_result']['representative_frames']

        print(f"✅ Loaded Groq results for video {vid}")
    except FileNotFoundError:
        print(f"❌ Groq results not found for video {vid}, skipping...")
        continue
    except Exception as e:
        print(f"❌ Error loading Groq results for video {vid}: {e}, skipping...")
        continue

    # Load features for this video
    print(f"📂 Loading features for {vid}...")
    try:
        img_feats = load_image_features(vid, config.save_folder)
        img_feats = img_feats.cpu()  # For scipy
        print(f"✅ Features loaded: shape {img_feats.shape}")
    except Exception as e:
        print(f"❌ Failed to load features for {vid}: {e}, skipping...")
        continue

    # Process relevance scores
    print(f"📊 Processing relevance scores...")
    if isinstance(frame_relevance, list) and len(frame_relevance) > 0:
        print(f"✅ Using extracted relevance scores: {frame_relevance}")
        relevance_scores = frame_relevance
    else:
        print("⚠️  No relevance scores found, using default scores...")
        num_clusters = len(set(cluster_assignments))
        relevance_scores = [2] * num_clusters  # Medium relevance

    print(f"📈 Relevance scores: {relevance_scores}")
    print(f"🎯 Cluster assignments: {len(cluster_assignments)} total assignments")

    # Perform hierarchical clustering
    print(f"\n🔄 Performing hierarchical clustering...")
    print(f"  - Primary clusters: {len(set(cluster_assignments))}")
    print(f"  - Subclusters per cluster: {config.num_subclusters}")
    print(f"  - Sub-subclusters per subcluster: {config.num_subsubclusters}")

    clusters_info = hierarchical_clustering_with_external_primary(
        img_feats,
        cluster_assignments,
        relevance_scores,
        num_subclusters=config.num_subclusters,
        num_subsubclusters=config.num_subsubclusters
    )

    print(f"✅ Hierarchical clustering complete!")
    print(f"📊 Clusters info type: {type(clusters_info)}")

    # Find representative points in temporal order
    print(f"\n🎯 Finding representative points in temporal order...")
    closest_points_temporal = find_closest_points_in_temporal_order_subsub(
        img_feats,
        clusters_info,
        relevance_scores
    )

    print(f"✅ Representative points found!")
    print(f"📍 Number of representative points: {len(closest_points_temporal)}")
    print(f"🕐 Temporal order: {closest_points_temporal}")

    # Compare with original representative frames
    print(f"\n📊 COMPARISON:")
    print(f"  Original (width expansion): {representative_frames}")
    print(f"  New (depth expansion): {closest_points_temporal}")
    print(f"  Original count: {len(representative_frames)}")
    print(f"  New count: {len(closest_points_temporal)}")

    # Calculate expansion ratio
    if len(representative_frames) > 0:
        expansion_ratio = len(closest_points_temporal) / len(representative_frames)
        print(f"  📈 Expansion ratio: {expansion_ratio:.2f}x")
    else:
        expansion_ratio = 0
        print(f"  ❌ No original frames to compare")

    print("="*60)



🚀 STARTING DEPTH EXPANSION for video: 0074f737-11cb-497d-8d07-77c3a8127391
✅ Loaded Groq results for video 0074f737-11cb-497d-8d07-77c3a8127391
📂 Loading features for 0074f737-11cb-497d-8d07-77c3a8127391...
✅ Features loaded: shape torch.Size([180, 1024])
📊 Processing relevance scores...
✅ Using extracted relevance scores: [2, 3, 3, 3, 3, 3, 3, 1]
📈 Relevance scores: [2, 3, 3, 3, 3, 3, 3, 1]
🎯 Cluster assignments: 180 total assignments

🔄 Performing hierarchical clustering...
  - Primary clusters: 8
  - Subclusters per cluster: 4
  - Sub-subclusters per subcluster: 4
✅ Hierarchical clustering complete!
📊 Clusters info type: <class 'dict'>

🎯 Finding representative points in temporal order...
✅ Representative points found!
📍 Number of representative points: 86
🕐 Temporal order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 20, 22, 30, 31, 34, 40, 41, 53, 56, 62, 63, 64, 65, 67, 68, 70, 72, 74, 75, 79, 80, 83, 85, 87, 87, 88, 89, 90, 91, 92, 93, 102, 105, 107, 111, 114, 115, 116, 1

In [16]:
for vid in video_ids:
    print("\n" + "="*60)
    print(f"🚀 STARTING DEPTH EXPANSION for video: {vid}")
    print("="*60)
    
    groq_path = f'./outputs/video_{vid}_groq_pipeline.json'
    config = DepthExpansionConfig(vid, groq_path)
    config.output_filename = f'depth_expansion_{vid}.json'  # unique output filename
    
    # -- Cell 2 part to load groq results (assumed you did this outside or add here) --
    try:
        with open(config.groq_results_path, 'r') as f:
            groq_results = json.load(f)
        cluster_assignments = groq_results['final_result']['cluster_assignments']
        frame_relevance = groq_results['final_result']['frame_relevance'] 
        representative_frames = groq_results['final_result']['representative_frames']
    except Exception as e:
        print(f"Error loading Groq results for {vid}: {e}")
        continue  # skip to next video
    
    # -- Cell 3: Run Depth Expansion --
    print(f"📂 Loading features for {config.video_id}...")
    img_feats = load_image_features(config.video_id, config.save_folder)
    img_feats = img_feats.cpu()
    print(f"✅ Features loaded: shape {img_feats.shape}")

    print(f"📊 Processing relevance scores...")
    if isinstance(frame_relevance, list) and len(frame_relevance) > 0:
        relevance_scores = frame_relevance
    else:
        num_clusters = len(set(cluster_assignments))
        relevance_scores = [2] * num_clusters

    clusters_info = hierarchical_clustering_with_external_primary(
        img_feats, 
        cluster_assignments, 
        relevance_scores,
        num_subclusters=config.num_subclusters,
        num_subsubclusters=config.num_subsubclusters
    )

    closest_points_temporal = find_closest_points_in_temporal_order_subsub(
        img_feats, 
        clusters_info, 
        relevance_scores
    )
    
    # -- Cell 4: Save Results and Analysis --
    depth_results = {
        "video_id": config.video_id,
        "input_data": {
            "total_frames": len(img_feats),
            "feature_dimensions": img_feats.shape[1] if len(img_feats.shape) > 1 else 1,
            "original_clusters": len(set(cluster_assignments)),
            "original_representative_frames": representative_frames,
            "relevance_scores": relevance_scores,
            "cluster_assignments": cluster_assignments
        },
        "depth_expansion_config": {
            "num_subclusters": config.num_subclusters,
            "num_subsubclusters": config.num_subsubclusters
        },
        "results": {
            "hierarchical_clusters": len(clusters_info),
            "final_representative_frames": closest_points_temporal,
            "expansion_ratio": len(closest_points_temporal) / len(representative_frames) if len(representative_frames) > 0 else 0,
            "total_representative_frames": len(closest_points_temporal)
        },
        "analysis": {
            "relevance_distribution": {
                "score_1": relevance_scores.count(1) if isinstance(relevance_scores, list) else 0,
                "score_2": relevance_scores.count(2) if isinstance(relevance_scores, list) else 0,
                "score_3": relevance_scores.count(3) if isinstance(relevance_scores, list) else 0
            },
            "frame_coverage": {
                "min_frame": min(closest_points_temporal) if closest_points_temporal else 0,
                "max_frame": max(closest_points_temporal) if closest_points_temporal else 0,
                "frame_span": max(closest_points_temporal) - min(closest_points_temporal) + 1 if closest_points_temporal else 0
            }
        }
    }

    output_path = os.path.join(config.output_base_path, config.output_filename)
    save_json(depth_results, output_path)

    print("💾 SAVING DEPTH EXPANSION RESULTS")
    print("="*50)
    print(f"📁 Output path: {output_path}")

    if os.path.exists(output_path):
        file_size = os.path.getsize(output_path) / 1024
        print(f"📏 File size: {file_size:.2f} KB")
    print(f"✅ Results saved successfully!")

    # Print detailed analysis summary
    print(f"\n📊 DEPTH EXPANSION ANALYSIS")
    print("="*50)

    results = depth_results["results"]
    analysis = depth_results["analysis"]

    print(f"🎬 Video: {depth_results['video_id']}")
    print(f"📊 Total frames: {depth_results['input_data']['total_frames']}")
    print(f"🔢 Feature dimensions: {depth_results['input_data']['feature_dimensions']}")

    print(f"\n📈 Expansion Results:")
    print(f"  - Original clusters: {depth_results['input_data']['original_clusters']}")
    print(f"  - Original representative frames: {len(depth_results['input_data']['original_representative_frames'])}")
    print(f"  - New representative frames: {results['total_representative_frames']}")
    print(f"  - Expansion ratio: {results['expansion_ratio']:.2f}x")

    print(f"\n⭐ Relevance Distribution:")
    rel_dist = analysis["relevance_distribution"]
    print(f"  - Score 1 (Low): {rel_dist['score_1']} clusters")
    print(f"  - Score 2 (Medium): {rel_dist['score_2']} clusters")
    print(f"  - Score 3 (High): {rel_dist['score_3']} clusters")

    print(f"\n🎯 Frame Coverage:")
    coverage = analysis["frame_coverage"]
    print(f"  - Frame range: {coverage['min_frame']} → {coverage['max_frame']}")
    print(f"  - Frame span: {coverage['frame_span']} frames")

    print(f"\n📍 Representative Frames:")
    print(f"  Original: {depth_results['input_data']['original_representative_frames']}")
    print(f"  Expanded: {results['final_representative_frames']}")

    print("="*50)
    print("🎉 DEPTH EXPANSION COMPLETE!")

    print(f"\n📋 SUMMARY:")
    if results['expansion_ratio'] > 1:
        print(f"✅ Successfully expanded from {len(representative_frames)} to {results['total_representative_frames']} frames")
        print(f"📈 {results['expansion_ratio']:.1f}x more detailed frame selection!")
    elif results['expansion_ratio'] == 1:
        print(f"➡️  Same number of frames, but hierarchically organized")
    else:
        print(f"📉 Fewer frames selected: {results['total_representative_frames']} vs {len(representative_frames)}")

    print(f"🎬 Your video now has {results['total_representative_frames']} key representative frames!")



🚀 STARTING DEPTH EXPANSION for video: 0074f737-11cb-497d-8d07-77c3a8127391
📂 Loading features for 0074f737-11cb-497d-8d07-77c3a8127391...
✅ Features loaded: shape torch.Size([180, 1024])
📊 Processing relevance scores...
💾 SAVING DEPTH EXPANSION RESULTS
📁 Output path: ./outputs/depth_expansion_0074f737-11cb-497d-8d07-77c3a8127391.json
📏 File size: 5.12 KB
✅ Results saved successfully!

📊 DEPTH EXPANSION ANALYSIS
🎬 Video: 0074f737-11cb-497d-8d07-77c3a8127391
📊 Total frames: 180
🔢 Feature dimensions: 1024

📈 Expansion Results:
  - Original clusters: 8
  - Original representative frames: 8
  - New representative frames: 86
  - Expansion ratio: 10.75x

⭐ Relevance Distribution:
  - Score 1 (Low): 1 clusters
  - Score 2 (Medium): 1 clusters
  - Score 3 (High): 6 clusters

🎯 Frame Coverage:
  - Frame range: 0 → 177
  - Frame span: 178 frames

📍 Representative Frames:
  Original: [7, 20, 54, 87, 126, 141, 152, 169]
  Expanded: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 20, 22, 30, 31,

# VLM

In [17]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
from transformers import AutoProcessor, AutoModelForVision2Seq

# -------- Globals --------
vl_model = None
vl_processor = None
vl_device = None

VL_MODEL_ID = "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit"

def ensure_vl_model_loaded(
    model_id: str = VL_MODEL_ID,
    gpu_index: int = 0,
    device_map: str | None = "auto",  # "auto" spreads layers, None keeps manual device handling
    torch_dtype: torch.dtype | None = None,  # Keep None for 4-bit; HF sets dtype appropriately
) -> tuple[AutoModelForVision2Seq, AutoProcessor, str]:
    """
    Loads Qwen2.5-VL once and returns (model, processor, device).
    Subsequent calls return the already-loaded singleton.
    """
    global vl_model, vl_processor, vl_device
    if vl_model is not None and vl_processor is not None and vl_device is not None:
        return vl_model, vl_processor, vl_device

    # Resolve target device
    if torch.cuda.is_available():
        vl_device = f"cuda:{gpu_index}"
    else:
        vl_device = "cpu"

    # Load processor
    vl_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

    # Load model (4-bit quantized). Prefer device_map="auto" for correct placement.
    load_kwargs = dict(trust_remote_code=True, low_cpu_mem_usage=True)
    if device_map is not None:
        load_kwargs["device_map"] = device_map
    if torch_dtype is not None:
        load_kwargs["torch_dtype"] = torch_dtype

    vl_model = AutoModelForVision2Seq.from_pretrained(model_id, **load_kwargs)
    vl_model.eval()

    # If not using device_map, move the whole model to the resolved device
    if device_map in (None, "none"):
        vl_model = vl_model.to(vl_device)

    return vl_model, vl_processor, vl_device

def unload_vl_model():
    """Optional: free the singleton to reclaim memory."""
    global vl_model, vl_processor, vl_device
    vl_model = None
    vl_processor = None
    vl_device = None
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


2025-08-09 18:08:38.111926: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754762918.333670      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754762918.399070      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [18]:
model, processor, device = ensure_vl_model_loaded(
    model_id="unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit",
    gpu_index=0,          # set to 1 if you intend to use cuda:1
    device_map="auto"     # or None to force .to(device)
)

preprocessor_config.json:   0%|          | 0.00/575 [00:00<?, ?B/s]

The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


chat_template.json: 0.00B [00:00, ?B/s]



config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/5.97G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/238 [00:00<?, ?B/s]

In [None]:
for video_id in video_ids:

    video_path = f"/kaggle/working/downloaded_videos/{video_id}.mp4"
    json_path = f'/kaggle/working/outputs/depth_expansion_{video_id}.json'
    
    question = questions[video_id]
    
    # Step 1: Extract frames at 1 FPS
    def extract_frames_at_1fps(video_path):
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        duration_sec = total_frames / fps
        print(f"🎞️ Video FPS: {fps}, total frames: {total_frames}, duration: {duration_sec:.2f}s")
    
        frames = {}
        frame_idx = 0
        saved_idx = 0
    
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            if frame_idx % int(fps) == 0:
                rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames[saved_idx] = Image.fromarray(rgb)
                saved_idx += 1
            frame_idx += 1
    
        cap.release()
        print(f"✅ Extracted {saved_idx} frames at ~1 FPS")
        return frames
    
    # Step 2: Load important frame indices from JSON
    with open(json_path, 'r') as f:
        data = json.load(f)
    important_indices = data['results']['final_representative_frames']
    print(f"📌 Important frames from JSON: {important_indices}")
    
    # Step 3: Match with extracted frames
    all_frames = extract_frames_at_1fps(video_path)
    important_images = []
    missing = []
    
    for idx in important_indices:
        if idx in all_frames:
            important_images.append(all_frames[idx])
        else:
            missing.append(idx)
    
    print(f"🖼️ Found {len(important_images)} important frames")
    if missing:
        print(f"⚠️ Missing frames (not in 1FPS output): {missing}")
    
    # Step 4: Helper - batching
    def chunk_list(lst, size):
        for i in range(0, len(lst), size):
            yield lst[i:i + size]
    
    # Step 5: Run inference in batches
    batch_size = 10  # tweak this if OOM persists
    answers = []
    
    for i, batch in enumerate(chunk_list(important_images, batch_size)):
        print(f"🔄 Running batch {i+1} with {len(batch)} images")
    
        content = [{"type": "image", "image": img} for img in batch]
        content.append({"type": "text", "text": question})
        messages = [{"role": "user", "content": content}]
    
        try:
            inputs = processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
            ).to(device)
    
            outputs = model.generate(**inputs, max_new_tokens=50)  # Reduce tokens for memory
            answer = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:])
            answers.append(answer.strip())
    
            torch.cuda.empty_cache()  # Free up memory
    
        except torch.cuda.OutOfMemoryError:
            print("❌ OOM in batch. Skipping or lowering batch size.")
            torch.cuda.empty_cache()
            continue


    
    # Step 6: Combine and print final answer
    final_answer = "\n".join(answers)

    print(f"\n🎥 Video ID: {video_id}")
    print(f"❓ Question: {question}")
    print("\n=== 🧠 Answers ===")
    print(final_answer)
    print("\n\n\n")

📌 Important frames from JSON: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 20, 22, 30, 31, 34, 40, 41, 53, 56, 62, 63, 64, 65, 67, 68, 70, 72, 74, 75, 79, 80, 83, 85, 87, 87, 88, 89, 90, 91, 92, 93, 102, 105, 107, 111, 114, 115, 116, 117, 118, 119, 121, 122, 124, 126, 126, 127, 128, 131, 140, 141, 141, 149, 151, 152, 153, 154, 155, 155, 156, 158, 160, 161, 162, 165, 168, 169, 169, 174, 175, 177]
🎞️ Video FPS: 30.0, total frames: 5400, duration: 180.00s
✅ Extracted 180 frames at ~1 FPS
🖼️ Found 86 important frames
🔄 Running batch 1 with 10 images
🔄 Running batch 2 with 10 images
🔄 Running batch 3 with 10 images
🔄 Running batch 4 with 10 images
🔄 Running batch 5 with 10 images
🔄 Running batch 6 with 10 images
🔄 Running batch 7 with 10 images
🔄 Running batch 8 with 10 images
🔄 Running batch 9 with 6 images

🎥 Video ID: 0074f737-11cb-497d-8d07-77c3a8127391
❓ Question: Question:
Taking into account all the actions performed by c, what can you deduce about the primary objective and foc