In [None]:
# Setup: Clone repository if running on Google Colab
import os

# IMPORTANT: Set Keras backend BEFORE any keras imports (including et_util)
os.environ["KERAS_BACKEND"] = "jax"

try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    # Clone the repository to get et_util package
    if not os.path.exists('eye-tracking-training'):
        !git clone https://github.com/jspsych/eye-tracking-training.git
    
    # Add the repository to Python path so we can import et_util
    import sys
    repo_path = '/content/eye-tracking-training'
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)


# WebGazer vs Our Model Comparison

This notebook implements a WebGazer-style ridge regression algorithm on raw pixel data and compares its performance with our deep learning-based eye tracking model on the same dataset. Our model supports multiple backbone architectures (DenseNet, ViT, or Hybrid CNN-Transformer).

## Setup and Imports

In [None]:
!pip install osfclient --quiet
!pip install plotnine --quiet
!pip install wandb --quiet
!pip install keras-hub --quiet
!pip install python-dotenv --quiet
!pip install scikit-learn --quiet

In [None]:
import os
import gc
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
from scipy import stats

import keras
from keras import ops
import keras_hub

import osfclient
from osfclient.api import OSF

import wandb
from wandb.integration.keras import WandbMetricsLogger

from plotnine import ggplot, geom_point, aes, geom_line, geom_histogram, geom_boxplot, geom_ribbon, scale_y_reverse, scale_y_continuous, theme_void, scale_x_continuous, scale_color_manual, scale_fill_manual, ylab, xlab, labs, theme, theme_classic, element_text, element_blank, element_line, facet_wrap, geom_abline, geom_smooth, stat_smooth

In [None]:
try:
  from google.colab import userdata
  IN_COLAB = True
except ImportError:
  IN_COLAB = False

In [None]:
import et_util.dataset_utils as dataset_utils
from et_util.dataset_utils import parse_single_eye_tfrecord_with_phase as parse, rescale_coords_map_with_phase as rescale_coords_map
from et_util.custom_loss import normalized_weighted_euc_dist
from et_util.custom_layers import (
    SimpleTimeDistributed,
    MaskedWeightedRidgeRegressionLayer,
    MaskInspectorLayer,
    ResidualBlock,
    AddPositionalEmbedding,
)
from et_util.model_analysis import plot_model_performance
from et_util.inference_utils import create_flexible_inference_model

In [None]:
if IN_COLAB:
    os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
    os.environ['OSF_TOKEN'] = userdata.get('osftoken')
    os.environ['OSF_USERNAME'] = userdata.get('osfusername')
else:
    # Load from a .env file using python-dotenv
    from dotenv import load_dotenv
    load_dotenv()
    os.environ['WANDB_API_KEY'] = os.getenv('WANDB_API_KEY', '')
    os.environ['OSF_TOKEN'] = os.getenv('OSF_TOKEN', '')
    os.environ['OSF_USERNAME'] = os.getenv('OSF_USERNAME', '')

os.environ['OSF_ANALYSIS_PROJECT_ID'] = "8ecx5"
os.environ["KERAS_BACKEND"] = "tensorflow"

## Constants and Configuration

In [None]:
# Fixed constants
MAX_TARGETS = 288
IMAGE_HEIGHT = 36
IMAGE_WIDTH = 144
IMAGE_PIXELS = IMAGE_HEIGHT * IMAGE_WIDTH  # 5,184 pixels

# Model configuration defaults (will be overridden by W&B config)
EMBEDDING_DIM = 200
RIDGE_REGULARIZATION = 0.001

# Experiment configuration - updated to match serializable model project
EXPERIMENT_ID = "mck7rj6y"
ENTITY_NAME = "vassar-cogsci-lab"
PROJECT_NAME = "eye-tracking-serializable-model"

# Model display name for plots and results
OUR_MODEL_NAME = "Our Model"

## Dataset Loading and Preprocessing

Reusing the exact same dataset loading pipeline from the original analysis to ensure fair comparison.

In [None]:
!osf -p uf2sh fetch single_eye_tfrecords.tar.gz

100% 675M/675M [00:04<00:00, 160Mbytes/s]


