### Dependencies

In [None]:
!git clone https://github.com/subhadarship/kmeans_pytorch
%cd kmeans_pytorch

In [None]:
!pip install --editable .
!pip install --upgrade transformers accelerate bitsandbytes

### Copying 1 Video

In [None]:
import shutil
import os
from pathlib import Path

source = '/kaggle/input/egoschema/fed08b9b-7cbf-4f96-86a0-567a96b80125.mp4'

destination_folder = Path('/kaggle/working/videos')
destination_folder.mkdir(parents=True, exist_ok=True)

# This will copy the file with the same name
shutil.copy(source, os.path.join(destination_folder, os.path.basename(source)))

### Extract Images

In [None]:
import cv2
from pathlib import Path
from tqdm import tqdm
import json

def load_json(fn):
    with open(fn, 'r') as f:
        data = json.load(f)
    return data

def save_json(data, fn, indent=4):
    with open(fn, 'w') as f:
        json.dump(data, f, indent = indent)

input_base_path = Path('/kaggle/working/videos')
output_base_path = Path('/kaggle/working/extracted_frames')
output_base_path.mkdir(parents=True, exist_ok=True)

fps = 1

pbar = tqdm(total = len(list(input_base_path.iterdir())))

for video_fp in input_base_path.iterdir():
    output_path = output_base_path / video_fp.stem
    output_path.mkdir(parents=True, exist_ok=True) # video_fp.stem refers to the filename without the extension.

    vidcap = cv2.VideoCapture(str(video_fp)) # read videos frame by frame

    count = 0

    success = True

    fps_ori = int(vidcap.get(cv2.CAP_PROP_FPS)) # retrieves the original FPS of the input video

    frame_interval = int(fps_ori / fps)

    while success:
        success, image = vidcap.read()

        if not success:
            break
        if count % frame_interval == 0:
            cv2.imwrite(f'{output_path}/{count}.jpg', image) # save frame as JPG file
        count += 1
    pbar.update(1)
pbar.close()

### Extract Features

In [None]:
import os
import json
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import torch
from transformers import CLIPImageProcessor, AutoModel, BitsAndBytesConfig

# ----------- Configuration -----------
MODEL_NAME = "BAAI/EVA-CLIP-8B"
IMAGE_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BASE_PATH = Path('/kaggle/working/extracted_frames')
SAVE_PATH = Path('/kaggle/working/extracted_features')
SAVE_PATH.mkdir(parents=True, exist_ok=True)

JSON_PATH = Path('/kaggle/input/egoschema/fullset_anno.json')
MAX_EXAMPLES = 50  # Set limit on number of examples to process
RESUME = True  # Unused, but kept if resuming logic is added later

# ----------- Load model and processor -----------
processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-large-patch14')
quant_config = BitsAndBytesConfig(load_in_4bit=True)

model = AutoModel.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    quantization_config=quant_config
).to(DEVICE).eval()

# ----------- Utility Functions -----------

def save_image_features(img_feats: torch.Tensor, name_id: str, save_folder: Path):
    """Saves extracted image features to disk."""
    torch.save(img_feats, save_folder / f"{name_id}.pt")

def load_json(file_path: Path):
    with open(file_path, 'r') as f:
        return json.load(f)

# ----------- Main Processing Loop -----------

def process_images():
    json_data = load_json(JSON_PATH)
    valid_names = set(json_data.keys())

    example_dirs = list(BASE_PATH.iterdir())
    processed = 0

    pbar = tqdm(total=min(len(example_dirs), MAX_EXAMPLES), desc="Processing sets")

    for example_dir in example_dirs:
        if processed >= MAX_EXAMPLES:
            break
        if example_dir.name not in valid_names:
            continue

        image_files = sorted(example_dir.iterdir(), key=lambda x: int(x.stem))
        feature_list = []

        for image_file in image_files:
            image = Image.open(image_file)
            inputs = processor(images=image, return_tensors="pt").pixel_values.to(DEVICE)

            with torch.no_grad(), torch.cuda.amp.autocast():
                feats = model.encode_image(inputs)
                feature_list.append(feats)

        stacked_feats = torch.stack(feature_list).squeeze(1)
        save_image_features(stacked_feats, example_dir.name, SAVE_PATH)

        processed += 1
        pbar.update(1)

    pbar.close()

# ----------- Run -----------
if __name__ == "__main__":
    process_images()

### Adaptive Breath Expansion

In [None]:
!git clone https://github.com/Ziyang412/VideoTree.git 
%cd /kaggle/working/VideoTree

In [None]:
!pip install groq

In [None]:
import os
os.environ["GROQ_API_KEY"] = "Put Your Key Here"

In [None]:
# Cell 1: Imports and Configuration
import os
import json
from pathlib import Path
from tqdm import tqdm
from pprint import pprint
from kmeans_pytorch import kmeans
import torch
import re
from groq import Groq

# Import your original modules
from util import *
from prompts import PromptFactory

class SingleVideoConfig:
    """Configuration for single video processing with Groq"""
    def __init__(self):
        # Paths - UPDATE THESE FOR YOUR SETUP
        self.output_base_path = './outputs'
        self.output_filename = 'single_video_groq_pipeline.json'
        
        # Single video settings
        self.feature_file_path = '/kaggle/working/extracted_features/fed08b9b-7cbf-4f96-86a0-567a96b80125.pt'
        self.video_id = 'fed08b9b-7cbf-4f96-86a0-567a96b80125'
        
        # Original pipeline parameters
        self.frame_feat_path = '/kaggle/working/extracted_features'
        self.start_from_scratch = True
        self.num_examples_to_run = 1  # Just one video
        self.prompt_type = 'cap_score'
        self.fewshot_example_path = '/kaggle/input/egoschema/few_shot_6.json'
        self.data_path = '/kaggle/input/egoschema/fullset_anno.json'  # Keep for prompt setup
        self.max_cluster_num = 32
        self.init_cluster_num = 4
        self.iter_threshold = 5
        self.default_adaptive_rate = 2
        self.fps = 1.0
        self.num_words_in_sum = 100
        self.save_info = True
        self.save_every = 1
        self.backup_pred_path = ''
        self.disable_eval = True
        self.task = 'sum'
        self.dataset = 'egoschema'
        self.anno_path = ''
        self.captions_file_path = '/kaggle/input/egoschema/blip2_fullset.json'
        
        # Groq model configuration
        self.model = 'llama-3.3-70b-versatile'
        self.temperature = 0.0
        self.max_tokens = 1000

# Initialize config
config = SingleVideoConfig()
print("Configuration loaded!")
pprint(vars(config))


# Cell 2: Initialize Groq Model
from groq import Groq

