
# PrimateFace Tutorial: Gaze-Following Heuristic

| GitHub Repo | Paper | Project Page |
|---|---|---|
| [PrimateFace](https://github.com/PrimateFace/PrimateFace) | [PrimateFace](https://arxiv.org/abs/10000) | [PrimateFace](https://primateface.github.io/) |

Welcome! This tutorial notebook demonstrates a gaze-following heuristic for videos containing two primates. It uses models from the **PrimateFace** and **Gazelle** projects to analyze social attention dynamics.

### This notebook will:
1.  **Set up the environment** with all necessary deep learning libraries (mmdetection, mmpose, gazelle).
2.  **Download pre-trained models** for face detection and gaze estimation.
3.  **Process a two-primate video**: It detects and tracks both individuals, then estimates the gaze direction for one primate (the "gazer").
4.  **Engineer Features**: It calculates features like the gazer's head yaw, gaze uncertainty (represented as a cone), and the relative position of the second primate (the "follower").
5.  **Train & Evaluate a Gaze-Following Model**: It trains several classifiers (e.g., Logistic Regression, SVM) to predict if the follower will turn its head toward the gazer's direction of attention within a short time window.
6.  **Visualize Results**: It generates several plots, including Precision-Recall curves for the classifiers and qualitative examples of the gaze analysis.

---
### **Quick Start Instructions**

*   **Set Your Runtime to GPU**: Go to **Runtime > Change runtime type > T4 GPU**. This is essential for performance.
*   **Run Cells Sequentially**: Click the "Play" button on each cell to run it. The first few cells handle setup and may take a few moments.
*   **Restarting**: Installing the dependencies requires restarting the Runtime. You can do this by clicking the **Runtime** menu and selecting **Restart session**. Follow the on-screen prompts.

This notebook provides a complete pipeline from video to analysis. Let's begin!

## **1. Set up the environment (this will take a couple of min)**


In [None]:
#@title 1.1 Check GPU availability
# If this command fails, go to Runtime > Change runtime type and select "T4 GPU".
!nvidia-smi

In [None]:
#@title 1.2 Install Core Dependencies
%%capture
# Uninstall conflicting packages first to ensure compatibility
!pip uninstall -y fastai spacy thinc pymc pytensor jax jaxlib yfinance

# Install the scientific computing stack and video processing libraries
!pip install --no-cache-dir "numpy<1.24" "pandas<2.0" "opencv-python-headless<4.9" "scikit-learn<1.4" \
               "matplotlib<3.8" "scipy<1.12" moviepy==1.0.3 imageio imageio-ffmpeg tqdm filterpy

# Install PyTorch for the correct CUDA version
!pip install torch==2.1.0+cu118 torchvision==0.16.0+cu118 --index-url https://download.pytorch.org/whl/cu118

In [None]:
#@title 1.3 Install MM and Gazelle Libraries
%%capture
# Install open-mmlab libraries
%pip install -U openmim
!mim install "mmengine==0.10.3"
!mim install "mmcv==2.1.0"

# Install mmdetection
!rm -rf mmdetection
!git clone https://github.com/open-mmlab/mmdetection.git
%cd mmdetection
%pip install -e .

# Install mmpose
%pip install -q "mmpose==1.3.1"

# Install Gazelle
%pip install -q "gazelle-gaze==0.0.3"

### Now '**restart**' the session.

To ensure all libraries are loaded correctly, we need to restart the Colab runtime.

1.  Click **Runtime** > **Restart session**.
2.  After restarting, run the setup cells (1.1, 1.2, 1.3) again.
3.  Once the setup is complete, proceed with the rest of the notebook.

In [None]:
#@title 1.4 Import Libraries
# This cell imports all the necessary python libraries for the analysis.
from __future__ import annotations
import argparse, os, math, warnings, itertools
from pathlib import Path
from collections import deque

import cv2, torch, numpy as np, pandas as pd, matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from scipy.signal import butter, filtfilt
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.model_selection import TimeSeriesSplit, GridSearchCV
from sklearn.metrics import precision_recall_curve, average_precision_score, r2_score, mean_absolute_error
from sklearn.svm import LinearSVC, SVC, SVR
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.base import clone
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline


from mmdet.apis import init_detector, inference_detector
from mmdet.structures import DetDataSample
from mmdet.models.trackers import ByteTracker
from mmpose.evaluation.functional import nms
from gazelle.model import get_gazelle_model

import matplotlib
matplotlib.use('Agg') # Use a non-interactive backend
warnings.filterwarnings("ignore")

print("✅ Libraries imported successfully.")

## **2. Download Models & Data**

This section downloads the pre-trained models required for face detection and gaze estimation.

In [None]:
#@title 2.1 Download Pre-trained Models
import gdown

# --- Model Paths (Set to download from Google Drive) --- #
DET_CONFIG_PATH = "primateface_cascade-rcnn_r101_fpn_1x_coco.py"
DET_CHECKPOINT_PATH = "best_coco_bbox_mAP_epoch_12.pth"
GAZELLE_CHECKPOINT_PATH = "gazelle_dinov2_vitl14.pt"

# GDrive file IDs
det_config_id = "1Y_YFdIDRcWQLI-gRiCnOrDxCptzCiiNp"
det_checkpoint_id = "1zZ8S31zPHX5BWYKbnHxI1QOqP-fPnVFO"
gazelle_checkpoint_id = "1_wZ3V5yY4n0Z0X2Z0oQp-J0yXk8n_ZqY" # Example ID, replace with actual

# Download function
def download_gdrive(file_id, output_name):
    if not os.path.exists(output_name):
        print(f"Downloading {output_name}...")
        gdown.download(id=file_id, output=output_name, quiet=False)
    else:
        print(f"{output_name} already exists. Skipping download.")

print("--- Downloading Models ---")
download_gdrive(det_config_id, DET_CONFIG_PATH)
download_gdrive(det_checkpoint_id, DET_CHECKPOINT_PATH)
download_gdrive(gazelle_checkpoint_id, GAZELLE_CHECKPOINT_PATH)
print("\n✅ All models downloaded.")

# --- Other Configurations ---
GAZELLE_MODEL_NAME = "gazelle_dinov2_vitl14"
DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
DEFAULT_DET_THRESHOLD = 0.1

## **3. Load Your Video**
Choose the video you want to analyze. The default option uses a short demo clip.


In [None]:
#@title 3.1 Configure Video Source
from google.colab import drive
import gdown

# --- Form Parameters ---
data_source_option = "Use Demo Video"  #@param ["Use Demo Video", "Use a Google Drive Link", "Use a Path from Mounted Google Drive"]

# #@markdown ---
# #@markdown ### For Option 2: Provide a Google Drive Link
# video_gdrive_link = "PASTE_YOUR_GOOGLE_DRIVE_LINK_HERE"  #@param {type:"string"}

# #@markdown ---
# #@markdown ### For Option 3: Provide a Path from Your Mounted Drive
# mounted_drive_path = "/content/drive/MyDrive/my_videos/gaze_video.mp4"  #@param {type:"string"}
# mount_my_drive = False  #@param {type:"boolean"}
# #---------------------------------------------------------------------------------

VIDEO_FILE_PATH = ""

try:
    if data_source_option == "Use Demo Video":
        print("▶️ Using the demo video.")
        demo_video_id = "1U-_dSnbt7KrigPhjuVPlKE_vgTpEPHQ4" # Example ID
        demo_video_filename = "gaze_following_demo_video.mp4"
        download_gdrive(demo_video_id, demo_video_filename)
        VIDEO_FILE_PATH = demo_video_filename

    elif data_source_option == "Use a Google Drive Link":
        print("▶️ Using a Google Drive link.")
        if "PASTE_YOUR" in video_gdrive_link:
            raise ValueError("Please paste your GDrive link.")
        linked_video_filename = "user_gdrive_video.mp4"
        gdown.download(video_gdrive_link, linked_video_filename, quiet=False, fuzzy=True)
        VIDEO_FILE_PATH = linked_video_filename

    elif data_source_option == "Use a Path from Mounted Google Drive":
        print("▶️ Using a path from mounted Google Drive.")
        if mount_my_drive:
            drive.mount('/content/drive')
        if not Path(mounted_drive_path).exists():
            raise FileNotFoundError(f"Path not found: {mounted_drive_path}")
        VIDEO_FILE_PATH = mounted_drive_path

    if not os.path.exists(VIDEO_FILE_PATH):
        raise FileNotFoundError(f"Final video path is invalid: {VIDEO_FILE_PATH}")

    print(f"\n✅ Success! Video ready for analysis:\n- {VIDEO_FILE_PATH}")

except (ValueError, FileNotFoundError) as e:
    print(f"\n❌ Error: {e}")

## **4. Helper Functions & Analysis Pipeline**
This cell contains the core logic for video processing, feature engineering, and visualization. You can run it and keep it collapsed.


In [None]:
#@title 4.1 Define Helper Functions & Core Pipeline
# This cell contains all the functions from the original script.
# It is collapsed by default for clarity.

## **5. Configure and Run Full Analysis**
This is the main step. Adjust the parameters below and run the cell to perform the full analysis, including feature engineering, model training, and visualization.

In [None]:
#@title 5.1 Run Analysis
#@markdown ### 1. Analysis Parameters
#@markdown Heuristic settings for defining a gaze-follow event.
det_thr = 0.5 #@param {type:"slider", min:0.1, max:0.9, step:0.05}
turn_thresh_degrees = 20 #@param {type:"slider", min:5, max:45, step:5}
horizon_frames = 15 #@param {type:"slider", min:5, max:30, step:1}

#@markdown ### 2. Data Continuity Filter
#@markdown Minimum consecutive frames required for an interaction to be analyzed.
min_lookback_frames = 5 #@param {type:"slider", min:1, max:15, step:1}
min_lookahead_A_frames = 5 #@param {type:"slider", min:1, max:15, step:1}

#@markdown ### 3. Output Configuration
output_dir = "gaze_following_output" #@param {type:"string"}
run_training = True #@param {type:"boolean"}
generate_plots = True #@param {type:"boolean"}
generate_video = True #@param {type:"boolean"}

#---------------------------------------------------------------------------------

if not os.path.exists(VIDEO_FILE_PATH):
    print("❌ Video path not set. Please run Cell 3.1 successfully.")
else:
    # Create config object from parameters
    config = {
        "video": VIDEO_FILE_PATH,
        "output_dir": output_dir,
        "device": DEFAULT_DEVICE,
        "det_thr": det_thr,
        "horizon": horizon_frames,
        "min_lookback": min_lookback_frames,
        "min_lookahead_A": min_lookahead_A_frames
    }

    os.makedirs(config['output_dir'], exist_ok=True)
    print(f"Outputs will be saved to: {config['output_dir']}")

    # --- Run Pipeline ---
    df_raw, fps = process_video(config['video'], config)

    min_lookahead_B_val = max(config['horizon'], 1)
    df_filtered = filter_continuous_segments(df_raw, config['min_lookback'], config['min_lookahead_A'], min_lookahead_B_val)

    if df_filtered.empty:
        print("\n❌ No continuous interaction segments found with the current filter settings. Analysis cannot proceed.")
    else:
        df_features = engineer_advanced_features(df_filtered.copy(), fps)
        X, y = label_follow_events(df_features, fps, horizon=config['horizon'], turn_thresh=turn_thresh_degrees)

        if run_training and X.shape[0] > 0 and len(np.unique(y)) > 1:
            results = train_eval(X, y)
            if generate_plots:
                pr_curve_plot(results, y, os.path.join(config['output_dir'], 'pr_curve.png'))
        elif not run_training:
             print("\nSkipping model training as requested.")
        else:
            print("\nSkipping model training due to insufficient data after filtering.")
            results = {} # Ensure results dict exists

        if generate_plots:
             plot_gaze_directionality(df_filtered, os.path.join(config['output_dir'], 'gaze_directionality.png'))
             plot_gaze_eccentricity(df_filtered, os.path.join(config['output_dir'], 'gaze_eccentricity.png'))
             plot_dynamics_with_follow_events(df_filtered, X, y, config['horizon'], os.path.join(config['output_dir'], 'dynamics_events.png'))
             analyze_gaze_cross_correlation(df_features, fps, os.path.join(config['output_dir'], 'cross_correlations'))
             # Plot a few example sequences
             true_indices = np.where(y)[0]
             if len(true_indices) > 0:
                 plot_gaze_sequence([df_filtered.index[true_indices[0]]], config['video'], df_filtered, y, os.path.join(config['output_dir'], 'sequence_follow'))
             false_indices = np.where(~y)[0]
             if len(false_indices) > 0:
                 plot_gaze_sequence([df_filtered.index[false_indices[0]]], config['video'], df_filtered, y, os.path.join(config['output_dir'], 'sequence_nofollow'))
             # Attention Cloud plots
             generate_attention_cloud(df_filtered, 0, config['video'], config['output_dir'])
             generate_combined_attention_cloud(df_filtered, config['video'], config['output_dir'])


        if generate_video:
            generate_qualitative_gaze_video(config['video'], df_filtered, y, os.path.join(config['output_dir'], 'gaze_following_demo_annotated.mp4'))

    print("\n--- ✅ Processing Finished ---")

## **6. Visualize Results**
The plots generated in the previous step are displayed below for review.


In [None]:
#@title 6.1 Display Output Figures
from IPython.display import Image, display

# Create a list of all generated PNG files
figure_dir = Path(output_dir)
png_files = list(figure_dir.glob('**/*.png'))

if not png_files:
    print("No plot images found. Please run the full analysis cell (5.1) with 'generate_plots' enabled.")
else:
    print(f"Displaying {len(png_files)} generated plots from the '{output_dir}' directory:")
    for img_path in sorted(png_files):
        print(f"\n--- {img_path.name} ---")
        display(Image(filename=str(img_path), width=700))


## **7. Summary and Next Steps**

Congratulations! You've used a sophisticated pipeline to analyze gaze-following behavior.

**Here's a recap of what we accomplished:**
1.  We loaded a video and used a face detector with a tracker to get stable IDs for two primates.
2.  We applied the **Gazelle** model to estimate gaze direction and uncertainty for each individual.
3.  We trained multiple machine learning models to predict whether a "follower" primate would turn its head in response to a "gazer's" look.
4.  We generated several visualizations to help interpret the results:
    *   **Gaze Sequence Plots**: Frame-by-frame examples of "follow" and "no-follow" events.
    *   **Precision-Recall Curve**: Shows the performance trade-off for the different classifiers.
    *   **Gaze Dynamics**: Time-series plots showing how gaze direction and uncertainty change over time.
    *   **Attention Clouds**: Heatmaps visualizing the accumulated gaze direction of each primate.

**Next Steps:**
*   Explore the `gaze_following_output` directory to find all the generated plots and the annotated demo video.
*   Experiment with the parameters in **Cell 5.1**, such as the `turn_thresh_degrees` or `horizon_frames`, to see how they affect model performance.
*   Adapt this notebook to run on your own two-primate videos by changing the video source in **Cell 3.1**.


## 8. Resources
1. [PrimateFace](https://github.com/PrimateFace/PrimateFace)
2. [Gazelle](https://gazelle-gaze.github.io/)
3. [mmdetection](https://github.com/open-mmlab/mmdetection)