In [None]:
!mkdir single_eye_tfrecords
!tar -xf single_eye_tfrecords.tar.gz -C single_eye_tfrecords

In [None]:
# parse function already imported from et_util.dataset_utils at the top

In [None]:
# Load the dataset with phase-aware parsing
test_data, _, _ = dataset_utils.process_tfr_to_tfds(
    'single_eye_tfrecords/',
    parse,
    train_split=1.0,
    val_split=0.0,
    test_split=0.0,
    random_seed=12604,
    group_function=lambda img, phase, coords, subject_id: subject_id
)

In [None]:
# Rescale coordinates from 0-100 to 0-1 range using the imported function
test_data_rescaled = test_data.map(rescale_coords_map)

In [None]:
# Define calibration points (same as original analysis)
cal_points = tf.constant([
    [5, 5],
    [5, 27.5],
    [5, 50],
    [5, 72.5],
    [5, 95],
    [35, 5],
    [35, 27.5],
    [35, 50],
    [35, 72.5],
    [35, 95],
    [65, 5],
    [65, 27.5],
    [65, 50],
    [65, 72.5],
    [65, 95],
    [95, 5],
    [95, 27.5],
    [95, 50],
    [95, 72.5],
    [95, 95],
], dtype=tf.float32)

scaled_cal_points = tf.divide(cal_points, tf.constant([100.]))

## WebGazer Implementation

Implementing the WebGazer algorithm: ridge regression on raw pixel intensities.

In [None]:
class WebGazerModel:
    """
    WebGazer-style ridge regression model that operates directly on raw pixel intensities.
    Implements the same interface as the DenseNet model for fair comparison.
    """

    def __init__(self, lambda_ridge=0.001, normalize_pixels=True):
        self.lambda_ridge = lambda_ridge
        self.normalize_pixels = normalize_pixels
        self.ridge_x = Ridge(alpha=lambda_ridge, fit_intercept=True)
        self.ridge_y = Ridge(alpha=lambda_ridge, fit_intercept=True)
        self.is_fitted = False

    def _preprocess_images(self, images):
        """
        Preprocess images: flatten and optionally normalize.
        Args:
            images: numpy array of shape [n_samples, height, width] or [n_samples, height, width, 1]
        Returns:
            Flattened and preprocessed images of shape [n_samples, height*width]
        """
        # Ensure we have the right shape
        if len(images.shape) == 4 and images.shape[-1] == 1:
            images = np.squeeze(images, axis=-1)  # Remove channel dimension

        # Flatten images
        n_samples = images.shape[0]
        flattened = images.reshape(n_samples, -1)

        # Normalize pixels to 0-1 range if specified
        if self.normalize_pixels:
            flattened = flattened.astype(np.float32) / 255.0

        return flattened

    def fit(self, images, coordinates, sample_weights=None):
        """
        Fit the WebGazer model to calibration data.
        Args:
            images: calibration images [n_cal_samples, height, width]
            coordinates: gaze coordinates [n_cal_samples, 2]
            sample_weights: optional sample weights [n_cal_samples]
        """
        X = self._preprocess_images(images)

        # Fit separate ridge regression models for x and y coordinates
        self.ridge_x.fit(X, coordinates[:, 0], sample_weight=sample_weights)
        self.ridge_y.fit(X, coordinates[:, 1], sample_weight=sample_weights)

        self.is_fitted = True

    def predict(self, images):
        """
        Predict gaze coordinates for test images.
        Args:
            images: test images [n_test_samples, height, width]
        Returns:
            predicted coordinates [n_test_samples, 2]
        """
        if not self.is_fitted:
            raise ValueError("Model must be fitted before making predictions")

        X = self._preprocess_images(images)

        # Predict x and y coordinates separately
        pred_x = self.ridge_x.predict(X)
        pred_y = self.ridge_y.predict(X)

        # Combine predictions
        predictions = np.column_stack([pred_x, pred_y])

        return predictions

    def get_params(self):
        """Return model parameters for logging/analysis."""
        return {
            'lambda_ridge': self.lambda_ridge,
            'normalize_pixels': self.normalize_pixels,
            'n_features': IMAGE_PIXELS,
            'model_type': 'WebGazer_RidgeRegression'
        }