class GroqModel:
    """Wrapper for Groq API to match original model interface"""
    def __init__(self, model_name='llama-3.3-70b-versatile', temperature=0.0, max_tokens=1000):
        self.client = Groq()  # Uses GROQ_API_KEY environment variable
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.post_process_fn = None
        
    def set_post_process_fn(self, fn):
        """Set post-processing function (same interface as original)"""
        self.post_process_fn = fn
        
    def forward(self, system_prompt, user_prompt):
        """Forward pass - same interface as original model"""
        try:
            # Combine prompts if user_prompt is a list
            if isinstance(user_prompt, list):
                combined_prompt = "\n".join(user_prompt)
            else:
                combined_prompt = user_prompt
                
            # Make API call
            chat_completion = self.client.chat.completions.create(
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": combined_prompt}
                ],
                model=self.model_name,
                temperature=self.temperature,
                max_tokens=self.max_tokens
            )
            
            # Get response
            response = chat_completion.choices[0].message.content
            
            # Apply post-processing if available
            if self.post_process_fn:
                pred = self.post_process_fn(response)
            else:
                pred = response
                
            # Return in same format as original
            info = {
                'response': response,
                'model': self.model_name,
                'tokens_used': chat_completion.usage.total_tokens if hasattr(chat_completion, 'usage') else 0
            }
            
            return pred, info
            
        except Exception as e:
            print(f"Error in Groq API call: {e}")
            return [], {'response': '', 'error': str(e)}

# Initialize Groq model
print("Initializing Groq model...")
model = GroqModel(
    model_name=config.model,
    temperature=config.temperature,
    max_tokens=config.max_tokens
)
print(f"✅ Groq model initialized: {config.model}")

# Test the model with a simple call
try:
    test_pred, test_info = model.forward(
        "You are a helpful assistant.", 
        "Say 'Hello, I am ready!' if you can understand this."
    )
    print(f"✅ Model test successful: {test_info['response'][:50]}...")
except Exception as e:
    print(f"❌ Model test failed: {e}")
    print("Please check your GROQ_API_KEY environment variable")

# Cell 3: Utility Functions

def load_frame_captions(captions_file_path, video_id):
    """Load frame captions for a specific video from JSON file"""
    try:
        captions_data = load_json(captions_file_path)
        
        if video_id in captions_data:
            video_captions_raw = captions_data[video_id]
            
            # Handle if it's a list instead of dict
            if isinstance(video_captions_raw, list):
                print(f"✅ Found {len(video_captions_raw)} frame captions (list format)")
                # Convert list to dict with frame indices
                frame_captions = {}
                for i, caption in enumerate(video_captions_raw):
                    frame_captions[i] = caption
                return frame_captions
            else:
                # Handle dict format
                print(f"✅ Found frame captions (dict format)")
                frame_captions = {}
                for key, caption in video_captions_raw.items():
                    try:
                        frame_idx = int(key)
                        frame_captions[frame_idx] = caption
                    except ValueError:
                        continue
                return frame_captions
        else:
            print(f"❌ Video {video_id} not found in captions file")
            return {}
            
    except Exception as e:
        print(f"❌ Error loading captions: {e}")
        return {}

def get_caption_for_frame(frame_idx, video_captions):
    """Get caption for a specific frame index"""
    if frame_idx in video_captions:
        return video_captions[frame_idx]
    else:
        # Return None for missing frames - we'll handle this properly
        return None

def load_frame_features(video_id, features_folder):
    """Load frame features for a video"""
    filename = f"{video_id}.pt"
    filepath = os.path.join(features_folder, filename)
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Feature file not found: {filepath}")
    return torch.load(filepath)

def find_closest_points_per_cluster(features, cluster_ids, cluster_centers):
    """Find closest points to cluster centers"""
    closest_points_idx_per_cluster = {cluster_id: [] for cluster_id in range(len(cluster_centers))}
    
    for cluster_id in range(len(cluster_centers)):
        indices_in_cluster = torch.where(cluster_ids == cluster_id)[0]
        points_in_cluster = features[indices_in_cluster]
        distances = torch.norm(points_in_cluster - cluster_centers[cluster_id], dim=1)
        if distances.numel() > 0:
            closest_idx_in_cluster = torch.argmin(distances).item()
            closest_global_idx = indices_in_cluster[closest_idx_in_cluster].item()
            closest_points_idx_per_cluster[cluster_id].append(closest_global_idx)
    
    return closest_points_idx_per_cluster

def build_fewshot_examples(fewshot_path, data_path):
    """Build few-shot examples if available"""
    try:
        if os.path.exists(fewshot_path):
            return load_json(fewshot_path)
        else:
            print(f"Few-shot file not found: {fewshot_path}, using empty examples")
            return []
    except:
        print("Error loading few-shot examples, using empty list")
        return []

def create_dummy_item(video_id, config):
    """Create a dummy item structure for the single video"""
    return {
        'quid': video_id,
        'uid': video_id,
        'video_id': video_id,
        'question': 'What activities does the person perform in this video?',
        'options': ['A) Cooking', 'B) Cleaning', 'C) Reading', 'D) Working', 'E) Other'],
        'optionA': 'Cooking and food preparation',
        'optionB': 'Cleaning and organizing',
        'optionC': 'Reading and studying', 
        'optionD': 'Working on computer',
        'optionE': 'Other activities',
        'answer': 'A',
        'duration': 180,  # Will be updated with actual frame count
        'narration': 'Video frame descriptions will be inserted here by the clustering algorithm'
    }

def create_relevance_prompt(tree_node, video_captions, item):
    """Create a focused prompt for relevance scoring"""
    
    # Get captions for representative frames (only those that have captions)
    frame_descriptions = []
    valid_frames = []
    
    for frame_idx in tree_node:
        caption = get_caption_for_frame(frame_idx, video_captions)
        if caption:  # Only include frames with actual captions
            frame_descriptions.append(f"Frame {frame_idx}: {caption}")
            valid_frames.append(frame_idx)
    
    if not frame_descriptions:
        print("❌ No captions found for any representative frames!")
        return None, []
    
    print(f"📝 Using {len(frame_descriptions)} frames with captions out of {len(tree_node)} total")
    
    # Create focused prompt for relevance scoring
    prompt = f"""VIDEO QUESTION: {item['question']}

OPTIONS:
{chr(10).join([f"{opt}" for opt in item['options']])}

FRAME DESCRIPTIONS:
{chr(10).join(frame_descriptions)}

TASK: Rate each frame's relevance to answering the question on a scale of 1-3:
- 1 = Not relevant (doesn't help answer the question)
- 2 = Somewhat relevant (provides some context)  
- 3 = Highly relevant (directly helps answer the question)

Provide ONLY the relevance scores in this exact format:
frame relevance: [score1, score2, score3, ...]

You must provide exactly {len(frame_descriptions)} scores for the {len(frame_descriptions)} frames described above.

Example: frame relevance: [2, 1, 3, 2, 1]"""
    
    return prompt, valid_frames

