# PrimateFace Tutorial: Automated Video Timestamping

| GitHub | Paper | Website |
|---|---|---|
| [Code](https://github.com/KordingLab/PrimateFace) | [Preprint](https://www.biorxiv.org/content/10.1101/2025.08.12.669927v2) | [Project](https://primateface.studio/) |

Welcome! This tutorial notebook walks through using a **PrimateFace** detection model to automatically find and timestamp primate faces in videos, creating a 'visibility baseline' for behavioral coding.

### This notebook will:
1. **Set up the environment** to run _mmdetection_ face detection models trained on PrimateFace data.
2. **Download PrimateFace model**
3. **Load video data** from local path or Google Drive link (default).
4. **Calibrate Detection Confidence**: Use an interactive tool to find the optimal face detection sensitivity for your video.
5. **Analyze Video**: Apply this setting to process the entire video and log all face detections.
6. **Visualize Results**: Create a timeline of face visibility.
7. **Export to BORIS**: Automatically convert the results into a `.boris` project file with "face present" events pre-coded.

---
### **Quick Start Instructions**

*   **Set Your Runtime to GPU**: Go to **Runtime > Change runtime type > T4 GPU**. This is essential for performance.
*   **Run Cells Sequentially**: Click the "Play" button on each cell to run it. The first few cells handle setup and may take a few moments. Do not 'Run All'
 so that you don't skip over the interactive threshold picker.
*   **Restarting**: Installiing the dependencies requires restarting the Runtime. You can do this by clicking the "Runtime" menu and selecting "Restart session".

This is a demo notebook and extending this to your own videos requires setting up the environment, GPU access, and paths to your own data.

Let's begin!

## **1. Set up the environment (this will take a couple of min)**

In [None]:
#@title Check GPU availability. If not available,, change Runtime type to 'T4 GPU'
!nvidia-smi

Tue Aug 19 20:40:55 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   44C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
#@title Install PyTorch, mmdet, mmpose, and Insightface with strict numpy pinning
%%capture
# 1. Clean slate
!pip uninstall -y numpy xtcocotools pycocotools fastai spacy thinc pymc pytensor jax jaxlib yfinance

# 2. Install numpy FIRST with --force-reinstall
!pip install --force-reinstall --no-deps numpy==1.26.4

# 3. Install torch stack (won't touch numpy)
!pip install torch==2.1.0+cu118 torchvision==0.16.0+cu118 torchaudio==2.1.0+cu118 --index-url https://download.pytorch.org/whl/cu118

# 4. Install other deps with numpy constraint
!pip install --no-cache-dir "opencv-python-headless<4.9" moviepy==1.0.3 imageio imageio-ffmpeg "numpy==1.26.4"

# 5. Install MMDet/MMPose deps with constraints
!pip install -U openmim "numpy==1.26.4"
!mim install "mmengine==0.10.3" "numpy==1.26.4"
!mim install "mmcv==2.1.0" "numpy==1.26.4"

# 6. Clone and install mmpose
!rm -rf mmpose
!git clone https://github.com/open-mmlab/mmpose.git
%cd mmpose
!pip install -e . --no-deps
!pip install -r requirements.txt "numpy==1.26.4"
%cd ..

# 7. Install mmdet with constraint
!pip install mmdet==3.3.0 "numpy==1.26.4"

# 8. Build xtcocotools from source against correct numpy
!pip install cython
!pip install --no-binary :all: xtcocotools


In [None]:
%%capture
#@title 1.3 Install detection dependencies
%pip install -U openmim
!mim install "mmengine==0.10.3"
!mim install "mmcv==2.1.0"

# Install mmdetection
!rm -rf mmdetection
!git clone https://github.com/open-mmlab/mmdetection.git
%cd mmdetection

%pip install -e .
%pip install -q lap
%pip install -q mmpose

### Now '**restart**' the session.

To fully use the newly installed packages, we need to restart the colab notebook runtime.

1.  Click **Runtime** > **Restart session**
2.  Rerun the setup cells above.
3.  Continue with the rest of the notebook.

In [None]:
#@title 1.4 Import the necessary libraries for timestamping analysis.
import os
import json
import cv2
import gdown
import time
from pathlib import Path
import argparse
import torch
import numpy as np
from tqdm.notebook import tqdm
import copy
from typing import List, Dict, Optional, Any, Tuple
import random
import datetime
import tempfile

# --- MMDetection & MMPose ---
# These are the core libraries for running the deep learning models.
from mmdet.apis import inference_detector, init_detector
from mmengine.structures import InstanceData
from mmdet.structures import DetDataSample
from mmpose.evaluation.functional import nms # Using nms from mmpose as in gradio script
from mmdet.models.trackers import ByteTracker # Import ByteTracker
from mmengine.registry import MODELS as MMENGINE_MODELS # If building tracker from dict

# --- Visualization ---
import matplotlib
import matplotlib.pyplot as plt
from moviepy.editor import VideoFileClip

# --- Interactive Widgets ---
from ipywidgets import interact, FloatSlider
from IPython.display import display, clear_output
_INTERACTIVE_WIDGETS_AVAILABLE = True

# --- Warnings ---
# Suppress common warnings from the detection libraries for a cleaner output.
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)


ModuleNotFoundError: No module named 'mmcv'

## **2. Download PrimateFace and set up helper functions**
Here, we'll download one of PrimateFace face detection models, which includes:
 1. the model configuration file and
 2. the model checkpoint (AKA weights).

We'll download these files directly from a Google Drive link.


In [None]:
#@title 2.1 Download PrimateFace cascade-rcnn detection model


# --- Action Required: Paste your Google Drive shareable links here ---
CONFIG_GDRIVE_LINK = "https://drive.google.com/file/d/1Y_YFdIDRcWQLI-gRiCnOrDxCptzCiiNp/view?usp=drive_link"
WEIGHTS_GDRIVE_LINK = "https://drive.google.com/file/d/1zZ8S31zPHX5BWYKbnHxI1QOqP-fPnVFO/view?usp=drive_link"

# --- Define local filenames for the downloaded files ---
downloaded_config_path = "downloaded_model_config.py"
downloaded_weights_path = "downloaded_model_weights.pth"

# --- Download function ---
def download_file_from_gdrive(url, output_path):
    """Uses the gdown library to download a file from a Google Drive link."""
    print(f"Attempting to download from Google Drive to {output_path}...")
    try:
        gdown.download(url, output_path, quiet=False, fuzzy=True)
        if os.path.exists(output_path):
            print(f"Successfully downloaded: {output_path}\n")
            return True
        else:
            print(f"Download failed. File not found at {output_path}. Check your share link and permissions.\n")
            return False
    except Exception as e:
        print(f"An error occurred during download: {e}\n")
        return False