## Data Processing for Model Comparison

Creating functions to process the dataset for both models using identical train/test splits.

In [None]:
def prepare_masked_dataset(dataset, calibration_points=None, cal_phase=None):
    """Prepare masked dataset with phase-aware calibration - matching analysis notebook.
    
    Note: With the flexible inference model, tf.ensure_shape() is no longer required
    since the model can handle dynamic shapes. We still pad to MAX_TARGETS for 
    consistent batching, but the model will work with any number of points.
    """
    def group_by_subject(subject_id, ds):
        return ds.batch(batch_size=MAX_TARGETS)

    grouped_dataset = dataset.group_by_window(
        key_func=lambda img, phase, coords, subject_id: subject_id,
        reduce_func=group_by_subject,
        window_size=MAX_TARGETS
    )

    def add_masks_to_batch(images, phase, coords, subject_ids):
        actual_batch_size = tf.shape(images)[0]

        # Create phase masks
        phase1_mask = tf.cast(tf.equal(phase, 1), tf.int8)
        phase2_mask = tf.cast(tf.equal(phase, 2), tf.int8)

        cal_mask = tf.zeros(actual_batch_size, dtype=tf.int8)
        target_mask = tf.zeros(actual_batch_size, dtype=tf.int8)

        if calibration_points is None:
            raise ValueError("Need to specify calibration points in test mode")
        else:
            coords_xpand = tf.expand_dims(coords, axis=1)
            cal_xpand = tf.expand_dims(calibration_points, axis=0)

            # Check which points match calibration points
            equality = tf.equal(coords_xpand, cal_xpand)
            matches = tf.reduce_all(equality, axis=-1)
            point_matches = tf.reduce_any(matches, axis=1)
            cal_mask = tf.cast(point_matches, dtype=tf.int8)

        target_mask = (1 - cal_mask) * phase2_mask

        if cal_phase == 1:
            cal_mask = cal_mask * phase1_mask
        elif cal_phase == 2:
            cal_mask = cal_mask * phase2_mask

        padded_images = tf.pad(
            tf.reshape(images, (-1, 36, 144, 1)),
            [[0, MAX_TARGETS - actual_batch_size], [0, 0], [0, 0], [0, 0]]
        )
        padded_coords = tf.pad(
            coords,
            [[0, MAX_TARGETS - actual_batch_size], [0, 0]]
        )
        padded_cal_mask = tf.pad(
            cal_mask,
            [[0, MAX_TARGETS - actual_batch_size]]
        )
        padded_target_mask = tf.pad(
            target_mask,
            [[0, MAX_TARGETS - actual_batch_size]]
        )

        # Note: tf.ensure_shape() removed - flexible model handles dynamic shapes
        return (padded_images, padded_coords, padded_cal_mask, padded_target_mask), padded_coords, subject_ids[0]

    masked_dataset = grouped_dataset.map(
        lambda imgs, phase, coords, subj_ids: add_masks_to_batch(imgs, phase, coords, subj_ids),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    return masked_dataset

In [None]:
def prepare_model_inputs(features, labels, subject_ids):
    """
    Prepare inputs in the format expected by the gaze model.
    Used with tf.data.Dataset.map() for efficient batched prediction.
    """
    images, coords, cal_mask, target_mask = features

    inputs = {
        "Input_All_Images": images,
        "Input_All_Coords": coords,
        "Input_Calibration_Mask": cal_mask,
        "Input_Target_Mask": target_mask
    }

    return inputs, labels, target_mask, subject_ids


def extract_calibration_and_test_data_for_webgazer(masked_dataset):
    """
    Extract calibration and test data from masked dataset for WebGazer evaluation.
    WebGazer uses sklearn which requires numpy arrays.
    """
    subject_data = []

    for batch in masked_dataset.as_numpy_iterator():
        features, labels, subject_id = batch
        images, coords, cal_mask, target_mask = features

        # Extract calibration data (where cal_mask == 1)
        cal_indices = np.where(cal_mask == 1)[0]
        cal_images = images[cal_indices]
        cal_coords = coords[cal_indices]

        # Extract test data (where target_mask == 1)
        test_indices = np.where(target_mask == 1)[0]
        test_images = images[test_indices]
        test_coords = coords[test_indices]

        # Only include subjects with sufficient data
        if len(cal_indices) >= 10 and len(test_indices) >= 96:
            subject_data.append({
                'subject_id': subject_id,
                'cal_images': cal_images,
                'cal_coords': cal_coords,
                'test_images': test_images,
                'test_coords': test_coords
            })

    return subject_data

In [None]:
# Prepare the dataset with Phase 1 calibration
print("Preparing masked dataset with Phase 1 calibration...")
masked_dataset_p1 = prepare_masked_dataset(test_data_rescaled, scaled_cal_points, 1)

# Create gaze model dataset (memory efficient - no caching)
gaze_model_dataset = masked_dataset_p1.map(
    prepare_model_inputs,
    num_parallel_calls=tf.data.AUTOTUNE
)

print("Dataset prepared for streaming evaluation")

## Load Our Trained Eye Tracking Model

Loading the pre-trained gaze prediction model for comparison. This model uses a deep learning backbone (DenseNet, ViT, or Hybrid CNN-Transformer) with ridge regression for subject-specific calibration.

In [None]:
# Download the model from W&B
wandb.login()
api = wandb.Api()
run = api.run(f"{ENTITY_NAME}/{PROJECT_NAME}/{EXPERIMENT_ID}")
config = run.config

# Download the full serialized model (new format)
run.file("full_model.keras").download(exist_ok=True)

# Get backbone type from config for display
backbone_type = config.get('backbone', 'unknown')

print("Model configuration:")
print(f"Backbone: {backbone_type}")
print(f"Embedding dim: {config['embedding_dim']}")
print(f"Ridge regularization: {config['ridge_regularization']}")

In [None]:
# Load the model using flexible inference utilities
# This rewires the trained model with flexible input shapes (None instead of 288)
# All trained layers and weights are preserved
gaze_model = create_flexible_inference_model('full_model.keras')

print(f"{OUR_MODEL_NAME} loaded successfully with flexible inference (backbone: {backbone_type})")
gaze_model.summary()

## Model Evaluation and Comparison

Evaluating both models on identical data splits and comparing performance.

In [None]:
def evaluate_webgazer_model(subject_data, lambda_ridge=0.001):
    """
    Evaluate WebGazer model on all subjects.
    Returns per-subject performance metrics.
    """
    results = []

    print(f"Evaluating WebGazer model on {len(subject_data)} subjects...")

    for i, subject in enumerate(subject_data):
        if i % 50 == 0:
            print(f"Processing subject {i+1}/{len(subject_data)}")

        subject_id = subject['subject_id']

        # Initialize and train WebGazer model
        webgazer = WebGazerModel(lambda_ridge=lambda_ridge)

        try:
            # Train on calibration data
            webgazer.fit(subject['cal_images'], subject['cal_coords'])

            # Predict on test data
            predictions = webgazer.predict(subject['test_images'])

            # Calculate errors using backend-agnostic approach
            errors = []
            for pred_coord, actual_coord in zip(predictions, subject['test_coords']):
                error = normalized_weighted_euc_dist(
                    np.array([[actual_coord[0], actual_coord[1]]], dtype=np.float32),
                    np.array([[pred_coord[0], pred_coord[1]]], dtype=np.float32)
                )
                errors.append(float(error[0]))  # Works with both JAX and TensorFlow backends

            mean_error = np.mean(errors)

            results.append({
                'subject_id': subject_id,
                'model': 'WebGazer',
                'mean_error': mean_error,
                'n_cal_points': len(subject['cal_coords']),
                'n_test_points': len(subject['test_coords'])
            })

        except Exception as e:
            print(f"Error processing subject {subject_id}: {e}")
            continue

    return pd.DataFrame(results)

In [None]:
def evaluate_gaze_model(gaze_model_dataset, gaze_model, model_name, batch_size=1):
    """
    Evaluate gaze model using pure tf.data streaming - data stays on disk until needed.
    Processes one batch at a time to minimize RAM usage.
    """
    print(f"Evaluating {model_name} with batch_size={batch_size}...")
    results = []
    
    # Stream through dataset in batches
    batched_dataset = gaze_model_dataset.batch(batch_size)
    
    for batch_idx, batch_data in enumerate(batched_dataset):
        if batch_idx % 20 == 0:
            print(f"Processing batch {batch_idx}...")
        
        # Unpack - subject_ids now included from prepare_model_inputs
        batch_inputs, batch_labels, batch_target_masks, batch_subject_ids = batch_data
        
        # Run model prediction (stays in TF/JAX until .numpy())
        batch_predictions = gaze_model(batch_inputs, training=False)
        
        # Only convert what we need to numpy, one subject at a time
        current_batch_size = batch_predictions.shape[0]
        for i in range(current_batch_size):
            # Extract single subject data
            subject_id = int(batch_subject_ids[i].numpy())
            predictions = batch_predictions[i].numpy()
            target_mask = batch_target_masks[i].numpy()
            actual_coords = batch_labels[i].numpy()
            cal_mask = batch_inputs['Input_Calibration_Mask'][i].numpy()
            
            # Extract test predictions using target mask
            test_predictions = predictions[target_mask == 1]
            test_actual = actual_coords[target_mask == 1]
            
            # Skip subjects with insufficient test data
            if len(test_predictions) < 96:
                continue
            
            # Calculate errors using vectorized approach
            pred_array = np.array(test_predictions, dtype=np.float32)
            actual_array = np.array(test_actual, dtype=np.float32)
            
            # Vectorized error calculation
            errors = normalized_weighted_euc_dist(actual_array, pred_array)
            mean_error = float(np.mean(errors))
            n_cal = int(np.sum(cal_mask))
            
            results.append({
                'subject_id': subject_id,
                'model': model_name,
                'mean_error': mean_error,
                'n_cal_points': n_cal,
                'n_test_points': len(test_predictions)
            })
        
        # Explicit cleanup after each batch
        del batch_predictions, batch_inputs, batch_labels, batch_target_masks
        if batch_idx % 10 == 0:
            gc.collect()
    
    print(f"Completed evaluation for {len(results)} subjects")
    return pd.DataFrame(results)

In [None]:
# Evaluate gaze model FIRST using streaming (no data loaded to RAM yet)
print("=== Evaluating Our Model (streaming from disk) ===")
our_model_results = evaluate_gaze_model(gaze_model_dataset, gaze_model, OUR_MODEL_NAME, batch_size=1)

# Clear gaze model from memory before loading WebGazer data
print("\nClearing model from memory...")
del gaze_model
gc.collect()

print("\n=== Evaluating WebGazer ===")
# Now extract data for WebGazer (loads to numpy - but gaze model is cleared)
print("Extracting data for WebGazer (requires numpy arrays)...")
masked_dataset_for_webgazer = prepare_masked_dataset(test_data_rescaled, scaled_cal_points, 1)
webgazer_subject_data = extract_calibration_and_test_data_for_webgazer(masked_dataset_for_webgazer)
print(f"Prepared data for {len(webgazer_subject_data)} subjects")

# Evaluate WebGazer
webgazer_results = evaluate_webgazer_model(webgazer_subject_data, lambda_ridge=RIDGE_REGULARIZATION)

# Clear WebGazer data
del webgazer_subject_data
gc.collect()

# Combine results
all_results = pd.concat([webgazer_results, our_model_results], ignore_index=True)

print(f"\n=== Evaluation Complete ===")
print(f"WebGazer results: {len(webgazer_results)} subjects")
print(f"{OUR_MODEL_NAME} results: {len(our_model_results)} subjects")

## Results Analysis and Visualization

Analyzing the performance differences between WebGazer and our model.

In [None]:
# Summary statistics
summary_stats = all_results.groupby('model')['mean_error'].agg(['mean', 'std', 'median', 'count']).round(4)
print("Model Performance Summary:")
print(summary_stats)
print()

# Statistical significance test
webgazer_errors = webgazer_results['mean_error'].values
our_model_errors = our_model_results['mean_error'].values

# Ensure we have matching subjects for paired comparison
webgazer_subjects = set(webgazer_results['subject_id'])
our_model_subjects = set(our_model_results['subject_id'])
common_subjects = webgazer_subjects.intersection(our_model_subjects)

print(f"Common subjects between models: {len(common_subjects)}")

if len(common_subjects) > 10:  # Need sufficient data for statistical test
    # Create paired data
    webgazer_paired = webgazer_results[webgazer_results['subject_id'].isin(common_subjects)].sort_values('subject_id')['mean_error'].values
    our_model_paired = our_model_results[our_model_results['subject_id'].isin(common_subjects)].sort_values('subject_id')['mean_error'].values

    # Paired t-test
    t_stat, p_value = stats.ttest_rel(webgazer_paired, our_model_paired)

    print(f"\nPaired t-test results:")
    print(f"t-statistic: {t_stat:.4f}")
    print(f"p-value: {p_value:.6f}")
    print(f"Mean difference (WebGazer - {OUR_MODEL_NAME}): {np.mean(webgazer_paired - our_model_paired):.4f}")

    # Effect size (Cohen's d)
    diff = webgazer_paired - our_model_paired
    cohens_d = np.mean(diff) / np.std(diff)
    print(f"Effect size (Cohen's d): {cohens_d:.4f}")

In [None]:
# Create comparison visualizations

# 1. Error distribution comparison
distribution_plot = (
    ggplot(all_results, aes(x='mean_error', fill='model'))
    + geom_histogram(alpha=0.7, position='identity', bins=30)
    + labs(
        title=f'Prediction Error Distribution: WebGazer vs {OUR_MODEL_NAME}',
        x='Mean Prediction Error (% Screen Diagonal)',
        y='Number of Subjects'
    )
    + scale_fill_manual(values=['#A23B72', '#2E86AB'], name='Model')
    + theme_classic()
    + theme(plot_title=element_text(hjust=0.5))
)

print("Distribution plot:")
distribution_plot

In [None]:
from plotnine import coord_equal
# 2. Direct comparison scatter plot (for common subjects)

comparison_data = []
for subject_id in common_subjects:
    webgazer_error = webgazer_results[webgazer_results['subject_id'] == subject_id]['mean_error'].iloc[0]
    our_model_error = our_model_results[our_model_results['subject_id'] == subject_id]['mean_error'].iloc[0]

    comparison_data.append({
        'subject_id': subject_id,
        'webgazer_error': webgazer_error*100,
        'our_model_error': our_model_error*100,
        'difference': webgazer_error - our_model_error
    })

comparison_df = pd.DataFrame(comparison_data)

# Scatter plot
scatter_plot = (
    ggplot(comparison_df, aes(x='our_model_error', y='webgazer_error'))
    + geom_point(alpha=0.6, size=2)
    + geom_abline(intercept=0, slope=1, linetype='dashed', color='red')
    + labs(
        title='Subject-Level Performance Comparison',
        x=f'{OUR_MODEL_NAME} Error (% Screen Diagonal)',
        y='WebGazer Error (% Screen Diagonal)',
        subtitle='Points above red line indicate our model performed better'
    )
    + coord_equal(xlim=(3, 45), ylim=(3, 45))
    + theme_classic()
    + theme(plot_title=element_text(hjust=0.5), plot_subtitle=element_text(hjust=0.5))
)

print("Scatter plot comparison:")

our_model_better = sum(comparison_df['difference'] > 0)
percent_better = (our_model_better / len(comparison_df)) * 100

print(f"{OUR_MODEL_NAME} performed better than WebGazer in {our_model_better} subjects ({percent_better:.1f}%)")

average_improvement = comparison_df['difference'].mean() * 100
print(f"Average improvement over WebGazer: {average_improvement:.2f}%")

scatter_plot

## Save Results and Create Summary

In [None]:
# Save plots
distribution_plot.save('webgazer_comparison_distribution.png', width=10, height=6, dpi=150)

if len(common_subjects) > 10:
    scatter_plot.save('webgazer_comparison_scatter.png', width=8, height=6, dpi=150)

# Save results data
all_results.to_csv('webgazer_comparison_results.csv', index=False)

if len(common_subjects) > 10:
    comparison_df.to_csv('webgazer_paired_comparison.csv', index=False)

print("Results and plots saved successfully")