def update_relevance_response(text):
    """Extract relevance scores from model response - IMPROVED VERSION"""
    response = text.strip()
    print(f"🤖 Model response: {response}")
    
    # Try multiple patterns to extract relevance scores
    patterns = [
        r"frame relevance:\s*\[([0-9,\s]+)\]",
        r"relevance:\s*\[([0-9,\s]+)\]", 
        r"scores:\s*\[([0-9,\s]+)\]",
        r"relevance scores:\s*\[([0-9,\s]+)\]",
        r"\[([0-9,\s]+)\]"  # Any list of numbers
    ]
    
    for pattern in patterns:
        relevance_match = re.search(pattern, response, re.IGNORECASE)
        if relevance_match:
            try:
                # Extract and clean the numbers
                numbers_str = relevance_match.group(1)
                relevance = [int(x.strip()) for x in numbers_str.split(',') if x.strip().isdigit()]
                
                # Validate the scores (should be 1, 2, or 3)
                relevance = [max(1, min(3, score)) for score in relevance]
                
                print(f"✅ Extracted relevance scores: {relevance}")
                return relevance
            except Exception as e:
                print(f"❌ Error parsing relevance: {e}")
                continue
    
    print("❌ No relevance scores found in response")
    return []

# Load frame captions for this video
print(f"\n📝 Loading frame captions...")
video_captions = load_frame_captions(config.captions_file_path, config.video_id)

if video_captions:
    print(f"📊 Available frames: {min(video_captions.keys())}-{max(video_captions.keys())}")
    print(f"📝 Sample caption: '{list(video_captions.values())[0]}'")
else:
    print("❌ No captions loaded - pipeline will exit")
    exit()

print("✅ Utility functions loaded!")

# Cell 4: Initialize Prompter

# Initialize prompter
prompter = PromptFactory().get(config.prompt_type)
print(f"✅ Initialized prompter: {config.prompt_type}")

# Set up the model's post-processing function
model.set_post_process_fn(update_relevance_response)
print("✅ Set post-processing function for relevance extraction")

# Create dummy item for testing
item = create_dummy_item(config.video_id, config)
print(f"✅ Created dummy item for video: {config.video_id}")

# Cell 5: Load Video Features

print(f"Loading features from: {config.feature_file_path}")

# Check if feature file exists
if not os.path.exists(config.feature_file_path):
    print(f"❌ ERROR: Feature file not found: {config.feature_file_path}")
    exit()
else:
    # Load frame features
    frame_feats = load_frame_features(config.video_id, config.frame_feat_path)
    print(f"✅ Loaded features: shape {frame_feats.shape}, dtype {frame_feats.dtype}")
    
    # Move to appropriate device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    frame_feats = frame_feats.to(device)
    print(f"✅ Moved features to device: {device}")
    
    print(f"\nFeature summary:")
    print(f"  - Total frames: {len(frame_feats)}")
    print(f"  - Feature dimensions: {frame_feats.shape[1] if len(frame_feats.shape) > 1 else 1}")
    print(f"  - Device: {frame_feats.device}")
    print(f"  - Memory usage: {frame_feats.element_size() * frame_feats.nelement() / 1024 / 1024:.2f} MB")

# Cell 6: Adaptive Clustering with Groq Model

def adaptive_clustering_with_groq(frame_feats, config, model, prompter, item, video_captions):
    """
    Perform adaptive clustering with Groq model for relevance prediction
    """
    cluster_num = config.init_cluster_num
    device = frame_feats.device
    
    print(f"🚀 Starting adaptive clustering with Groq model")
    print(f"📊 Feature tensor shape: {frame_feats.shape}")
    print(f"🖥️  Device: {device}")
    
    clustering_results = []
    
    while cluster_num <= config.max_cluster_num:
        try:
            print(f"\n{'='*60}")
            print(f"🔄 CLUSTERING ATTEMPT: {cluster_num} clusters")
            print(f"{'='*60}")
            
            # Perform k-means clustering
            cluster_ids_x, cluster_centers = kmeans(
                X=frame_feats, 
                num_clusters=cluster_num, 
                distance='cosine', 
                device=device
            )
            
            cluster_ids_x = cluster_ids_x.to(device)
            cluster_centers = cluster_centers.to(device)
            
            # Find representative frames
            closest_points_idx_per_cluster = find_closest_points_per_cluster(
                frame_feats, cluster_ids_x, cluster_centers
            )
            
            if not closest_points_idx_per_cluster:
                print(f"❌ No valid clusters found, increasing to {cluster_num * config.default_adaptive_rate}")
                cluster_num *= config.default_adaptive_rate
                continue
            
            # Get representative frame indices
            tree_node = sorted([value for sublist in closest_points_idx_per_cluster.values() for value in sublist])
            cluster_ids_list = cluster_ids_x.tolist()
            
            print(f"📍 Representative frames: {tree_node}")
            print(f"📊 Unique clusters found: {len(set(cluster_ids_list))}")
            
            # Create custom relevance prompt instead of using original prompter
            prompt, valid_frames = create_relevance_prompt(tree_node, video_captions, item)
            
            if not prompt:
                print(f"❌ No valid prompt created, increasing clusters")
                cluster_num *= config.default_adaptive_rate
                continue
            
            print(f"📋 Created relevance prompt for {len(valid_frames)} frames")
            
            # Model inference with Groq
            print(f"🤖 Calling Groq model...")
            pred, info = model.forward(
                "You are an expert video analyst. Analyze frame descriptions and rate their relevance.", 
                prompt
            )
            
            print(f"📤 Model response received")
            print(f"🎯 Extracted prediction: {pred}")
            
            # Extract frame relevance
            frame_relevance = pred
            
            # Count high relevance frames (score = 3)
            if isinstance(frame_relevance, list) and len(frame_relevance) == len(valid_frames):
                high_relevance_frame_num = frame_relevance.count(3)
                print(f"📈 Relevance scores: {frame_relevance}")
                print(f"🎯 High relevance frames (score=3): {high_relevance_frame_num}")
            else:
                high_relevance_frame_num = 0
                print(f"⚠️  Relevance extraction failed or mismatch. Expected {len(valid_frames)}, got {len(frame_relevance) if isinstance(frame_relevance, list) else 'non-list'}")
            
            print(f"🎚️  Threshold: {config.iter_threshold}")
            
            # Store clustering result
            clustering_result = {
                'num_clusters': cluster_num,
                'actual_clusters': len(set(cluster_ids_list)),
                'representative_frames': tree_node,
                'valid_frames_with_captions': valid_frames,
                'cluster_assignments': cluster_ids_list,
                'frame_relevance': frame_relevance,
                'high_relevance_count': high_relevance_frame_num,
                'prompt': prompt,
                'model_response': info.get('response', ''),
                'tokens_used': info.get('tokens_used', 0)
            }
            clustering_results.append(clustering_result)
            
            # Check stopping condition
            if high_relevance_frame_num < config.iter_threshold:
                if cluster_num < config.max_cluster_num:
                    print(f"📉 Not enough high-relevance frames ({high_relevance_frame_num} < {config.iter_threshold})")
                    next_cluster_num = cluster_num * config.default_adaptive_rate
                    print(f"🔄 Increasing clusters: {cluster_num} → {next_cluster_num}")
                    cluster_num = next_cluster_num
                else:
                    print(f"🛑 Reached max clusters ({config.max_cluster_num}), stopping")
                    break
            else:
                print(f"✅ Found sufficient high-relevance frames ({high_relevance_frame_num} >= {config.iter_threshold})")
                print(f"🏁 Stopping clustering - SUCCESS!")
                break
                
        except Exception as e:
            print(f"❌ Clustering failed with {cluster_num} clusters: {e}")
            import traceback
            traceback.print_exc()
            cluster_num *= config.default_adaptive_rate
            continue
    
    print(f"\n{'='*60}")
    print(f"🏆 CLUSTERING COMPLETE!")
    print(f"{'='*60}")
    
    # Return the final successful clustering result
    if clustering_results:
        final_result = clustering_results[-1]
        return (final_result['representative_frames'], 
                final_result['cluster_assignments'], 
                final_result['frame_relevance'],
                final_result['high_relevance_count'],
                clustering_results)
    else:
        return [], [], [], 0, []