# --- Execute the downloads ---
config_download_success = download_file_from_gdrive(CONFIG_GDRIVE_LINK, downloaded_config_path)
weights_download_success = download_file_from_gdrive(WEIGHTS_GDRIVE_LINK, downloaded_weights_path)

# --- IMPORTANT: Update the main configuration paths ---
# This tells the rest of the notebook to use the files we just downloaded.
if config_download_success:
    MMDET_CONFIG_PATH = downloaded_config_path
if weights_download_success:
    MMDET_CHECKPOINT_PATH = downloaded_weights_path


MMDET_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEFAULT_NMS_THRESHOLD = 0.3


print("--- Path Update Summary ---")
print(f"Model Config Path is now set to: {MMDET_CONFIG_PATH}")
print(f"Model Checkpoint Path is now set to: {MMDET_CHECKPOINT_PATH}")

In [None]:
#@title Setup for multi-animal tracker

# ByteTracker configuration
# These parameters can be tuned for different tracking performance.
DEFAULT_TRACKER_CONFIG = dict(
    motion=dict(type='KalmanFilter'),
    obj_score_thrs=dict(high=0.6, low=0.1),
    init_track_thr=0.7,
    weight_iou_with_det_scores=True,
    match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3),
    num_frames_retain=30
)

# Threshold for a detection to be considered by the tracker
# This can be lower than the final confidence threshold.
DEFAULT_TRACK_CONF_THRESHOLD = 0.4

# The maximum number of individuals to assign stable IDs to.
DEFAULT_NUM_TRACKED_IDS = 5

# Colors for visualizing different track IDs
ID_COLORS = {
    1: (0, 255, 0),   # Green
    2: (0, 0, 255),   # Red
    3: (255, 0, 0),   # Blue
    4: (255, 255, 0), # Yellow
    5: (0, 255, 255), # Cyan
    6: (255, 0, 255)  # Magenta
}
DEFAULT_COLOR = (128, 128, 128) # Gray for unassigned/overflow


In [None]:
#@title Helper functions for analysis, visualization, and BORIS export.

mmdet_model = None

def get_model_class_names(model) -> tuple:
    """Robustly gets class names from a model, handling str, list, or tuple."""
    meta = model.dataset_meta
    classes = meta.get('classes', None)
    if classes is None:
        return ()
    if isinstance(classes, str):
        # If it's a single string, wrap it in a tuple
        return (classes,)
    return tuple(classes)

# --- Model Loading ---
def load_mmdetection_model(config_path: str, checkpoint_path: str, device: str):
    """Loads the MMDetection model into memory."""
    global mmdet_model
    if mmdet_model is not None:
        print("MMDetection model already loaded.")
        return mmdet_model
    config_file, checkpoint_file = Path(config_path), Path(checkpoint_path)
    if not config_file.exists() or not checkpoint_file.exists():
        raise FileNotFoundError(f"Config or checkpoint not found. Check paths.")
    try:
        print(f"Loading MMDetection model on device: {device}...")
        model = init_detector(str(config_file), str(checkpoint_file), device=device)
        print("MMDetection model loaded successfully.")
        mmdet_model = model
        return mmdet_model
    except Exception as e:
        print(f"Error loading MMDetection model: {e}")
        raise
# --- Video Path Utility ---
def get_video_paths(input_path: str) -> List[str]:
    """Finds all video files in a given path (file or directory)."""
    input_p = Path(input_path)
    video_paths = []
    video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".webm", ".mts"]
    if input_p.is_file() and input_p.suffix.lower() in video_extensions:
        video_paths.append(str(input_p))
    elif input_p.is_dir():
        for item in input_p.iterdir():
            if item.is_file() and item.suffix.lower() in video_extensions:
                video_paths.append(str(item))
    if not video_paths:
        raise FileNotFoundError(f"No video files found at {input_path}")
    return video_paths

# --- Detection Processing ---
def filter_and_nms_raw_detections(mm_results_raw: DetDataSample, model_classes: List, confidence_threshold: float, nms_threshold: float, device: str) -> List[Dict[str, Any]]:
    """Filters raw model predictions by confidence and applies Non-Maximum Suppression."""
    processed_detections = []
    if not hasattr(mm_results_raw, 'pred_instances') or len(mm_results_raw.pred_instances) == 0:
        return []
    pred_instances = mm_results_raw.pred_instances
    keep_conf = pred_instances.scores >= confidence_threshold
    final_instances = pred_instances[keep_conf]
    if len(final_instances) == 0:
        return []

    nms_input = np.hstack((final_instances.bboxes.cpu().numpy(), final_instances.scores.cpu().numpy()[:, np.newaxis]))
    keep_nms_indices = nms(nms_input, nms_threshold)

    for i in keep_nms_indices:
        label_idx = int(final_instances.labels[i].item())
        processed_detections.append({
            "bbox": final_instances.bboxes[i].cpu().numpy().tolist(),
            "score": float(final_instances.scores[i].item()),
            "label": model_classes[label_idx]
        })
    return processed_detections

def create_downsampled_video(input_path: str, target_height: int) -> Optional[str]:
    """
    Creates a temporary, downsampled version of a video file.

    Args:
        input_path: Path to the original video.
        target_height: The desired height of the new video (e.g., 720).
                       Width is scaled automatically to maintain aspect ratio.

    Returns:
        The file path to the temporary, downsampled video, or None if an error occurs.
    """
    if target_height <= 0:
        return None

    try:
        print(f"Downsampling '{Path(input_path).name}' to {target_height}p height...")
        # Create a temporary file that will be deleted on close
        temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
        temp_video_path = temp_video_file.name

        cap = cv2.VideoCapture(input_path)
        if not cap.isOpened():
            print(f"Error: Could not open original video {input_path}")
            return None

        # Get original video properties
        original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS)

        # If video is already smaller, no need to downsample
        if original_height <= target_height:
            print("Video is already at or below target resolution. No downsampling needed.")
            temp_video_file.close() # It was created, so close it
            os.unlink(temp_video_path) # and delete it
            return None

        # Calculate new width to maintain aspect ratio
        aspect_ratio = original_width / original_height
        new_width = int(target_height * aspect_ratio)

        # Ensure width is even, as required by many video codecs
        if new_width % 2 != 0:
            new_width += 1

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(temp_video_path, fourcc, fps, (new_width, target_height))
        if not writer.isOpened():
            print("Error: Could not initialize VideoWriter.")
            return None

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        for _ in tqdm(range(total_frames), desc="Resizing frames"):
            ret, frame = cap.read()
            if not ret:
                break
            resized_frame = cv2.resize(frame, (new_width, target_height), interpolation=cv2.INTER_AREA)
            writer.write(resized_frame)

        cap.release()
        writer.release()
        print("Downsampling complete.")
        return temp_video_path

    except Exception as e:
        print(f"An error occurred during downsampling: {e}")
        return None


# --- Visualization ---
def visualize_frame_with_detections(frame: np.ndarray, detections: List[Dict[str, Any]], thickness: int = 2):
    """Draws bounding boxes on a frame, colored by stable ID."""
    vis_frame = frame.copy()
    for det in detections:
        x1, y1, x2, y2 = map(int, det["bbox"])
        score = det["score"]
        label = det["label"]
        stable_id = det.get("stable_id")

        color = ID_COLORS.get(stable_id, DEFAULT_COLOR) if stable_id is not None else (0, 255, 0)

        id_text = f"ID:{stable_id} " if stable_id is not None else ""

        cv2.rectangle(vis_frame, (x1, y1), (x2, y2), color, thickness)

        label_text = f"{id_text}{label}: {score:.2f}"

        label_y_pos = y1 - 10 if y1 > 20 else y1 + 15
        cv2.putText(vis_frame, label_text, (x1, label_y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness)
    return vis_frame

# --- Output Saving ---
def save_results_to_json(all_video_data: Dict[str, Any], output_filepath: str):
    """Saves the detailed analysis to a JSON file."""
    Path(output_filepath).parent.mkdir(parents=True, exist_ok=True)
    with open(output_filepath, 'w') as f:
        json.dump(all_video_data, f, indent=4)
    print(f"Results successfully saved to {output_filepath}")


def save_boris_project(master_results: Dict[str, Any], output_dir: str):
    """Converts analysis results with tracking into BORIS project files."""
    if not master_results.get("videos"): return

    behavior_code = "face present"
    behavior_conf = {"1": {"key": "1", "code": behavior_code, "type": "STATE", "description": "Auto-detected", "modifiers": {}, "category": "presence"}}

    for video_data in master_results["videos"]:
        video_path = Path(video_data["video_filepath"])
        fps, total_frames = video_data["fps"], video_data["total_frames"]
        timestamps_by_id = video_data.get("timestamps_summary_by_id", {})

        # Determine which subjects are present in this video
        present_sids = [sid for sid, ts in timestamps_by_id.items() if ts]
        if not present_sids:
            print(f"No tracked detections in {video_path.name}, skipping BORIS file.")
            continue

        # Define subjects based on detected IDs for this video
        subjects_conf = {str(i): {"key": str(i+1), "name": f"primate_{sid}", "description": ""} for i, sid in enumerate(present_sids)}
        subject_name_map = {sid: f"primate_{sid}" for sid in present_sids}

        all_events = []
        for sid, subject_name in subject_name_map.items():
            timestamps = timestamps_by_id.get(sid, [])
            if not timestamps: continue

            events_for_subject, in_event = [], False
            frame_duration = 1.0 / fps
            for i, ts in enumerate(timestamps):
                if not in_event:
                    events_for_subject.append([ts, subject_name, behavior_code, "", "", "START"])
                    in_event = True

                is_last = (i == len(timestamps) - 1)
                if is_last or (timestamps[i+1] - ts) > (frame_duration * 1.5):
                    events_for_subject.append([ts + frame_duration, subject_name, behavior_code, "", "", "STOP"])
                    in_event = False
            all_events.extend(events_for_subject)

        all_events.sort(key=lambda x: (x[0], 0 if x[5] == "START" else 1))

        video_path_str = str(video_path)
        observation_data = {
            "date": datetime.datetime.now().isoformat(), "description": "Auto-generated from script", "type": "MEDIA",
            "file": {"1": [video_path_str]},
            "media_info": {"length": {video_path_str: total_frames / fps}, "fps": {video_path_str: fps}, "hasVideo": {video_path_str: True}, "hasAudio": {video_path_str: True}, "offset": {"1": 0.0}},
            "time offset": 0.0, "events": all_events, "independent_variables": {}, "close_behaviors_between_videos": False,
        }

        boris_project = {
            "project_name": video_path.stem, "project_date": datetime.datetime.now().isoformat(),
            "subjects_conf": subjects_conf, "behaviors_conf": behavior_conf,
            "behavioral_categories": ["presence"], "observations": {video_path.stem: observation_data},
            "project_format_version": "7.0", "time_format": "s"
        }

        output_path = Path(output_dir) / f"{video_path.stem}.boris"
        with open(output_path, "w") as f:
            json.dump(boris_project, f, indent=4)
        print(f"BORIS project file for tracked subjects created at: {output_path}")


In [None]:
#@title Video analysis function

def analyze_video(video_path: str,
                  original_video_path: str,
                  model,
                  model_classes: List,
                  confidence_thresh: float,
                  nms_thresh: float, # Kept for signature consistency, but not used in tracking logic
                  device: str,
                  tracker_config: dict,
                  track_thr: float,
                  num_tracked_ids: int,
                  inference_fps: Optional[int] = None) -> Dict[str, Any]:
    """
    Processes a video, performing detection and tracking with performance profiling,
    and returns detailed analysis data.
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened(): return {}

    video_fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if video_fps == 0 or total_frames == 0: return {}

    tracker = ByteTracker(**tracker_config)
    stable_id_map = {sid: None for sid in range(1, num_tracked_ids + 1)}

    frame_step = 1
    if inference_fps and 0 < inference_fps < video_fps:
        frame_step = int(round(video_fps / inference_fps))

    all_frame_detections = []
    present_timestamps_by_id = {sid: [] for sid in range(1, num_tracked_ids + 1)}

    # --- Profiling Timers ---
    time_inference, time_tracking, time_postprocess = 0, 0, 0
    loop_count = 0

    print(f"Processing video with tracking: {Path(video_path).name}")
    for frame_idx in tqdm(range(0, total_frames, frame_step)):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if not ret: break

        current_time_sec = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0

        # --- 1. Time Inference ---
        t0 = time.time()
        with torch.no_grad(): # Ensure no gradients are computed
            mm_results_raw = inference_detector(model, frame)
            pred_instances = mm_results_raw.pred_instances
            pred_instances = pred_instances[pred_instances.scores > confidence_thresh]
        torch.cuda.synchronize() # Wait for GPU to finish
        t1 = time.time()
        time_inference += (t1 - t0)

        # --- 2. Time Tracking ---
        t0 = time.time()
        det_data_sample = DetDataSample(pred_instances=pred_instances, metainfo=dict(frame_id=frame_idx))
        tracked_instances = tracker.track(det_data_sample)
        t1 = time.time()
        time_tracking += (t1 - t0)

        # --- 3. Time Post-processing (Optimized) ---
        t0 = time.time()
        final_frame_detections = []
        if len(tracked_instances) > 0 and 'instances_id' in tracked_instances:

            # --- Vectorized Filtering and Sorting on GPU ---
            valid_mask = tracked_instances.scores >= track_thr

            current_bt_ids = tracked_instances.instances_id[valid_mask]
            current_scores = tracked_instances.scores[valid_mask]
            current_bboxes = tracked_instances.bboxes[valid_mask]
            current_labels = tracked_instances.labels[valid_mask]

            # Sort by score on GPU
            sort_indices = torch.argsort(current_scores, descending=True)
            current_bt_ids = current_bt_ids[sort_indices]
            current_scores = current_scores[sort_indices]
            current_bboxes = current_bboxes[sort_indices]
            current_labels = current_labels[sort_indices]

            new_stable_id_map = {sid: None for sid in range(1, num_tracked_ids + 1)}

            # Lists to accumulate final tensor data
            final_bboxes_list, final_scores_list, final_labels_list, final_stable_ids_list = [], [], [], []

            # --- ID Assignment (Hybrid Approach) ---
            # Part A: Maintain existing tracks
            assigned_bt_ids = set()
            for stable_id, prev_bt_id in stable_id_map.items():
                if prev_bt_id is not None:
                    # Search on the GPU tensor
                    match_indices = torch.where(current_bt_ids == prev_bt_id)[0]
                    if len(match_indices) > 0:
                        idx = match_indices[0]
                        new_stable_id_map[stable_id] = prev_bt_id
                        assigned_bt_ids.add(prev_bt_id)

                        final_bboxes_list.append(current_bboxes[idx])
                        final_scores_list.append(current_scores[idx])
                        final_labels_list.append(current_labels[idx])
                        final_stable_ids_list.append(stable_id)

            # Part B: Assign new tracks
            if len(assigned_bt_ids) < num_tracked_ids:
                for i in range(len(current_bt_ids)):
                    if len(assigned_bt_ids) >= num_tracked_ids: break

                    bt_id = current_bt_ids[i].item() # .item() is fast for single element
                    if bt_id not in assigned_bt_ids:
                        for stable_id_slot in range(1, num_tracked_ids + 1):
                            if new_stable_id_map[stable_id_slot] is None:
                                new_stable_id_map[stable_id_slot] = bt_id
                                assigned_bt_ids.add(bt_id)

                                final_bboxes_list.append(current_bboxes[i])
                                final_scores_list.append(current_scores[i])
                                final_labels_list.append(current_labels[i])
                                final_stable_ids_list.append(stable_id_slot)
                                break
            stable_id_map = new_stable_id_map

            # --- Final Conversion to CPU ---
            if final_bboxes_list:
                # Stack all tensors at once
                final_bboxes_gpu = torch.stack(final_bboxes_list)
                final_scores_gpu = torch.stack(final_scores_list)
                final_labels_gpu = torch.stack(final_labels_list)
                # Convert Python list of ints to a tensor
                final_stable_ids_gpu = torch.tensor(final_stable_ids_list, device=device)

                # Move to CPU in one batch
                final_bboxes_np = final_bboxes_gpu.cpu().numpy()
                final_scores_np = final_scores_gpu.cpu().numpy()
                final_labels_np = final_labels_gpu.cpu().numpy()
                final_stable_ids_np = final_stable_ids_gpu.cpu().numpy()

                # Create final dict list (this loop is unavoidable but now runs on pre-processed CPU data)
                for i in range(len(final_bboxes_np)):
                    final_frame_detections.append({
                        "bbox": final_bboxes_np[i].tolist(),
                        "score": float(final_scores_np[i]),
                        "label": model_classes[int(final_labels_np[i])],
                        "stable_id": int(final_stable_ids_np[i])
                    })

        all_frame_detections.append({
            "frame_index": frame_idx, "timestamp_sec": current_time_sec, "detections": final_frame_detections
        })

        if final_frame_detections:
            end_time_sec = current_time_sec + (frame_step / video_fps)
            time_block = np.arange(current_time_sec, end_time_sec, 1.0 / video_fps)
            for det in final_frame_detections:
                sid = det.get("stable_id")
                if sid in present_timestamps_by_id:
                    present_timestamps_by_id[sid].extend(time_block)

        t1 = time.time()
        time_postprocess += (t1 - t0)
        loop_count += 1

    cap.release()

    # --- Print Performance Breakdown ---
    if loop_count > 0:
        print("\n" + "="*50)
        print("--- Performance Breakdown ---")
        print(f"Frames processed: {loop_count}")
        print(f"Avg Inference Time:       {time_inference / loop_count * 1000:.2f} ms per frame")
        print(f"Avg Tracking Time:        {time_tracking / loop_count * 1000:.2f} ms per frame")
        print(f"Avg Post-processing Time: {time_postprocess / loop_count * 1000:.2f} ms per frame")
        print(f"Total Avg Time per Frame: {(time_inference + time_tracking + time_postprocess) / loop_count * 1000:.2f} ms")
        print("="*50 + "\n")

    for sid in present_timestamps_by_id:
        present_timestamps_by_id[sid] = sorted(list(set(present_timestamps_by_id[sid])))

    return {
        "video_filepath": original_video_path,
        "processed_video_path": video_path,
        "fps": video_fps,
        "total_frames": total_frames,
        "detailed_detections": all_frame_detections,
        "timestamps_summary_by_id": present_timestamps_by_id
    }

In [None]:
#@title Helper class for visibility timeline visualization.

class VideoAnalysisVisualizer:
    def __init__(self, original_video_path: str, analysis_data: dict, output_dir: str, dpi: int = 300, thickness: int = 2):
        self.original_video_path = original_video_path
        self.analysis_data = analysis_data
        self.output_dir = Path(output_dir)
        self.dpi = dpi
        self.thickness = thickness
        self.video_stem = Path(original_video_path).stem
        self.output_dir.mkdir(parents=True, exist_ok=True)

        self.clip = None
        try:
            # Always open the original, high-quality video for stills
            self.clip = VideoFileClip(original_video_path)
        except Exception as e:
            print(f"Error opening video {original_video_path} with MoviePy: {e}. Visualizations will not be generated.")

    def _get_video_dimensions(self, video_path: str) -> Tuple[int, int]:
        """Helper to get (width, height) of a video."""
        try:
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened(): return (0, 0)
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            cap.release()
            return (width, height)
        except Exception:
            return (0, 0)

    def plot_timeline(
        self,
        target_still_timestamp_sec: Optional[float] = None,
        fig_filename_prefix: str = "timeline",
        raster_cmap: str = "viridis"
    ):
        if not self.clip: return
        detailed_detections = self.analysis_data.get("detailed_detections")
        if not detailed_detections:
            print(f"No detection data available for {self.original_video_path}. Skipping timeline.")
            return

        # --- Bounding Box Rescaling Logic ---
        original_dims = self._get_video_dimensions(self.original_video_path)
        processed_path = self.analysis_data.get("processed_video_path", self.original_video_path)
        processed_dims = self._get_video_dimensions(processed_path)

        scale_x, scale_y = 1.0, 1.0
        if all(d > 0 for d in original_dims) and all(d > 0 for d in processed_dims) and original_dims != processed_dims:
            scale_x = original_dims[0] / processed_dims[0]
            scale_y = original_dims[1] / processed_dims[1]
            print(f"Rescaling bboxes for visualization. X-Factor: {scale_x:.2f}, Y-Factor: {scale_y:.2f}")

        # Find all unique stable IDs to create the raster plot rows
        all_sids = sorted(list(set(det['stable_id'] for fd in detailed_detections for det in fd.get('detections', []) if 'stable_id' in det)))

        if not all_sids:
            raster_row_labels = ["Detection"]
            visibility_matrix = np.zeros((1, len(detailed_detections)))
            for frame_idx, frame_data in enumerate(detailed_detections):
                if frame_data.get("detections"):
                    visibility_matrix[0, frame_idx] = 1
        else:
            raster_row_labels = [f"ID {sid}" for sid in all_sids]
            sid_to_row_idx = {sid: i for i, sid in enumerate(all_sids)}
            visibility_matrix = np.zeros((len(all_sids), len(detailed_detections)))
            for frame_idx, frame_data in enumerate(detailed_detections):
                for det in frame_data.get("detections", []):
                    sid = det.get("stable_id")
                    if sid in sid_to_row_idx:
                        visibility_matrix[sid_to_row_idx[sid], frame_idx] = 1

        # --- Choose Still Frame ---
        frames_with_detections = [fd for fd in detailed_detections if fd.get("detections")]
        if not frames_with_detections:
            print(f"No frames with any detections found. Skipping timeline figure.")
            return

        chosen_frame_data = random.choice(frames_with_detections)
        if target_still_timestamp_sec is not None:
            chosen_frame_data = min(frames_with_detections, key=lambda fd: abs(fd["timestamp_sec"] - target_still_timestamp_sec))

        still_timestamp = chosen_frame_data["timestamp_sec"]
        still_detections = chosen_frame_data.get("detections", [])

        try:
            still_frame_rgb = self.clip.get_frame(still_timestamp)
        except Exception as e:
            print(f"Error extracting still frame: {e}. Skipping timeline.")
            return

        timestamps_axis = np.array([fd["timestamp_sec"] for fd in detailed_detections])

        # --- Apply scaling to the detections for the still frame ---
        rescaled_detections = []
        for det in still_detections:
            rescaled_det = det.copy()
            x1, y1, x2, y2 = rescaled_det["bbox"]
            rescaled_det["bbox"] = [x1 * scale_x, y1 * scale_y, x2 * scale_x, y2 * scale_y]
            rescaled_detections.append(rescaled_det)

        # --- Plotting ---
        still_bgr_with_boxes = visualize_frame_with_detections(
            cv2.cvtColor(still_frame_rgb, cv2.COLOR_RGB2BGR),
            rescaled_detections,
            thickness=self.thickness
        )
        still_rgb_with_boxes = cv2.cvtColor(still_bgr_with_boxes, cv2.COLOR_BGR2RGB)

        fig_height_ratio = len(raster_row_labels) or 1
        fig = plt.figure(figsize=(10, 4 + fig_height_ratio * 0.5))
        gs = fig.add_gridspec(2, 1, height_ratios=[3, fig_height_ratio], hspace=0.15)

        ax0 = fig.add_subplot(gs[0])
        ax0.imshow(still_rgb_with_boxes)
        ax0.axis("off")
        ax0.set_title(f"Frame at {still_timestamp:.2f}s", fontsize=10)

        ax1 = fig.add_subplot(gs[1])
        ax1.imshow(visibility_matrix, aspect="auto", cmap=raster_cmap, interpolation='nearest', extent=[timestamps_axis[0], timestamps_axis[-1], -0.5, len(raster_row_labels)-0.5])
        ax1.set_yticks(range(len(raster_row_labels)))
        ax1.set_yticklabels(raster_row_labels)
        ax1.set_xlabel("Time (s)")
        ax1.set_ylabel("Tracked ID" if all_sids else "Detections")
        ax1.axvline(x=still_timestamp, color='red', linestyle='--', linewidth=1.5)

        plt.tight_layout(pad=0.5)
        fig_path = self.output_dir / f"{self.video_stem}_ID_timeline.png"
        pdf_fig_path = self.output_dir / f"{self.video_stem}_ID_timeline.pdf"

        try:
            plt.savefig(fig_path, dpi=self.dpi, bbox_inches='tight')
            plt.savefig(pdf_fig_path, dpi=self.dpi, bbox_inches='tight', format='pdf')
            print(f"ID timeline figure saved to: {fig_path} and {pdf_fig_path}")
        except Exception as e:
            print(f"Error saving timeline figure: {e}")

        plt.show()
        plt.close(fig)

    def close(self):
        """Closes the video file clip."""
        if self.clip:
            self.clip.close()
            self.clip = None


## **3. Load Your Video Data**

Now it's time to choose the video(s) you want to analyze. You have three easy options.

**Please choose one option in the form cell below:**

*   **Option 1: Use the Demo Video (Default)**
    *   This is the simplest way to get started. We'll automatically download a sample video of ring-tailed lemurs for you. Just run the next cell without changing anything.

*   **Option 2: Use Your Own Video via a Google Drive Link**
    *   If you have a single video on Google Drive, this is a great option.
    *   **How-to**: In Google Drive, right-click your video, select **Share > Share**, and set "General access" to **"Anyone with the link"**. Copy the link and paste it into the `video_gdrive_link` field below.

*   **Option 3: Use a Folder from Your Mounted Google Drive**
    *   This is the best option for analyzing multiple videos at once.
    *   **How-to**: First, check the `mount_my_drive` box in the form. You'll be asked to authorize Colab to access your Google Drive. Then, update the `mounted_drive_path` to point to the folder containing your videos (e.g., `/content/drive/MyDrive/MyPrimateVideos`).

**After selecting your option, run the cell to load the data.**

In [None]:
#@title 3.1 Configure Video Source
#@markdown 1. Choose your data source from the dropdown menu.
#@markdown 2. Fill in the corresponding fields below if needed.
#@markdown 3. Run the cell.

from google.colab import drive
from pathlib import Path
import gdown
import os

# --- Form Parameters ---
data_source_option = "Use a Path from Mounted Google Drive"  #@param ["Use Demo Video", "Use a Google Drive Link", "Use a Path from Mounted Google Drive"]

#@markdown ---
#@markdown ### For Option 2: Provide a Google Drive Link
#@markdown *If using this option, paste your shareable link here.*
video_gdrive_link = "PASTE_YOUR_GOOGLE_DRIVE_LINK_HERE"  #@param {type:"string"}

#@markdown ---
#@markdown ### For Option 3: Provide a Path from Your Mounted Drive
#@markdown *If using this option, set the path to your video file or folder.*
mounted_drive_path = "/content/drive/MyDrive/PrimateFace/demo-input-videos"  #@param {type:"string"}
#@markdown *Check this box to mount your Google Drive (required for Option 3).*
mount_my_drive = True  #@param {type:"boolean"}
#---------------------------------------------------------------------------------

# Initialize a global path variable
VIDEO_INPUT_PATH = ""

# --- Logic to handle the user's choice ---
try:
    if data_source_option == "Use Demo Video":
        print("▶️ Option 1: Using the demo video.")
        demo_video_link = "https://drive.google.com/file/d/1VHLn_zdkwaYrtgjgl_Fz_jTCYVGIu5Il/view?usp=drive_link"
        demo_video_filename = "demo_cayo_mac.mp4"
        print(f"Downloading demo video to '{demo_video_filename}'...")
        gdown.download(demo_video_link, demo_video_filename, quiet=False, fuzzy=True)
        if not os.path.exists(demo_video_filename):
            raise FileNotFoundError("Demo video download failed.")
        VIDEO_INPUT_PATH = demo_video_filename

    elif data_source_option == "Use a Google Drive Link":
        print("▶️ Option 2: Using a Google Drive link.")
        if "PASTE_YOUR" in video_gdrive_link:
            raise ValueError("Please paste your Google Drive link into the 'video_gdrive_link' field.")

        linked_video_filename = "user_gdrive_video.mp4"
        print(f"Downloading video from link to '{linked_video_filename}'...")
        gdown.download(video_gdrive_link, linked_video_filename, quiet=False, fuzzy=True)
        if not os.path.exists(linked_video_filename):
            raise FileNotFoundError("Video download from your link failed. Please check the URL and sharing permissions.")
        VIDEO_INPUT_PATH = linked_video_filename

    elif data_source_option == "Use a Path from Mounted Google Drive":
        print("▶️ Option 3: Using a path from your mounted Google Drive.")
        if mount_my_drive:
            print("Mounting Google Drive...")
            drive.mount('/content/drive')
            print("Drive mounted successfully.")
        else:
            print("Assuming Google Drive is already mounted. If you see an error, check the 'mount_my_drive' box.")

        if not Path(mounted_drive_path).exists():
            raise FileNotFoundError(f"The path '{mounted_drive_path}' does not exist in your Google Drive. Please check the path.")
        VIDEO_INPUT_PATH = mounted_drive_path

    # --- Final Validation ---
    video_files = get_video_paths(VIDEO_INPUT_PATH)
    print("\n✅ Success! The following video(s) are ready for analysis:")
    for vf in video_files:
        print(f" - {Path(vf).name}")

except (ValueError, FileNotFoundError) as e:
    print(f"\n❌ Error: {e}")
    print("Please correct the settings in the form and run the cell again.")
    video_files = [] # Ensure it's empty on failure

## **4. Configure analysis & pre-process videos**



In [None]:
#@title [Optional] 4.1 Pre-process Videos for Faster Analysis
#@markdown This step handles the one-time video downsampling.
#@markdown Set a lower value (e.g., 360) to speed up analysis, or **set to 0 to disable.**
#@markdown Note that downsampling may result in less accurate face detections.

target_processing_height = 360 #@param {type:"integer"}

#---------------------------------------------------------------------------------

analysis_jobs = []

if 'video_files' in locals() and video_files:
    print("--- Pre-processing Videos ---")
    with tqdm(total=len(video_files), desc="Processing Videos") as pbar:
        for original_path in video_files:
            job = {
                'original_path': original_path,
                'processing_path': original_path, # Default to original
                'temp_file_to_delete': None
            }
            if target_processing_height > 0:
                temp_path = create_downsampled_video(original_path, target_processing_height)
                if temp_path:
                    job['processing_path'] = temp_path
                    job['temp_file_to_delete'] = temp_path

            analysis_jobs.append(job)
            pbar.set_description(f"Processed '{Path(original_path).name}'")
            pbar.update(1)

    print("\n✅ Pre-processing complete. Ready for interactive threshold selection.")
else:
    print("❌ No video files loaded. Please run the 'Configure Video Source' cell successfully.")

## **5. Calibrate the Detection Confidence**

This tool helps you choose the best confidence threshold for your videos. A **high threshold** is strict (fewer, but more accurate detections), while a **low threshold** is lenient (more detections, but potentially more errors).

**Your Goal:** Move the slider to find the "sweet spot" where green boxes correctly identify most faces without picking up too much background noise.

In [None]:
#@title 5.1 Run me to explore confidence thresholds!
num_imgs_to_display = 7 #@param {type: "integer"}
bbox_thickness = 4 #@param {type: "integer"}

# This global variable will store the slider's value for the next step.
final_confidence_threshold = 0.4 # Default value
def interactive_threshold_explorer(
    jobs_to_sample: List[Dict],
    model: Any,
    model_classes: List[str],
    nms_thresh: float,
    device: str,
    num_imgs: int,
    bbox_thick: int
):
    print(f"Selecting up to {num_imgs} random frames for interactive testing...")

    sample_data = []
    samples_per_video = int(np.ceil(num_imgs / len(jobs_to_sample)))

    with tqdm(total=len(jobs_to_sample), desc="Sampling frames") as pbar:
        for job in jobs_to_sample:
            # Use the pre-processed video path for sampling
            processing_path = job['processing_path']

            # --- MODIFIED: Use cv2.VideoCapture for more robust video reading ---
            cap = cv2.VideoCapture(processing_path)
            if not cap.isOpened():
                print(f"Warning: Could not open {processing_path} with OpenCV, skipping.")
                pbar.update(1)
                continue

            fps = cap.get(cv2.CAP_PROP_FPS)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            duration = total_frames / fps if fps > 0 else 0

            if duration <= 1.0:
                pbar.update(1)
                cap.release()
                continue
            # --- END MODIFIED SECTION ---

            possible_ts = np.arange(0.5, duration - 0.5, 0.5)
            timestamps_to_sample = random.sample(list(possible_ts), k=min(samples_per_video, len(possible_ts)))

            for t in timestamps_to_sample:
                if len(sample_data) >= num_imgs: break

                # --- MODIFIED: Read frame from cv2.VideoCapture object ---
                frame_index_to_get = int(t * fps)
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index_to_get)
                ret, frame_bgr = cap.read()
                if not ret: continue
                # --- END MODIFIED SECTION ---

                with torch.no_grad():
                    raw_preds = inference_detector(model, frame_bgr)

                sample_data.append({
                    "display_frame_bgr": frame_bgr, # Pass BGR directly
                    "raw_preds": raw_preds,
                    "source_video": Path(job['original_path']).name,
                    "timestamp": t
                })

            pbar.update(1)
            cap.release() # --- MODIFIED: Release the capture object
            if len(sample_data) >= num_imgs: break

    if not sample_data:
        print("Error: Could not extract any sample frames. Please check that your video files are not corrupt and are in a common format (e.g., MP4).")
        return

    # This is the function that the slider will call
    def update_plot(confidence_thr):
        global final_confidence_threshold
        final_confidence_threshold = confidence_thr # Update global variable

        clear_output(wait=True)
        cols = int(np.ceil(np.sqrt(len(sample_data))))
        rows = int(np.ceil(len(sample_data) / cols))
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4.5, rows * 4.5))
        axes = axes.ravel()

        for i, sample in enumerate(sample_data):
            filtered_dets = filter_and_nms_raw_detections(sample["raw_preds"], model_classes, confidence_thr, nms_thresh, device)

            # --- MODIFIED: Use BGR frame directly for visualization ---
            frame_with_boxes = visualize_frame_with_detections(sample["display_frame_bgr"], filtered_dets, thickness=bbox_thick)

            avg_conf_text = ""
            if filtered_dets:
                avg_conf = np.mean([d['score'] for d in filtered_dets])
                avg_conf_text = f" | Avg Conf: {avg_conf:.2f}"

            title_text = f"{sample['source_video']}\n@{sample['timestamp']:.1f}s | {len(filtered_dets)} dets{avg_conf_text}"

            axes[i].imshow(cv2.cvtColor(frame_with_boxes, cv2.COLOR_BGR2RGB)) # Convert to RGB only for plotting
            # --- END MODIFIED SECTION ---

            axes[i].set_title(title_text, fontsize=9)
            axes[i].axis('off')

        for j in range(len(sample_data), len(axes)):
            axes[j].axis('off')

        plt.tight_layout()
        plt.show()

    # Create and display the interactive widget
    interact(
        update_plot,
        confidence_thr=FloatSlider(min=0.05, max=0.99, step=0.01, value=0.4, description='Confidence Thr.', continuous_update=False, readout_format='.2f', layout={'width': '500px'})
    )

# --- Main execution logic for this cell ---
if 'analysis_jobs' in locals() and analysis_jobs:
    # Load the model if it hasn't been loaded yet
    if 'mmdet_model' not in globals() or mmdet_model is None:
        mmdet_model = load_mmdetection_model(MMDET_CONFIG_PATH, MMDET_CHECKPOINT_PATH, MMDET_DEVICE)

    model_class_names = get_model_class_names(mmdet_model)
    print(f"Model classes found: {model_class_names}")

    interactive_threshold_explorer(
        jobs_to_sample=analysis_jobs,
        model=mmdet_model,
        model_classes=model_class_names,
        nms_thresh=DEFAULT_NMS_THRESHOLD,
        device=MMDET_DEVICE,
        num_imgs=num_imgs_to_display,
        bbox_thick=bbox_thickness
    )
else:
    print("❌ Analysis jobs not created. Please run the 'Configure Analysis' cell first.")

## **6. Run Full Analysis**
**Action Required**

This final cell runs the complete analysis on all your videos using the parameters you set below. It will save a detailed JSON file and the `.boris` project file(s) to the output directory.

**Set your desired parameters using the interactive form below.**
*   **`final_confidence_threshold`**: The value you determined in the interactive step.
*   **`use_fast_inference`**: Check this box to speed up the analysis.
*   **`inference_fps`**: If using fast inference, this is how many frames per second the model will analyze. A lower number is faster.

In [None]:
#@title 6.1 Run video inference
#@markdown This cell uses the confidence threshold from the previous step and the settings below to run the final analysis.

#@markdown ---
#@markdown ### 1. Analysis & Output Configuration
#@markdown Set this to the value you chose using the interactive slider.
final_confidence_threshold = 0.75 #@param {type:"slider", min:0.1, max:0.99, step:0.01}

#@markdown Select the analysis speed. "Fast" is less accurate but much faster.
analysis_mode = "Fast Inference (skips frames)" #@param ["Full Video (accurate, slow)", "Fast Inference (skips frames)"]

#@markdown **Note:** The FPS slider is only used if you select "Fast Inference".
inference_fps = 5 #@param {type:"slider", min:1, max:30, step:1}

#@markdown Set the directory and filename for your results.
OUTPUT_DIR = "analysis_results" #@param {type:"string"}
OUTPUT_JSON_FILENAME = "video_face_analysis_results.json" #@param {type:"string"}

#@markdown Select which output files you want to generate.
generate_json_file = True #@param {type:"boolean"}
generate_boris_file = True #@param {type:"boolean"}
generate_timeline_figure = True #@param {type:"boolean"}
#---------------------------------------------------------------------------------

# --- Setup from Parameters ---
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
output_json_path = Path(OUTPUT_DIR) / OUTPUT_JSON_FILENAME

if "Fast Inference" in analysis_mode:
    inference_fps_to_use = inference_fps
else:
    inference_fps_to_use = None

# Create the master results dictionary, storing the chosen parameters for reproducibility
master_results = {
    "parameters": {
        "confidence_threshold": final_confidence_threshold,
        "nms_threshold": DEFAULT_NMS_THRESHOLD,
        "track_confidence_threshold": DEFAULT_TRACK_CONF_THRESHOLD,
        "num_tracked_ids": DEFAULT_NUM_TRACKED_IDS,
        "inference_fps": inference_fps_to_use,
        "target_processing_height": target_processing_height
    },
    "videos": []
}

# --- Main Processing Loop ---
if 'analysis_jobs' in locals() and analysis_jobs:
    if 'mmdet_model' not in globals() or mmdet_model is None:
        print("❌ Error: MMDetection model is not loaded. Please run the model loading cell first.")
    else:
        for job in analysis_jobs:
            original_video_path = job['original_path']
            processing_video_path = job['processing_path']
            temp_file_to_delete = job['temp_file_to_delete']

            print("\n" + "="*50)
            print(f"Analyzing Video: {Path(original_video_path).name}")
            if temp_file_to_delete:
                print(f"--> Using temporary downsampled version: {Path(processing_video_path).name}")

            video_analysis_data = analyze_video(
                video_path=processing_video_path,
                original_video_path=original_video_path,
                model=mmdet_model,
                model_classes=model_class_names,
                confidence_thresh=final_confidence_threshold,
                nms_thresh=DEFAULT_NMS_THRESHOLD,
                device=MMDET_DEVICE,
                inference_fps=inference_fps_to_use,
                tracker_config=DEFAULT_TRACKER_CONFIG,
                track_thr=DEFAULT_TRACK_CONF_THRESHOLD,
                num_tracked_ids=DEFAULT_NUM_TRACKED_IDS
            )

            if video_analysis_data:
                master_results["videos"].append(video_analysis_data)
                if generate_timeline_figure:
                    print(f"\nGenerating timeline figure for {Path(original_video_path).name}...")
                    visualizer = VideoAnalysisVisualizer(
                        original_video_path,
                        video_analysis_data,
                        OUTPUT_DIR,
                        thickness=bbox_thickness
                    )
                    visualizer.plot_timeline()
                    visualizer.close()
            else:
                print(f"Analysis failed or produced no data for {original_video_path}")

            # Clean up the temporary file for this job after it's processed
            if temp_file_to_delete:
                print(f"Cleaning up temporary file: {temp_file_to_delete}")
                try:
                    os.unlink(temp_file_to_delete)
                except OSError as e:
                    print(f"  Error removing temp file: {e}")

        # --- Final Save Data Outputs ---
        if master_results["videos"]:
            print("\n--- Saving Output Files ---")
            if generate_json_file:
                save_results_to_json(master_results, str(output_json_path))
            if generate_boris_file:
                save_boris_project(master_results, OUTPUT_DIR)
        else:
            print("No videos were successfully processed. No output files saved.")
else:
    print("❌ No analysis jobs found. Please run the data loading and pre-processing cells first.")

print("\n--- ✅ Processing Finished ---")

## **7. Preview the Output Files**

Let's take a look at the files we just created in the `analysis_results` directory. This will give you a clear idea of what each file contains.

In [None]:
#@title Display files

import json
from rich.console import Console
from rich.table import Table
from rich.syntax import Syntax
from rich.panel import Panel

# Create a console object for printing
console = Console()

# --- Preview the Detailed Analysis JSON file ---
if generate_json_file and 'output_json_path' in locals() and output_json_path.exists():

    panel_title = f"Preview of Detailed Analysis File: [bold cyan]{output_json_path.name}[/bold cyan]"

    with open(output_json_path, 'r') as f:
        analysis_data = json.load(f)

    # --- Parameters Table ---
    param_table = Table(title="Analysis Parameters", show_header=True, header_style="bold magenta")
    param_table.add_column("Parameter", style="dim")
    param_table.add_column("Value")

    for key, val in analysis_data.get("parameters", {}).items():
        param_table.add_row(key, str(val))

    # --- Example Detection Panel ---
    first_video = analysis_data["videos"][0] if analysis_data.get("videos") else None
    first_detection_json = "{}"
    if first_video:
        first_detection = next((d for d in first_video.get('detailed_detections', []) if d['detections']), None)
        if first_detection:
            first_detection_json = json.dumps(first_detection, indent=2)

    json_preview = Syntax(first_detection_json, "json", theme="monokai", line_numbers=True)

    # Render the analysis preview in a panel
    console.print(Panel.fit(param_table, title=panel_title, border_style="green"))
    console.print(Panel(json_preview, title="[bold green]Example of First Detected Frame Data[/bold green]", border_style="green"))


# --- Preview the BORIS Project file ---
if generate_boris_file and 'master_results' in locals() and master_results.get("videos"):
    first_video_path = Path(master_results["videos"][0]["video_filepath"])
    boris_file_path = Path(OUTPUT_DIR) / f"{first_video_path.stem}.boris"

    if boris_file_path.exists():
        panel_title = f"Preview of BORIS Project File: [bold cyan]{boris_file_path.name}[/bold cyan]"

        with open(boris_file_path, 'r') as f:
            boris_data = json.load(f)

        obs_key = list(boris_data.get('observations', {}).keys())[0]
        events = boris_data['observations'][obs_key].get('events', [])

        # --- Events Table ---
        events_table = Table(title=f"First 10 'face present' Events", show_header=True, header_style="bold magenta")
        events_table.add_column("Timestamp", justify="right")
        events_table.add_column("Behavior")
        events_table.add_column("Event Type", style="dim")

        if events:
            for event in events[:10]:
                events_table.add_row(f"{event[0]:.2f}s", event[2], event[5])

        # Render the BORIS preview in a panel
        console.print(Panel.fit(events_table, title=panel_title, border_style="blue"))



## **8. Summary and Next Steps**

Congratulations! You have successfully completed the automated video timestamping process.

**Here's a recap of what we accomplished:**
1.  We loaded a pre-trained `PrimateFace` detection model.
2.  We used an interactive widget to visually determine the best detection **confidence threshold** for your specific video(s).
3.  We ran a full analysis using your chosen threshold to identify every frame containing a face.
4.  We generated two key output files, which are now saved in the `analysis_results` directory:
    *   `video_face_analysis_results.json`: A detailed data file containing the bounding box, score, and timestamp for every single detection.
    *   `[your_video_name].boris`: A BORIS project file, ready for manual annotation.

**Your next step is to open the generated `.boris` file in the BORIS application.** You will find that the timeline already contains a "face present" state event, accurately marking all the periods where a face was visible. You may need to change the location of the video file(s), typically with Observation > Edit observation.

You can now use this as a foundation to add your own, more detailed behavioral codes, saving valuable annotation time.

## 9. Resources
1. [PrimateFace](https://github.com/PrimateFace/PrimateFace)
2. [mmdetection](https://github.com/open-mmlab/mmdetection)
3. [roboflow](https://roboflow.com/)