print("✅ Adaptive clustering function loaded!")

# Cell 7: Run the Complete Pipeline

print("🚀 STARTING COMPLETE PIPELINE")
print("="*80)

# Run adaptive clustering with Groq model
(representative_frames, cluster_assignments, 
 frame_relevance, high_relevance_count, all_results) = adaptive_clustering_with_groq(
    frame_feats, config, model, prompter, item, video_captions
)

print("\n" + "="*80)
print("📊 FINAL RESULTS")
print("="*80)

# Create comprehensive result
result = {
    'video_id': config.video_id,
    'feature_file_path': config.feature_file_path,
    'total_frames': len(frame_feats),
    'feature_dimensions': frame_feats.shape[1] if len(frame_feats.shape) > 1 else 1,
    'final_result': {
        'representative_frames': representative_frames,
        'cluster_assignments': cluster_assignments,
        'frame_relevance': frame_relevance,
        'high_relevance_count': high_relevance_count,
        'num_clusters': len(set(cluster_assignments)) if cluster_assignments else 0,
        'passed_threshold': high_relevance_count >= config.iter_threshold
    },
    'all_clustering_attempts': all_results,
    'config': {k: v for k, v in vars(config).items() if not k.startswith('_')},
    'item_data': item,
    'model_info': {
        'model_name': config.model,
        'total_tokens_used': sum(attempt.get('tokens_used', 0) for attempt in all_results)
    }
}

# Print summary
print(f"🎬 Video ID: {result['video_id']}")
print(f"📊 Total frames: {result['total_frames']}")
print(f"🔢 Feature dimensions: {result['feature_dimensions']}")

final = result['final_result']
print(f"🎯 Final clusters: {final['num_clusters']}")
print(f"📍 Representative frames: {len(final['representative_frames'])}")
print(f"📋 Frame indices: {final['representative_frames']}")
print(f"⭐ High relevance count: {final['high_relevance_count']}")
print(f"✅ Passed threshold ({config.iter_threshold}): {final['passed_threshold']}")
print(f"📈 Frame relevance scores: {final['frame_relevance']}")

# Show clustering progression
print(f"\n📈 Clustering progression:")
for i, attempt in enumerate(result['all_clustering_attempts']):
    status = "🏆 FINAL" if i == len(result['all_clustering_attempts']) - 1 else "➡️  continue"
    tokens = attempt.get('tokens_used', 0)
    print(f"  Attempt {i+1}: {attempt['num_clusters']} clusters → "
          f"{attempt['high_relevance_count']} high-relevance frames ({tokens} tokens) {status}")

total_tokens = result['model_info']['total_tokens_used']
print(f"\n🔤 Total tokens used: {total_tokens}")

print(f"\n✅ Pipeline completed successfully!")
print("="*80)

# Cell 8: Save Results

# Create output directory
makedir(config.output_base_path)
output_path = os.path.join(config.output_base_path, config.output_filename)

# Save results
save_json(result, output_path)

print("💾 SAVING RESULTS")
print("="*50)
print(f"📁 Output directory: {config.output_base_path}")
print(f"📄 Output file: {config.output_filename}")
print(f"🔗 Full path: {output_path}")

# Show file size
if os.path.exists(output_path):
    file_size = os.path.getsize(output_path) / 1024  # KB
    print(f"📏 File size: {file_size:.2f} KB")

print(f"\n✅ Results saved successfully!")

# Optional: Show a snippet of the saved file
print(f"\n📋 Results summary:")
print(f"  - Video: {result['video_id']}")
print(f"  - Frames: {result['total_frames']}")
print(f"  - Clusters: {result['final_result']['num_clusters']}")
print(f"  - Representative frames: {len(result['final_result']['representative_frames'])}")
print(f"  - High relevance: {result['final_result']['high_relevance_count']}")
print(f"  - Success: {result['final_result']['passed_threshold']}")

print("="*50)
print("🎉 ALL DONE! Your video has been processed with Groq + Llama 3.3 70B!")

### Depth Expansion

In [None]:
# Cell 1: Depth Expansion - Imports and Functions

import numpy as np
import torch
import torch.nn.functional as F
from scipy.cluster.hierarchy import linkage, fcluster
import json
import os
from pathlib import Path
from tqdm import tqdm

def hierarchical_clustering_with_external_primary(video_features, cluster_ids, relevance_scores, num_subclusters=5, num_subsubclusters=5):
    """
    Perform hierarchical clustering based on relevance scores:
    - Score 1: Keep only primary cluster
    - Score 2: Split into subclusters  
    - Score 3: Split into sub-subclusters
    """
    clusters = {i: {} for i in range(0, max(cluster_ids)+1)}

    for cluster_id in set(cluster_ids):
        primary_indices = [i for i, x in enumerate(cluster_ids) if x == cluster_id]

        if cluster_id < len(relevance_scores):
            score = relevance_scores[cluster_id]
        else:
            score = 3

        if len(primary_indices) < 2:
            clusters[cluster_id] = primary_indices
            continue

        sub_features = video_features[primary_indices]

        if score == 1:
            # Low relevance: keep as single cluster
            clusters[cluster_id] = primary_indices
            continue

        # Create subclusters
        linked_sub = linkage(sub_features, method='ward')
        sub_cluster_labels = fcluster(linked_sub, num_subclusters, criterion='maxclust')
        sub_cluster_labels = sub_cluster_labels - 1

        if score == 2:
            # Medium relevance: split into subclusters
            clusters[cluster_id] = {i: [primary_indices[j] for j in np.where(sub_cluster_labels == i)[0]] for i in range(0, num_subclusters)}
            continue

        # High relevance (score == 3): split into sub-subclusters
        clusters[cluster_id] = {}
        for subcluster_id in range(0, num_subclusters):
            sub_indices = np.where(sub_cluster_labels == subcluster_id)[0]
            if len(sub_indices) < 2:
                continue

            subsub_features = sub_features[sub_indices]
            linked_subsub = linkage(subsub_features, method='ward')
            subsub_cluster_labels = fcluster(linked_subsub, num_subsubclusters, criterion='maxclust')
            subsub_cluster_labels = subsub_cluster_labels - 1

            clusters[cluster_id][subcluster_id] = {}
            for subsubcluster_id in range(0, num_subsubclusters):
                final_indices = sub_indices[np.where(subsub_cluster_labels == subsubcluster_id)[0]]
                original_indices = [primary_indices[i] for i in final_indices]
                clusters[cluster_id][subcluster_id][subsubcluster_id] = original_indices

    return clusters

def cosine_similarity(points, centroid):
    """Calculate cosine similarity between points and centroid."""
    points_normalized = F.normalize(points, dim=1)
    centroid_normalized = F.normalize(centroid.unsqueeze(0), dim=1)
    return 1 - torch.mm(points_normalized, centroid_normalized.T).squeeze()

def find_closest_points_in_temporal_order_subsub(x, clusters, relevance_scores):
    """Find representative frames from hierarchical clusters in temporal order."""
    closest_points_indices = []

    for cluster_id, cluster_data in clusters.items():
        if cluster_id < len(relevance_scores):
            relevance = relevance_scores[cluster_id]
        else:
            relevance = 3

        if isinstance(cluster_data, list):  # Primary cluster directly
            cluster_data = np.array(cluster_data)
            if cluster_data.size == 0:
                continue
            points_in_cluster = x[torch.tensor(cluster_data, dtype=torch.long)]
            cluster_centroid = points_in_cluster.mean(dim=0)
            distances = cosine_similarity(points_in_cluster, cluster_centroid)
            if distances.numel() > 0:
                closest_idx = torch.argmin(distances).item()
                closest_points_indices.append(int(cluster_data[closest_idx]))

        elif isinstance(cluster_data, dict):  # Handle subclusters and sub-subclusters
            if relevance == 1:
                # Only take representative frame for primary cluster
                primary_indices = []
                for subcluster_data in cluster_data.values():
                    if isinstance(subcluster_data, dict):
                        for sub_data in subcluster_data.values():
                            if len(sub_data) > 0:
                                primary_indices.extend(sub_data)
                    elif isinstance(subcluster_data, list) and len(subcluster_data) > 0:
                        primary_indices.extend(subcluster_data)

                if primary_indices:
                    primary_indices = np.array(primary_indices)
                    primary_points = x[torch.tensor(primary_indices, dtype=torch.long)]
                    primary_centroid = primary_points.mean(dim=0)
                    primary_distances = cosine_similarity(primary_points, primary_centroid)
                    if primary_distances.numel() > 0:
                        closest_primary_idx = torch.argmin(primary_distances).item()
                        closest_points_indices.append(int(primary_indices[closest_primary_idx]))
                continue

            elif relevance == 2 or relevance == 3:
                # Include primary cluster representative
                primary_indices = []
                for subcluster_data in cluster_data.values():
                    if isinstance(subcluster_data, dict):
                        for sub_data in subcluster_data.values():
                            if len(sub_data) > 0:
                                primary_indices.extend(sub_data)
                    elif isinstance(subcluster_data, list) and len(subcluster_data) > 0:
                        primary_indices.extend(subcluster_data)

                if primary_indices:
                    primary_indices = np.array(primary_indices)
                    primary_points = x[torch.tensor(primary_indices, dtype=torch.long)]
                    primary_centroid = primary_points.mean(dim=0)
                    primary_distances = cosine_similarity(primary_points, primary_centroid)
                    if primary_distances.numel() > 0:
                        closest_primary_idx = torch.argmin(primary_distances).item()
                        closest_points_indices.append(int(primary_indices[closest_primary_idx]))

                # Process subclusters/sub-subclusters
                for subcluster_id, subclusters in cluster_data.items():
                    if isinstance(subclusters, dict):  # Sub-subclusters
                        for subsubcluster_id, indices in subclusters.items():
                            if len(indices) == 0:
                                continue
                            indices_tensor = torch.tensor(indices, dtype=torch.long)
                            points_in_subsubcluster = x[indices_tensor]
                            subsubcluster_centroid = points_in_subsubcluster.mean(dim=0)
                            distances = cosine_similarity(points_in_subsubcluster, subsubcluster_centroid)
                            if distances.numel() > 0:
                                closest_idx_in_subsubcluster = torch.argmin(distances).item()
                                closest_global_idx = indices[closest_idx_in_subsubcluster]
                                closest_points_indices.append(int(closest_global_idx))

                    elif isinstance(subclusters, list):
                        subclusters = np.array(subclusters)
                        if subclusters.size == 0:
                            continue
                        points_in_subcluster = x[torch.tensor(subclusters, dtype=torch.long)]
                        subcluster_centroid = points_in_subcluster.mean(dim=0)
                        distances = cosine_similarity(points_in_subcluster, subcluster_centroid)
                        if distances.numel() > 0:
                            closest_idx = torch.argmin(distances).item()
                            closest_points_indices.append(int(subclusters[closest_idx]))

    closest_points_indices.sort()  # Ensure temporal order
    return closest_points_indices

def load_image_features(name_ids, save_folder):
    """Load image features from a .pt file."""
    filename = f"{name_ids}.pt"
    filepath = os.path.join(save_folder, filename)
    img_feats = torch.load(filepath)
    return img_feats

def load_json(fn):
    """Load JSON file."""
    with open(fn, 'r') as f:
        data = json.load(f)
    return data

def save_json(data, fn, indent=4):
    """Save data to JSON file."""
    with open(fn, 'w') as f:
        json.dump(data, f, indent=indent)

print("✅ Depth expansion functions loaded!")

In [None]:
# Cell 2: Load Previous Results and Setup Paths

# Configuration for depth expansion
class DepthExpansionConfig:
    def __init__(self):
        # Input paths - UPDATE THESE
        self.save_folder = '/kaggle/working/extracted_features'  # Where your .pt files are
        self.video_id = 'fed08b9b-7cbf-4f96-86a0-567a96b80125'  # Your video ID
        
        # Previous results from Groq pipeline
        self.groq_results_path = './outputs/single_video_groq_pipeline.json'
        
        # Output paths
        self.output_base_path = './outputs'
        self.output_filename = 'depth_expansion_results.json'
        
        # Hierarchical clustering parameters
        self.num_subclusters = 4
        self.num_subsubclusters = 4

config = DepthExpansionConfig()
print("📋 Depth Expansion Configuration:")
print(f"  - Video ID: {config.video_id}")
print(f"  - Features folder: {config.save_folder}")
print(f"  - Groq results: {config.groq_results_path}")
print(f"  - Output: {config.output_filename}")

# Load previous Groq pipeline results
print(f"\n📂 Loading previous Groq results...")
try:
    with open(config.groq_results_path, 'r') as f:
        groq_results = json.load(f)
    
    print("✅ Groq results loaded successfully!")
    
    # Extract needed data
    cluster_assignments = groq_results['final_result']['cluster_assignments']
    frame_relevance = groq_results['final_result']['frame_relevance'] 
    representative_frames = groq_results['final_result']['representative_frames']
    
    print(f"📊 Data extracted:")
    print(f"  - Total frames: {groq_results['total_frames']}")
    print(f"  - Clusters: {len(set(cluster_assignments))}")
    print(f"  - Representative frames: {len(representative_frames)}")
    print(f"  - Frame relevance: {frame_relevance}")
    print(f"  - Cluster assignments: {cluster_assignments[:10]}...")  # Show first 10
    
except FileNotFoundError:
    print(f"❌ Error: Could not find Groq results at {config.groq_results_path}")
    print("Please run the Groq pipeline first!")
except Exception as e:
    print(f"❌ Error loading Groq results: {e}")

# Check if feature file exists
feature_path = os.path.join(config.save_folder, f"{config.video_id}.pt")
if os.path.exists(feature_path):
    print(f"✅ Feature file found: {feature_path}")
else:
    print(f"❌ Feature file not found: {feature_path}")

print("\n" + "="*60)

In [None]:
# Cell 3: Run Depth Expansion

print("🚀 STARTING DEPTH EXPANSION")
print("="*60)

# Load video features
print(f"📂 Loading features for {config.video_id}...")
img_feats = load_image_features(config.video_id, config.save_folder)
img_feats = img_feats.cpu()  # Move to CPU for scipy operations
print(f"✅ Features loaded: shape {img_feats.shape}")

# Use frame relevance as cluster relevance scores
# Map frame relevance to cluster relevance
print(f"📊 Processing relevance scores...")

# Check if we have frame relevance scores
if isinstance(frame_relevance, list) and len(frame_relevance) > 0:
    print(f"✅ Using extracted relevance scores: {frame_relevance}")
    relevance_scores = frame_relevance
else:
    print("⚠️  No relevance scores found, using default scores...")
    # Create default relevance scores (all medium relevance)
    num_clusters = len(set(cluster_assignments))
    relevance_scores = [2] * num_clusters  # Default to medium relevance

print(f"📈 Relevance scores: {relevance_scores}")
print(f"🎯 Cluster assignments: {len(cluster_assignments)} total assignments")

# Perform hierarchical clustering with external primary clusters
print(f"\n🔄 Performing hierarchical clustering...")
print(f"  - Primary clusters: {len(set(cluster_assignments))}")
print(f"  - Subclusters per cluster: {config.num_subclusters}")
print(f"  - Sub-subclusters per subcluster: {config.num_subsubclusters}")

clusters_info = hierarchical_clustering_with_external_primary(
    img_feats, 
    cluster_assignments, 
    relevance_scores,
    num_subclusters=config.num_subclusters,
    num_subsubclusters=config.num_subsubclusters
)

print(f"✅ Hierarchical clustering complete!")
print(f"📊 Clusters info type: {type(clusters_info)}")

# Find representative points in temporal order
print(f"\n🎯 Finding representative points in temporal order...")
closest_points_temporal = find_closest_points_in_temporal_order_subsub(
    img_feats, 
    clusters_info, 
    relevance_scores
)

print(f"✅ Representative points found!")
print(f"📍 Number of representative points: {len(closest_points_temporal)}")
print(f"🕐 Temporal order: {closest_points_temporal}")

# Compare with original representative frames
print(f"\n📊 COMPARISON:")
print(f"  Original (width expansion): {representative_frames}")
print(f"  New (depth expansion): {closest_points_temporal}")
print(f"  Original count: {len(representative_frames)}")
print(f"  New count: {len(closest_points_temporal)}")

# Calculate expansion ratio
if len(representative_frames) > 0:
    expansion_ratio = len(closest_points_temporal) / len(representative_frames)
    print(f"  📈 Expansion ratio: {expansion_ratio:.2f}x")
else:
    expansion_ratio = 0
    print(f"  ❌ No original frames to compare")

print("="*60)

In [None]:
# Cell 4: Save Results and Analysis

# Create comprehensive results
depth_results = {
    "video_id": config.video_id,
    "input_data": {
        "total_frames": len(img_feats),
        "feature_dimensions": img_feats.shape[1] if len(img_feats.shape) > 1 else 1,
        "original_clusters": len(set(cluster_assignments)),
        "original_representative_frames": representative_frames,
        "relevance_scores": relevance_scores,
        "cluster_assignments": cluster_assignments
    },
    "depth_expansion_config": {
        "num_subclusters": config.num_subclusters,
        "num_subsubclusters": config.num_subsubclusters
    },
    "results": {
        "hierarchical_clusters": len(clusters_info),
        "final_representative_frames": closest_points_temporal,
        "expansion_ratio": len(closest_points_temporal) / len(representative_frames) if len(representative_frames) > 0 else 0,
        "total_representative_frames": len(closest_points_temporal)
    },
    "analysis": {
        "relevance_distribution": {
            "score_1": relevance_scores.count(1) if isinstance(relevance_scores, list) else 0,
            "score_2": relevance_scores.count(2) if isinstance(relevance_scores, list) else 0,
            "score_3": relevance_scores.count(3) if isinstance(relevance_scores, list) else 0
        },
        "frame_coverage": {
            "min_frame": min(closest_points_temporal) if closest_points_temporal else 0,
            "max_frame": max(closest_points_temporal) if closest_points_temporal else 0,
            "frame_span": max(closest_points_temporal) - min(closest_points_temporal) + 1 if closest_points_temporal else 0
        }
    }
}

# Save results
output_path = os.path.join(config.output_base_path, config.output_filename)
save_json(depth_results, output_path)

print("💾 SAVING DEPTH EXPANSION RESULTS")
print("="*50)
print(f"📁 Output path: {output_path}")

# Show file size
if os.path.exists(output_path):
    file_size = os.path.getsize(output_path) / 1024  # KB
    print(f"📏 File size: {file_size:.2f} KB")

print(f"✅ Results saved successfully!")

# Print detailed analysis
print(f"\n📊 DEPTH EXPANSION ANALYSIS")
print("="*50)

results = depth_results["results"]
analysis = depth_results["analysis"]

print(f"🎬 Video: {depth_results['video_id']}")
print(f"📊 Total frames: {depth_results['input_data']['total_frames']}")
print(f"🔢 Feature dimensions: {depth_results['input_data']['feature_dimensions']}")

print(f"\n📈 Expansion Results:")
print(f"  - Original clusters: {depth_results['input_data']['original_clusters']}")
print(f"  - Original representative frames: {len(depth_results['input_data']['original_representative_frames'])}")
print(f"  - New representative frames: {results['total_representative_frames']}")
print(f"  - Expansion ratio: {results['expansion_ratio']:.2f}x")

print(f"\n⭐ Relevance Distribution:")
rel_dist = analysis["relevance_distribution"]
print(f"  - Score 1 (Low): {rel_dist['score_1']} clusters")
print(f"  - Score 2 (Medium): {rel_dist['score_2']} clusters")
print(f"  - Score 3 (High): {rel_dist['score_3']} clusters")

print(f"\n🎯 Frame Coverage:")
coverage = analysis["frame_coverage"]
print(f"  - Frame range: {coverage['min_frame']} → {coverage['max_frame']}")
print(f"  - Frame span: {coverage['frame_span']} frames")

print(f"\n📍 Representative Frames:")
print(f"  Original: {depth_results['input_data']['original_representative_frames']}")
print(f"  Expanded: {results['final_representative_frames']}")

print("="*50)
print("🎉 DEPTH EXPANSION COMPLETE!")

# Quick summary
print(f"\n📋 SUMMARY:")
if results['expansion_ratio'] > 1:
    print(f"✅ Successfully expanded from {len(representative_frames)} to {results['total_representative_frames']} frames")
    print(f"📈 {results['expansion_ratio']:.1f}x more detailed frame selection!")
elif results['expansion_ratio'] == 1:
    print(f"➡️  Same number of frames, but hierarchically organized")
else:
    print(f"📉 Fewer frames selected: {results['total_representative_frames']} vs {len(representative_frames)}")

print(f"🎬 Your video now has {results['total_representative_frames']} key representative frames!")

### LLM Reasoning

In [None]:
import json
import os
import base64
import cv2
from groq import Groq

class SimpleVideoQA:
    def __init__(self):
        self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
        self.vision_model = "meta-llama/llama-4-scout-17b-16e-instruct"
        self.text_model = "llama-3.3-70b-versatile"
    
    def load_frames(self):
        """Load the frames from depth expansion results"""
        with open('/kaggle/working/VideoTree/outputs/depth_expansion_results.json', 'r') as f:
            data = json.load(f)
        
        frames = data['results']['final_representative_frames']
        video_id = data['video_id']
        
        print(f"📍 Loaded {len(frames)} frames for video {video_id}")
        print(f"📊 Frame range: {min(frames)} → {max(frames)}")
        return frames, video_id
    
    def extract_frame(self, video_path, frame_num):
        """Extract one frame from video"""
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        cap.release()
        
        if ret:
            temp_path = f"./temp_frame_{frame_num}.jpg"
            cv2.imwrite(temp_path, frame)
            return temp_path
        return None
    
    def caption_frame(self, frame_path, frame_num):
        """Caption one frame with Groq vision"""
        try:
            # Encode image
            with open(frame_path, "rb") as f:
                base64_image = base64.b64encode(f.read()).decode('utf-8')
            
            # Get caption
            response = self.client.chat.completions.create(
                messages=[{
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Describe what the person is doing in this frame."},
                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
                    ]
                }],
                model=self.vision_model,
                max_tokens=100
            )
            
            caption = response.choices[0].message.content.strip()
            
            # Cleanup
            os.remove(frame_path)
            
            return caption
            
        except Exception as e:
            print(f"Error captioning frame {frame_num}: {e}")
            if os.path.exists(frame_path):
                os.remove(frame_path)
            return f"Frame {frame_num}: [Caption failed]"
    
    def caption_all_frames(self, frames, video_path):
        """Caption all frames"""
        print(f"🎬 Captioning {len(frames)} frames...")
        
        captions = []
        for i, frame_num in enumerate(frames):
            print(f"  Frame {i+1}/{len(frames)} (frame {frame_num})", end='\r')
            
            # Extract frame
            frame_path = self.extract_frame(video_path, frame_num)
            
            if frame_path:
                # Caption frame
                caption = self.caption_frame(frame_path, frame_num)
                captions.append(f"Frame {frame_num}: {caption}")
            else:
                captions.append(f"Frame {frame_num}: [Extraction failed]")
        
        print(f"\n✅ Captioned {len(captions)} frames")
        return captions
    
    def answer_question(self, captions, question):
        """Send captions + question to LLM, get answer"""
        print(f"🤖 Answering question...")
        
        # Join all captions
        all_captions = "\n".join(captions)
        
        # Create prompt
        prompt = f"""Here are descriptions of frames from a video:

{all_captions}

Question: {question}

Answer the question based on what you see in these frame descriptions."""
        
        # Get answer
        response = self.client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are analyzing video frames to answer questions."},
                {"role": "user", "content": prompt}
            ],
            model=self.text_model,
            max_tokens=300
        )
        
        answer = response.choices[0].message.content.strip()
        tokens = response.usage.total_tokens if hasattr(response, 'usage') else 0
        
        print(f"✅ Answer generated ({tokens} tokens)")
        return answer, tokens
    
    def save_results(self, frames, captions, question, answer, tokens):
        """Save everything"""
        results = {
            'total_frames': len(frames),
            'frame_numbers': frames,
            'captions': captions,
            'question': question,
            'answer': answer,
            'tokens_used': tokens
        }
        
        os.makedirs('/kaggle/working/VideoTree/outputs', exist_ok=True)
        with open('/kaggle/working/VideoTree/outputs/simple_qa_results.json', 'w') as f:
            json.dump(results, f, indent=2)
        
        print("💾 Results saved to /kaggle/working/VideoTree/outputs/simple_qa_results.json")

def run_simple_qa(video_path, question):
    """Run the simple caption + QA pipeline"""
    
    print("🚀 SIMPLE VIDEO QA PIPELINE")
    print("="*40)
    
    qa = SimpleVideoQA()
    
    # 1. Load frames
    frames, video_id = qa.load_frames()
    
    # 2. Caption all frames
    captions = qa.caption_all_frames(frames, video_path)
    
    # 3. Answer question
    answer, tokens = qa.answer_question(captions, question)
    
    # 4. Save results
    qa.save_results(frames, captions, question, answer, tokens)
    
    # 5. Show results
    print(f"\n🎉 DONE!")
    print("="*40)
    print(f"📊 Frames: {len(frames)}")
    print(f"❓ Question: {question}")
    print(f"💬 Answer: {answer}")
    print(f"🔤 Tokens: {tokens}")

# Usage
if __name__ == "__main__":
    video_path = "/kaggle/input/egoschema/fed08b9b-7cbf-4f96-86a0-567a96b80125.mp4"
    question = "What is the person doing in this video?"
    
    if os.path.exists(video_path):
        run_simple_qa(video_path, question)
    else:
        print(f"Video not found: {video_path}")

### Uniform Sampling - 85 frames

In [None]:
import json
import os
import base64
import cv2
import numpy as np
from groq import Groq

class UniformSamplingQA:
    def __init__(self):
        self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
        self.vision_model = "meta-llama/llama-4-scout-17b-16e-instruct"
        self.text_model = "llama-3.3-70b-versatile"
    
    def get_uniform_frames(self, total_frames, num_samples=85):
        """Generate 85 uniformly sampled frame indices"""
        # Create uniform spacing
        if num_samples >= total_frames:
            frames = list(range(total_frames))
        else:
            # Calculate step size for uniform sampling
            step = total_frames / num_samples
            frames = [int(i * step) for i in range(num_samples)]
            
            # Ensure we don't exceed total frames
            frames = [min(f, total_frames - 1) for f in frames]
        
        print(f"📍 Generated {len(frames)} uniform frames from {total_frames} total frames")
        print(f"📊 Frame range: {min(frames)} → {max(frames)}")
        print(f"📏 Step size: ~{step:.2f}")
        
        return frames
    
    def extract_frame(self, video_path, frame_num):
        """Extract one frame from video"""
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        cap.release()
        
        if ret:
            temp_path = f"./temp_uniform_frame_{frame_num}.jpg"
            cv2.imwrite(temp_path, frame)
            return temp_path
        return None
    
    def caption_frame(self, frame_path, frame_num):
        """Caption one frame with Groq vision"""
        try:
            # Encode image
            with open(frame_path, "rb") as f:
                base64_image = base64.b64encode(f.read()).decode('utf-8')
            
            # Get caption
            response = self.client.chat.completions.create(
                messages=[{
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Describe what the person is doing in this frame."},
                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
                    ]
                }],
                model=self.vision_model,
                max_tokens=100
            )
            
            caption = response.choices[0].message.content.strip()
            
            # Cleanup
            os.remove(frame_path)
            
            return caption
            
        except Exception as e:
            print(f"Error captioning frame {frame_num}: {e}")
            if os.path.exists(frame_path):
                os.remove(frame_path)
            return f"Frame {frame_num}: [Caption failed]"
    
    def caption_all_frames(self, frames, video_path):
        """Caption all frames"""
        print(f"🎬 Captioning {len(frames)} uniformly sampled frames...")
        
        captions = []
        for i, frame_num in enumerate(frames):
            print(f"  Frame {i+1}/{len(frames)} (frame {frame_num})", end='\r')
            
            # Extract frame
            frame_path = self.extract_frame(video_path, frame_num)
            
            if frame_path:
                # Caption frame
                caption = self.caption_frame(frame_path, frame_num)
                captions.append(f"Frame {frame_num}: {caption}")
            else:
                captions.append(f"Frame {frame_num}: [Extraction failed]")
        
        print(f"\n✅ Captioned {len(captions)} frames")
        return captions
    
    def answer_question(self, captions, question):
        """Send captions + question to LLM, get answer"""
        print(f"🤖 Answering question...")
        
        # Join all captions
        all_captions = "\n".join(captions)
        
        # Create prompt
        prompt = f"""Here are descriptions of uniformly sampled frames from a video:

{all_captions}

Question: {question}

Answer the question based on what you see in these frame descriptions."""
        
        # Get answer
        response = self.client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are analyzing uniformly sampled video frames to answer questions."},
                {"role": "user", "content": prompt}
            ],
            model=self.text_model,
            max_tokens=300
        )
        
        answer = response.choices[0].message.content.strip()
        tokens = response.usage.total_tokens if hasattr(response, 'usage') else 0
        
        print(f"✅ Answer generated ({tokens} tokens)")
        return answer, tokens
    
    def save_results(self, frames, captions, question, answer, tokens, total_frames):
        """Save everything"""
        results = {
            'sampling_method': 'uniform',
            'total_video_frames': total_frames,
            'sampled_frames': len(frames),
            'frame_numbers': frames,
            'captions': captions,
            'question': question,
            'answer': answer,
            'tokens_used': tokens,
            'sampling_info': {
                'step_size': total_frames / len(frames),
                'coverage': f"{min(frames)}-{max(frames)}"
            }
        }
        
        os.makedirs('/kaggle/working/VideoTree/outputs', exist_ok=True)
        with open('/kaggle/working/VideoTree/outputs/uniform_sampling_qa_results.json', 'w') as f:
            json.dump(results, f, indent=2)
        
        print("💾 Results saved to /kaggle/working/VideoTree/outputs/uniform_sampling_qa_results.json")

def run_uniform_sampling_qa(video_path, question, total_frames=180, num_samples=85):
    """Run the uniform sampling caption + QA pipeline"""
    
    print("🚀 UNIFORM SAMPLING VIDEO QA PIPELINE")
    print("="*45)
    
    qa = UniformSamplingQA()
    
    # 1. Generate uniform frame indices
    frames = qa.get_uniform_frames(total_frames, num_samples)
    
    # 2. Caption all frames
    captions = qa.caption_all_frames(frames, video_path)
    
    # 3. Answer question
    answer, tokens = qa.answer_question(captions, question)
    
    # 4. Save results
    qa.save_results(frames, captions, question, answer, tokens, total_frames)
    
    # 5. Show results
    print(f"\n🎉 UNIFORM SAMPLING DONE!")
    print("="*45)
    print(f"📊 Total frames: {total_frames}")
    print(f"📍 Sampled frames: {len(frames)}")
    print(f"📏 Step size: ~{total_frames/len(frames):.2f}")
    print(f"❓ Question: {question}")
    print(f"💬 Answer: {answer}")
    print(f"🔤 Tokens: {tokens}")

# Usage
if __name__ == "__main__":
    video_path = "/kaggle/input/egoschema/fed08b9b-7cbf-4f96-86a0-567a96b80125.mp4"
    question = "What is the person doing in this video?"
    
    if os.path.exists(video_path):
        run_uniform_sampling_qa(video_path, question, total_frames=180, num_samples=85)
    else:
        print(f"Video not found: {video_path}")