# Setup

In [None]:
!pip install osfclient --quiet
!pip install git+https://github.com/jspsych/eyetracking-utils.git --quiet
!pip install wandb --quiet
!pip install --upgrade keras --quiet
!pip install keras-hub --quiet

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import numpy as np
import math
import keras
import keras_hub
from keras import ops

import wandb
from wandb.integration.keras import WandbMetricsLogger

import matplotlib.pyplot as plt
from google.colab import userdata

import et_util.dataset_utils as dataset_utils
import et_util.embedding_preprocessing as embed_pre
import et_util.model_layers as model_layers
from et_util import experiment_utils
from et_util.custom_loss import normalized_weighted_euc_dist
from et_util.model_analysis import plot_model_performance

In [None]:
os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
os.environ['OSF_TOKEN'] = userdata.get('osftoken')
os.environ['OSF_USERNAME'] = userdata.get('osfusername')

In [None]:
keras.version()

In [None]:
keras.mixed_precision.set_global_policy('float32')

# Configure W&B experiment

In [None]:
wandb.login()

In [None]:
# Fixed constants
MAX_TARGETS = 144

# Config constants
EMBEDDING_DIM = 200
RIDGE_REGULARIZATION = 0.1
MIN_CAL_POINTS = 8
MAX_CAL_POINTS = 40

BACKBONE = "densenet"

AUGMENTATION = True

INITIAL_LEARNING_RATE = 0.00001
LEARNING_RATE = 0.01
BATCH_SIZE = 5
TRAIN_EPOCHS = 50
WARMUP_EPOCHS = 2
DECAY_EPOCHS = TRAIN_EPOCHS - WARMUP_EPOCHS
DECAY_ALPHA = 0.01


In [None]:
config = {
    "embedding_dim": EMBEDDING_DIM,
    "ridge_regularization": RIDGE_REGULARIZATION,
    "train_epochs": TRAIN_EPOCHS,
    "min_cal_points": MIN_CAL_POINTS,
    "max_cal_points": MAX_CAL_POINTS,
    "backbone": BACKBONE,
    "batch_size": BATCH_SIZE,
    "initial_learning_rate": INITIAL_LEARNING_RATE,
    "learning_rate": LEARNING_RATE,
    "augmentation": AUGMENTATION,
    "warmup_epochs": WARMUP_EPOCHS,
    "decay_epochs": DECAY_EPOCHS,
    "decay_alpha": DECAY_ALPHA,
}

In [None]:
run = wandb.init(
    project='eye-tracking-dense-full-data-set-single-eye',
    config=config
)

# Dataset preparation

## Download dataset from OSF

In [None]:
!osf -p 6b5cd fetch single_eye_tfrecords.tar.gz

## Process raw data records into TF Dataset

In [None]:
!mkdir single_eye_tfrecords
!tar -xf single_eye_tfrecords.tar.gz -C single_eye_tfrecords

In [None]:
def parse(element):
    """Process function that parses a tfr element in a raw dataset for process_tfr_to_tfds function.
    Gets mediapipe landmarks, raw image, image width, image height, subject id, and xy labels.
    Use for data generated with make_single_example_landmarks_and_jpg (i.e. data in
    jpg_landmarks_tfrecords.tar.gz)

    :param element: tfr element in raw dataset
    :return: image, label(x,y), landmarks, subject_id
    """

    data_structure = {
        'landmarks': tf.io.FixedLenFeature([], tf.string),
        'img_width': tf.io.FixedLenFeature([], tf.int64),
        'img_height': tf.io.FixedLenFeature([], tf.int64),
        'x': tf.io.FixedLenFeature([], tf.float32),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'eye_img': tf.io.FixedLenFeature([], tf.string),
        'subject_id': tf.io.FixedLenFeature([], tf.int64),
    }

    content = tf.io.parse_single_example(element, data_structure)

    landmarks = content['landmarks']
    raw_image = content['eye_img']
    width = content['img_width']
    height = content['img_height']
    depth = 3
    label = [content['x'], content['y']]
    subject_id = content['subject_id']

    landmarks = tf.io.parse_tensor(landmarks, out_type=tf.float32)
    landmarks = tf.reshape(landmarks, shape=(478, 3))

    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)

    return image, landmarks, label, subject_id

In [None]:
train_data, validation_data, test_data = dataset_utils.process_tfr_to_tfds(
    'single_eye_tfrecords/',
    parse,
    train_split=1.0,
    val_split=0.0,
    test_split=0.0,
    random_seed=12604,
    group_function=lambda img, landmarks, coords, z: z
)

In [None]:
def rescale_coords_map(eyes, mesh, coords, id):
  return eyes, mesh, tf.divide(coords, tf.constant([100.])), id

In [None]:
train_data_rescaled = train_data.map(rescale_coords_map)
validation_data_rescaled = validation_data.map(rescale_coords_map)

In [None]:
def prepare_masked_dataset(dataset, calibration_points=None):

    # Step 1: Group dataset by subject_id and batch all images
    def group_by_subject(subject_id, ds):
        return ds.batch(batch_size=MAX_TARGETS)

    grouped_dataset = dataset.group_by_window(
        key_func=lambda img, mesh, coords, subject_id: subject_id,
        reduce_func=group_by_subject,
        window_size=MAX_TARGETS
    )

    # Step 2: Filter out subjects with fewer than 72 images
    def filter_by_image_count(images, meshes, coords, subject_ids):
        return tf.shape(images)[0] >= 144

    grouped_dataset = grouped_dataset.filter(filter_by_image_count)

    # Step 3: Transform each batch to include masks
    def add_masks_to_batch(images, meshes, coords, subject_ids):

        actual_batch_size = tf.shape(images)[0]

        cal_mask = tf.zeros(MAX_TARGETS, dtype=tf.int8)
        target_mask = tf.zeros(MAX_TARGETS, dtype=tf.int8)

        # Determine how many calibration images to use (random between min and max)
        if calibration_points is None:
          n_cal_images = tf.random.uniform(
              shape=[],
              minval=MIN_CAL_POINTS,
              maxval=MAX_CAL_POINTS,
              dtype=tf.int32
          )
          # Create random indices for calibration images
          # NOTE: These will be different each time through the dataset
          random_indices = tf.random.shuffle(tf.range(actual_batch_size))
          cal_indices = random_indices[:n_cal_images]

          # Create masks (1 = included, 0 = excluded)
          cal_mask = tf.scatter_nd(
              tf.expand_dims(cal_indices, 1),
              tf.ones(n_cal_images, dtype=tf.int8),
              [MAX_TARGETS]
          )
        else:
          coords_xpand = tf.expand_dims(coords, axis=1)
          cal_xpand = tf.expand_dims(calibration_points, axis=0)
          equality = tf.equal(coords_xpand, cal_xpand)
          matches = tf.reduce_all(equality, axis=-1)
          point_matches = tf.reduce_any(matches, axis=1)
          cal_mask = tf.cast(point_matches, dtype=tf.int8)



        target_mask = 1 - cal_mask

        # Pad everything to fixed size

        padded_images = tf.pad(
            tf.reshape(images, (-1, 36, 144, 1)),
            [[0, MAX_TARGETS - actual_batch_size], [0, 0], [0, 0], [0, 0]]
        )
        padded_coords = tf.pad(
            coords,
            [[0, MAX_TARGETS - actual_batch_size], [0, 0]]
        )

        # Ensure all shapes are fixed
        padded_images = tf.ensure_shape(padded_images, [MAX_TARGETS, 36, 144, 1])
        padded_coords = tf.ensure_shape(padded_coords, [MAX_TARGETS, 2])
        padded_cal_mask = tf.ensure_shape(cal_mask, [MAX_TARGETS])
        padded_target_mask = tf.ensure_shape(target_mask, [MAX_TARGETS])
        return (padded_images, padded_coords, padded_cal_mask, padded_target_mask), padded_coords, subject_ids

    # Apply the transformation
    masked_dataset = grouped_dataset.map(
        lambda imgs, meshes, coords, subj_ids: add_masks_to_batch(imgs, meshes, coords, subj_ids),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    return masked_dataset

# Modified model input preparation function
def prepare_model_inputs(features, labels, subject_ids):
    """
    Restructure the inputs for the model, using the mask-based approach.

    Args:
        features: Tuple of (images, coords, cal_mask, target_mask)
        labels: The target coordinates

    Returns:
        Dictionary of inputs for the model and the target labels
    """
    images, coords, cal_mask, target_mask = features

    # Use masking to create the model inputs
    inputs = {
        "Input_All_Images": images,                   # All images (both cal and target)
        "Input_All_Coords": coords,                   # All coordinates
        "Input_Calibration_Mask": cal_mask,           # Mask indicating calibration images
    }

    return inputs, labels, target_mask # target_mask is used as sample weights for loss function

In [None]:
cal_points = tf.constant([
    [5, 5],
    [5, 27.5],
    [5, 50],
    [5, 72.5],
    [5, 95],
    [35, 5],
    [35, 27.5],
    [35, 50],
    [35, 72.5],
    [35, 95],
    [65, 5],
    [65, 27.5],
    [65, 50],
    [65, 72.5],
    [65, 95],
    [95, 5],
    [95, 27.5],
    [95, 50],
    [95, 72.5],
    [95, 95],
], dtype=tf.float32)

scaled_cal_points = tf.divide(cal_points, tf.constant([100.]))

In [None]:
masked_dataset = prepare_masked_dataset(train_data_rescaled)

In [None]:
train_ds_for_model = masked_dataset.map(
    prepare_model_inputs,
    num_parallel_calls=tf.data.AUTOTUNE
).shuffle(200).prefetch(tf.data.AUTOTUNE)

In [None]:
INDIVIDUALS = 0
for e in train_ds_for_model.as_numpy_iterator():
  INDIVIDUALS += 1

## Visualization

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import tensorflow as tf
from IPython.display import HTML
import base64

def visualize_eye_tracking_data(dataset, interval=100, figsize=(10, 6)):
    """
    Visualize eye tracking data by animating through a subject's images and displaying
    where they are looking with a red dot. Returns an HTML animation for direct display
    in a Colab notebook.

    Args:
        dataset: A TensorFlow dataset containing a single batch of one subject's data
                Expected format: ({"Input_All_Images": images, "Input_All_Coords": coords,
                                 "Input_Calibration_Mask": cal_mask, "Input_Target_Mask": target_mask}, labels)
                OR: Your model's dataset with one subject batch
        interval: Time between frames in milliseconds
        figsize: Size of the figure (width, height)

    Returns:
        IPython.display.HTML object with the animation

    Example usage in a Colab notebook:
    ```python
    # For a dataset from a single subject
    subject_dataset = train_ds_for_model.take(1)

    # Or you can create a dataset for just one subject (example from your notebook):
    # This line extracts data for one subject from the complete dataset
    single_subject = train_ds_for_model.filter(
        lambda inputs, labels: tf.equal(inputs["Input_Subject_ID"][0], subject_id)
    ).take(1)

    # Display the animation directly in the notebook
    from IPython.display import display
    animation = visualize_eye_tracking_data(single_subject)
    display(animation)
    ```
    """
    # Extract data from the dataset
    for inputs, labels, _ in dataset.take(1):
        images = inputs["Input_All_Images"]
        coords = inputs["Input_All_Coords"]

        # Convert to numpy arrays
        valid_images = images.numpy()
        valid_coords = coords.numpy()

        # Sort by X coordinate, then by Y coordinate
        sorted_indices = np.lexsort((valid_coords[:, 0], valid_coords[:, 1]))
        valid_images = valid_images[sorted_indices]
        valid_coords = valid_coords[sorted_indices]

        num_images = valid_images.shape[0]

    # Set up the figure
    plt.ioff()  # Turn off interactive mode to avoid displaying during generation
    fig, ax = plt.subplots(figsize=figsize)

    # Create a function that draws each frame
    def draw_frame(frame_num):
        ax.clear()

        # Set up the 16:9 coordinate space
        ax.set_xlim(0, 100)
        ax.set_ylim(100, 0)  # 16:9 aspect ratio

        # Get the current image and coordinates
        img = valid_images[frame_num].squeeze()
        x, y = valid_coords[frame_num]

        # Draw rectangle for the screen
        rect = plt.Rectangle((0, 0), 100, 100, fill=False, color='black', linewidth=2)
        ax.add_patch(rect)

        # Add image display in the top right corner
        img_display = ax.inset_axes([0.35, 0.35, 0.3, 0.3], transform=ax.transAxes)
        img_display.imshow(np.fliplr(img), cmap='gray')
        img_display.axis('off')

        # Draw red dot at the coordinate - make it bigger for visibility
        ax.scatter(x * 100, y * 100, color='red', s=150, zorder=5,
                  edgecolor='white', linewidth=1.5)

        # Draw crosshair
        ax.axhline(y * 100, color='gray', linestyle='--', alpha=0.5, zorder=1)
        ax.axvline(x * 100, color='gray', linestyle='--', alpha=0.5, zorder=1)

        # Set title
        ax.set_title('Eye Tracking Visualization')

        # Set axis labels
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')

    # Create the animation
    anim = animation.FuncAnimation(fig, draw_frame, frames=num_images, interval=interval)

  # Use animation's to_jshtml method which directly generates HTML
    html_animation = anim.to_jshtml()

    # Clean up
    plt.close(fig)

    # Create an HTML object that can be displayed in the notebook
    return HTML(html_animation)


In [None]:
visualize_eye_tracking_data(train_ds_for_model.take(1))

# Model building

## Custom Layers


In [None]:
class SimpleTimeDistributed(keras.layers.Wrapper):
    """A simplified version of TimeDistributed that applies a layer to every temporal slice of an input.

    This implementation avoids for loops by using reshape operations to apply the wrapped layer
    to all time steps at once.

    Args:
        layer: a `keras.layers.Layer` instance.
    """

    def __init__(self, layer, **kwargs):
        super().__init__(layer, **kwargs)
        self.supports_masking = getattr(layer, 'supports_masking', False)

    def build(self, input_shape):
        # Validate input shape has at least 3 dimensions (batch, time, ...)
        if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3:
            raise ValueError(
                "`SimpleTimeDistributed` requires input with at least 3 dimensions"
            )

        # Build the wrapped layer with shape excluding the time dimension
        super().build((input_shape[0], *input_shape[2:]))
        self.built = True

    def compute_output_shape(self, input_shape):
        # Get output shape by applying the layer to a single time slice
        child_output_shape = self.layer.compute_output_shape((input_shape[0], *input_shape[2:]))
        # Include time dimension in the result
        return (child_output_shape[0], input_shape[1], *child_output_shape[1:])

    def call(self, inputs, training=None):
        input_shape = ops.shape(inputs)
        batch_size = input_shape[0]
        time_steps = input_shape[1]

        # Reshape inputs to combine batch and time dimensions: (batch*time, ...)
        reshaped_inputs = ops.reshape(inputs, (-1, *input_shape[2:]))

        # Apply the layer to all time steps at once
        outputs = self.layer.call(reshaped_inputs, training=training)

        # Get output dimensions
        output_shape = ops.shape(outputs)

        # Reshape back to include the separate batch and time dimensions: (batch, time, ...)
        return ops.reshape(outputs, (batch_size, time_steps, *output_shape[1:]))

In [None]:
# cal_points = tf.constant([
#     [5, 5],
#     [5, 27.5],
#     [5, 50],
#     [5, 72.5],
#     [5, 95],
#     [35, 5],
#     [35, 27.5],
#     [35, 50],
#     [35, 72.5],
#     [35, 95],
#     [65, 5],
#     [65, 27.5],
#     [65, 50],
#     [65, 72.5],
#     [65, 95],
#     [95, 5],
#     [95, 27.5],
#     [95, 50],
#     [95, 72.5],
#     [95, 95],
# ], dtype=tf.float32)

# scaled_cal_points = tf.divide(cal_points, tf.constant([100.]))


In [None]:
class MaskedWeightedRidgeRegressionLayer(keras.layers.Layer):
    """
    A custom layer that performs weighted ridge regression with proper masking support.

    This layer takes embeddings, coordinates, weights, and calibration mask as explicit inputs,
    while using Keras' masking system to handle target masking. This separation allows more
    precise control over calibration points while leveraging Keras' built-in mask propagation
    for target predictions.

    Args:
        lambda_ridge (float): Regularization parameter for ridge regression
        embedding_dim (int): Dimension of the embedding vectors
        epsilon (float): Small constant for numerical stability inside sqrt
    """
    def __init__(self, lambda_ridge, epsilon=1e-7, **kwargs): # Added epsilon argument
        self.lambda_ridge = lambda_ridge
        self.epsilon = epsilon # Store epsilon
        super(MaskedWeightedRidgeRegressionLayer, self).__init__(**kwargs)

    def call(self, inputs, mask=None):
        """
        The forward pass of the layer.

        Args:
            inputs: A list containing:
                - embeddings: Embeddings for all points (batch_size, n_points, embedding_dim)
                - coords: Coordinates for all points (batch_size, n_points, 2)
                - calibration_weights: Importance weights (batch_size, n_points, 1)
                - cal_mask: Mask for calibration points (batch_size, n_points) [EXPLICIT]

        Returns:
            Predicted coordinates for the target points (batch_size, n_points, 2)
        """
        # Unpack inputs
        embeddings, coords, calibration_weights, cal_mask  = inputs

        # Ensure correct dtype, especially important for JIT
        embeddings = ops.cast(embeddings, "float32")
        coords = ops.cast(coords, "float32")
        calibration_weights = ops.cast(calibration_weights, "float32")
        cal_mask = ops.cast(cal_mask, "float32")

        # reshape weights to (batch, calibration)
        w = ops.squeeze(calibration_weights, axis=-1)

        # Pre-compute masked weights for calibration points
        w_masked = w * cal_mask
        # Add epsilon inside sqrt for numerical stability
        w_sqrt = ops.sqrt(w_masked + self.epsilon)
        w_sqrt = ops.expand_dims(w_sqrt, -1)  # More efficient than reshape

        # Apply calibration mask to embeddings
        cal_mask_expand = ops.expand_dims(cal_mask, -1)
        X = embeddings * cal_mask_expand # Mask out non-calibration embeddings for regression calculation

        # Weight calibration embeddings and coordinates using the masked weights
        X_weighted = X * w_sqrt
        y_weighted = coords * w_sqrt * cal_mask_expand # Also mask coords just to be safe

        # Matrix operations
        X_t = ops.transpose(X_weighted, axes=[0, 2, 1])
        X_t_X = ops.matmul(X_t, X_weighted)

        # Add regularization - CORRECTED: Removed the extra * 1e-3
        # Ensure the identity matrix is also float32
        identity_matrix = ops.cast(ops.eye(ops.shape(embeddings)[-1]), "float32")
        lhs = X_t_X + self.lambda_ridge * identity_matrix

        # Compute RHS
        rhs = ops.matmul(X_t, y_weighted)

        # Solve the system
        # Consider adding checks or alternative solvers if instability persists,
        # but correcting the regularization should be the primary fix.
        kernel = ops.linalg.solve(lhs, rhs)

        # Apply regression using the *original* full embeddings
        output = ops.matmul(embeddings, kernel)

        return output

    def compute_output_shape(self, input_shapes):
        """
        Computes the output shape of the layer.

        Args:
            input_shapes: List of input shapes

        Returns:
            Output shape tuple
        """
        # Output shape matches coordinates: (batch_size, n_points, 2)
        return (input_shapes[0][0], input_shapes[0][1], 2)

    def get_config(self):
        """
        Returns the configuration of the layer for serialization.

        Returns:
            Dictionary containing the layer configuration
        """
        config = super(MaskedWeightedRidgeRegressionLayer, self).get_config()
        config.update({
            "lambda_ridge": self.lambda_ridge,
            "epsilon": self.epsilon # Add epsilon to config
        })
        return config

In [None]:
class MaskInspectorLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MaskInspectorLayer, self).__init__(**kwargs)
        self.supports_masking = True
    def call(self, inputs, mask=None):
      # Print mask information
      tf.print("Layer mask:", mask)
      return inputs

## Custom loss

In [None]:
def normalized_weighted_euc_dist(y_true, y_pred):
    """Custom loss function that calculates a weighted Euclidean distance between two sets of points,
    respecting masks for padded sequences.

    Weighting multiplies the x-coordinate of each input by 1.778 (derived from the 16:9 aspect ratio of most laptops)
    to match the scale of the y-axis. For interpretability, the distances are normalized to the diagonal
    such that the maximum distance between two points (corner to corner) is 100.

    :param y_true (tensor): A tensor of shape (seq_len, 2) containing ground-truth x- and y- coordinates
    :param y_pred (tensor): A tensor of shape (seq_len, 2) containing predicted x- and y- coordinates

    :returns: A tensor of shape (1,) with the weighted and normalized euclidean distances between valid points.
    """
    # Weighting treats y-axis as unit-scale and creates a rectangle that's 177.8x100 units.
    x_weight = ops.convert_to_tensor([1.778, 1.0], dtype="float32")

    # Multiply x-coordinate by 16/9 = 1.778
    y_true_weighted = ops.multiply(x_weight, y_true)
    y_pred_weighted = ops.multiply(x_weight, y_pred)

    # Calculate Euclidean distance with weighted coordinates
    squared_diff = ops.square(y_pred_weighted - y_true_weighted)
    squared_dist = ops.sum(squared_diff, axis=-1)
    dist = ops.sqrt(squared_dist)

    # Euclidean Distance from [0,0] to [1.778, 1.00] = 2.03992
    # Divide by 2.03992, mult by 100 is same as / .0203992
    # Normalizes loss values to the diagonal-- makes loss easier to interpret
    normalized_dist = ops.divide(dist, .0203992)

    return normalized_dist

## Embedding Model

### Backbones

#### DenseNet

In [None]:
def create_dense_net_backbone():
  DENSE_NET_STACKWISE_NUM_REPEATS = [4,4,4]
  return keras_hub.models.DenseNetBackbone(
      stackwise_num_repeats=DENSE_NET_STACKWISE_NUM_REPEATS,
      image_shape=(36, 144, 1),
  )

#### Involution

In [None]:
class Involution(keras.layers.Layer):
    def __init__(
        self, channel, group_number, kernel_size, stride, reduction_ratio
    ):
        super().__init__()

        # Initialize the parameters.
        self.channel = channel
        self.group_number = group_number
        self.kernel_size = kernel_size
        self.stride = stride
        self.reduction_ratio = reduction_ratio

    def build(self, input_shape):
        # Get the shape of the input.
        (_, height, width, num_channels) = input_shape

        # Scale the height and width with respect to the strides.
        height = height // self.stride
        width = width // self.stride

        # Define a layer that average pools the input tensor
        # if stride is more than 1.
        self.stride_layer = (
            keras.layers.AveragePooling2D(
                pool_size=self.stride, strides=self.stride, padding="same"
            )
            if self.stride > 1
            else tf.identity
        )
        # Define the kernel generation layer.
        self.kernel_gen = keras.Sequential(
            [
                keras.layers.Conv2D(
                    filters=self.channel // self.reduction_ratio, kernel_size=1
                ),
                keras.layers.BatchNormalization(),
                keras.layers.ReLU(),
                keras.layers.Conv2D(
                    filters=self.kernel_size * self.kernel_size * self.group_number,
                    kernel_size=1,
                ),
            ]
        )
        # Define reshape layers
        self.kernel_reshape = keras.layers.Reshape(
            target_shape=(
                height,
                width,
                self.kernel_size * self.kernel_size,
                1,
                self.group_number,
            )
        )
        self.input_patches_reshape = keras.layers.Reshape(
            target_shape=(
                height,
                width,
                self.kernel_size * self.kernel_size,
                num_channels // self.group_number,
                self.group_number,
            )
        )
        self.output_reshape = keras.layers.Reshape(
            target_shape=(height, width, num_channels)
        )

    def call(self, x):
        # Generate the kernel with respect to the input tensor.
        # B, H, W, K*K*G
        kernel_input = self.stride_layer(x)
        kernel = self.kernel_gen(kernel_input)

        # reshape the kerenl
        # B, H, W, K*K, 1, G
        kernel = self.kernel_reshape(kernel)

        # Extract input patches.
        # B, H, W, K*K*C
        input_patches = keras.ops.image.extract_patches(
            images=x,
            size=self.kernel_size,
            strides=self.stride,
            dilation_rate=1,
            padding="same",
        )

        # Reshape the input patches to align with later operations.
        # B, H, W, K*K, C//G, G
        input_patches = self.input_patches_reshape(input_patches)

        # Compute the multiply-add operation of kernels and patches.
        # B, H, W, K*K, C//G, G
        output = keras.ops.multiply(kernel, input_patches)
        # B, H, W, C//G, G
        output = keras.ops.sum(output, axis=3)

        # Reshape the output kernel.
        # B, H, W, C
        output = self.output_reshape(output)

        # Return the output tensor and the kernel.
        return output #, kernel

In [None]:
def create_rednet_block(x, channels, kernel_size=3, stride=1, groups=8, reduction_ratio=4):
    # Save input for residual connection
    residual = keras.layers.Conv2D(channels, kernel_size=1, strides=stride, padding='same')(x)

    # Apply involution
    x = Involution(
        channel=channels,
        group_number=groups,
        kernel_size=kernel_size,
        stride=stride,
        reduction_ratio=reduction_ratio
    )(x)

    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)

    # Add residual connection
    x = keras.layers.Add()([x, residual])

    return x

In [None]:
def create_rednet_backbone():
    """
    Create a RedNet-based model for gaze prediction using the custom Involution layer
    """
    inputs = keras.layers.Input(shape=(36,144,1))

    # Initial convolution
    x = keras.layers.Conv2D(128, kernel_size=7, strides=1, padding='same')(inputs)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)

    # RedNet stages
    x = create_rednet_block(x, 128, stride=1)
    x = create_rednet_block(x, 128, stride=1)
    x = create_rednet_block(x, 128, stride=2)

    x = keras.layers.Conv2D(64, kernel_size=1, strides=1, padding='same')(x)

    x = create_rednet_block(x, 64, stride=1)
    x = create_rednet_block(x, 64, stride=1)
    x = create_rednet_block(x, 64, stride=2)

    x = keras.layers.Conv2D(32, kernel_size=1, strides=1, padding='same')(x)

    # weights = keras.layers.Conv2D(1, 1, activation='sigmoid')(x)
    # x = keras.layers.Multiply()([x, weights])
    # x = keras.layers.GlobalAveragePooling2D()(x)



    # Model with dual outputs
    model = keras.Model(inputs=inputs, outputs=x)

    return model

In [None]:
backbone = create_rednet_backbone()

In [None]:
backbone.summary()

In [None]:
def create_involution_backbone():
  inputs = keras.layers.Input(shape=(36, 144, 1))
  x = keras.layers.Conv2D(filters=64, kernel_size=4, padding="same")(inputs)
  x = keras.layers.ReLU()(x)
  x = Involution(channel=64, group_number=1, kernel_size=3, stride=1, reduction_ratio=2, name="involution_1")(x)
  x = keras.layers.ReLU()(x)
  x = Involution(channel=64, group_number=1, kernel_size=3, stride=1, reduction_ratio=2, name="involution_2")(x)
  x = keras.layers.ReLU()(x)
  weights = keras.layers.Conv2D(1, 1, activation='sigmoid')(x)
  x = keras.layers.Multiply()([x, weights])
  x = keras.layers.GlobalAveragePooling2D()(x)
  return keras.Model(inputs=inputs, outputs=x)

#### EfficientNetB0

In [None]:
def create_efficientnet_backbone():
    """
    Creates an embedding model using EfficientNetB0 as the backbone.

    Handles grayscale input by expanding channels before feeding into EfficientNet.
    """
    image_shape = (36, 144, 1)  # Original grayscale input shape
    input_eyes = keras.layers.Input(shape=image_shape, name="Input_Eye_Image")

    # --- Preprocessing ---
    # 1. Rescale pixel values to [0, 1] (standard practice)

    # 2. Expand channels from 1 (grayscale) to 3
    # EfficientNetB0 expects 3 input channels. We can simply repeat the
    # grayscale channel three times.
    # A Conv2D layer could also learn this mapping, but repeating is simpler.
    eyes_3channel = keras.layers.Concatenate(axis=-1, name="Expand_Channels")(
        [input_eyes, input_eyes, input_eyes]
    )
    # The shape is now (batch, 36, 144, 3)

    # --- Backbone ---
    # Load EfficientNetB0 from keras.applications
    # - include_top=False: Remove the final classification layer.
    # - weights=None: Train from scratch. ImageNet weights are unlikely
    #                 to be optimal for eye images and might hinder learning.
    # - input_shape=(36, 144, 3): Match the 3-channel input we created.
    # - pooling=None: Get the raw feature maps from the last conv layer.
    try:
        backbone = keras.applications.EfficientNetB0(
            include_top=False,
            weights=None, # Start fresh for this specific task
            input_shape=(36, 144, 3),
            pooling=None, # We'll add our own pooling
            name="EfficientNetB0_Backbone"
        )
    except Exception as e:
        print(f"Error loading EfficientNetB0: {e}")
        print("Ensure you have tensorflow installed correctly.")
        print("You might need to run: pip install --upgrade tensorflow keras")
        # As a fallback, you might try loading from TF Hub if available,
        # but keras.applications should generally work.
        # import tensorflow_hub as hub
        # backbone = hub.KerasLayer("https://tfhub.dev/tensorflow/efficientnet/b0/feature-vector/1", trainable=True)
        # Need input preprocessing specific to the Hub model if using this.
        raise # Re-raise the exception to stop execution if loading fails.


    # Apply the backbone to the 3-channel input
    backbone_model = keras.Model(inputs=input_eyes, outputs=backbone(eyes_3channel))
    # Shape will depend on EfficientNetB0's final conv layer output size

    return backbone_model

Cols = 5, 11, 17, 23, 29, 35, 41, 47, 53, 59, 65, 71, 77, 83, 89, 95

Rows = 5, 16.25, 27.5, 38.75, 50, 61.25, 72.5, 83.75, 95  


### Create Embedding Model

In [None]:
def create_embedding_model(BACKBONE):
  image_shape = (36, 144, 1)
  input_eyes = keras.layers.Input(shape=image_shape)

  eyes_rescaled = keras.layers.Rescaling(scale=1./255)(input_eyes)

  # Continue with the backbone
  # backbone = create_rednet_backbone()


  # backbone = keras.Sequential([
  #     keras.layers.Flatten(),
  #     keras.layers.Dense(10, activation="relu")
  # ])
  if BACKBONE == "densenet":
    backbone = create_dense_net_backbone()

  # backbone = create_efficientnet_backbone()

  # backbone = keras_hub.models.MiTBackbone(
  #   image_shape=(36,144,1),
  #   layerwise_depths=[2,2,2,2],
  #   num_layers=4,
  #   layerwise_num_heads=[1,2,5,8],
  #   layerwise_sr_ratios=[8,4,2,1],
  #   max_drop_path_rate=0.1,
  #   layerwise_patch_sizes=[7,3,3,3],
  #   layerwise_strides=[4,2,2,2],
  #   hidden_dims=[32,64,160,256]
  # )

  backbone_encoder = backbone(eyes_rescaled)
  flatten_compress = keras.layers.Flatten()(backbone_encoder)
  eye_embedding = keras.layers.Dense(units=EMBEDDING_DIM, activation="tanh")(flatten_compress)

  embedding_model = keras.Model(inputs=input_eyes, outputs=eye_embedding, name="Eye_Image_Embedding")

  return embedding_model

## Full Model

In [None]:
# Modified model to work with the new masked input approach
def create_masked_model():
    """
    Create a model that uses masks to distinguish calibration and target images.

    This implementation:
    - Uses explicit calibration mask as a regular input
    - Integrates target masking with Keras' masking system
    - Properly propagates the target mask through the network
    """

    # Define inputs
    input_all_images = keras.layers.Input(
        shape=(MAX_TARGETS, 36, 144, 1),
        name="Input_All_Images"
    )

    input_all_coords = keras.layers.Input(
        shape=(MAX_TARGETS, 2),
        name="Input_All_Coords"
    )

    # Calibration mask is a regular input (not using Keras masking)
    input_cal_mask = keras.layers.Input(
        shape=(MAX_TARGETS,),
        name="Input_Calibration_Mask",
    )

    # Create the embedding model
    embedding_model = create_embedding_model(BACKBONE)

    # Apply the embedding model to all images
    all_embeddings = SimpleTimeDistributed(embedding_model, name="Image_Embeddings")(input_all_images)

    # Calculate importance weights for calibration points
    calibration_weights = keras.layers.Dense(
        1,
        activation="sigmoid",
        name="Calibration_Weights"
    )(all_embeddings)

    ridge = MaskedWeightedRidgeRegressionLayer(
        RIDGE_REGULARIZATION,
        name="Regression"
    )(
        # Inputs to the regression layer
        [
            all_embeddings,           # Embeddings for all points
            input_all_coords,         # Coordinates for all points
            calibration_weights,      # Weights for calibration importance
            input_cal_mask,            # Explicit calibration mask
        ],
    )

    # Create the full model with proper masking connections
    full_model = keras.Model(
        inputs=[
            input_all_images,
            input_all_coords,
            input_cal_mask,
        ],
        outputs=ridge,
        name="MaskedEyePredictionModel"
    )

    return full_model

# Model Training


## Augmentation

In [None]:
def get_translation_matrix(tx, ty):
    """Creates a 3x3 translation matrix."""
    return tf.convert_to_tensor([
        [1., 0., tx],
        [0., 1., ty],
        [0., 0., 1.]
    ], dtype=tf.float32)

def get_rotation_matrix(angle, center_x, center_y):
    """Creates a 3x3 rotation matrix around a center point."""
    center_x = tf.cast(center_x, tf.float32)
    center_y = tf.cast(center_y, tf.float32)
    cos_a = tf.cos(angle)
    sin_a = tf.sin(angle)
    m_rot = tf.convert_to_tensor([
        [cos_a, -sin_a, 0.],
        [sin_a,  cos_a, 0.],
        [0.,     0.,    1.]
    ], dtype=tf.float32)
    m_trans1 = get_translation_matrix(-center_x, -center_y)
    m_trans2 = get_translation_matrix(center_x, center_y)
    return m_trans2 @ m_rot @ m_trans1

def get_zoom_matrix(zx, zy, center_x, center_y):
    """Creates a 3x3 zoom matrix around a center point."""
    center_x = tf.cast(center_x, tf.float32)
    center_y = tf.cast(center_y, tf.float32)
    m_scale = tf.convert_to_tensor([
        [zx, 0., 0.],
        [0., zy, 0.],
        [0., 0., 1.]
    ], dtype=tf.float32)
    m_trans1 = get_translation_matrix(-center_x, -center_y)
    m_trans2 = get_translation_matrix(center_x, center_y)
    return m_trans2 @ m_scale @ m_trans1

def get_flip_matrix(horizontal, center_x, center_y):
    """Creates a 3x3 horizontal flip matrix around a center point."""
    if not horizontal:
        return tf.eye(3, dtype=tf.float32)
    center_x = tf.cast(center_x, tf.float32)
    center_y = tf.cast(center_y, tf.float32)
    m_flip = tf.convert_to_tensor([
        [-1., 0., 0.],
        [ 0., 1., 0.],
        [ 0., 0., 1.]
    ], dtype=tf.float32)
    m_trans1 = get_translation_matrix(-center_x, -center_y)
    m_trans2 = get_translation_matrix(center_x, center_y)
    return m_trans2 @ m_flip @ m_trans1

def matrix_to_affine_params(matrix):
    """Extracts the 8 parameters from a 3x3 affine matrix."""
    params = tf.stack([
        matrix[0, 0], matrix[0, 1], matrix[0, 2],
        matrix[1, 0], matrix[1, 1], matrix[1, 2],
        matrix[2, 0], matrix[2, 1]
    ])
    return params

In [None]:
def augment_sequence_keras_ops_affine(all_inputs, targets, mask):
    """
    Applies the same random AFFINE augmentations to all frames in a sequence
    using keras.ops.image.affine_transform and tf.random.stateless_uniform.

    Args:
        sequence: A tensor of shape (timesteps, H, W, C).

    Returns:
        The augmented sequence tensor with the same shape.
    """

    sequence = all_inputs["Input_All_Images"]
    input_cal_points = all_inputs["Input_All_Coords"]
    input_cal_mask = all_inputs["Input_Calibration_Mask"]

    sequence_shape = keras.ops.shape(sequence)
    img_height = sequence_shape[1]
    img_width = sequence_shape[2]
    center_x = tf.cast(img_width, tf.float32) / 2.0
    center_y = tf.cast(img_height, tf.float32) / 2.0

    # --- Generate ONE base seed pair for this sequence ---
    # Use tf.random.uniform (stateful) to get a UNIQUE starting point
    # for EACH sequence processed by the map function. This works reliably.
    base_seed = tf.random.uniform([2], minval=0, maxval=tf.int32.max, dtype=tf.int32)

    # --- Generate Random Parameters using tf.random.stateless_uniform ---
    # Note: tf.random.stateless_uniform requires shape as the first arg.
    # 1. Flip Decision
    flip_seed = base_seed + [0, 1] # Derive seed
    # Use tf.random.stateless_uniform directly
    random_flip_val = tf.random.stateless_uniform(shape=(), seed=flip_seed, minval=0.0, maxval=1.0)
    apply_flip = random_flip_val < 0.5

    # 2. Rotation Angle
    rotation_factor = 0.05
    max_angle = rotation_factor * math.pi
    rotation_seed = base_seed + [0, 2] # Derive seed
    # Use tf.random.stateless_uniform directly
    rotation_angle = tf.random.stateless_uniform(shape=(), seed=rotation_seed, minval=-max_angle, maxval=max_angle)

    # 3. Zoom Factors
    zoom_factor_range = (0.8, 1.2)
    zoom_h_seed = base_seed + [0, 3] # Derive seed H
    zoom_w_seed = base_seed + [0, 4] # Derive seed W
    # Use tf.random.stateless_uniform directly
    zoom_height_factor = tf.random.stateless_uniform(shape=(), seed=zoom_h_seed, minval=zoom_factor_range[0], maxval=zoom_factor_range[1])
    zoom_width_factor = tf.random.stateless_uniform(shape=(), seed=zoom_w_seed, minval=zoom_factor_range[0], maxval=zoom_factor_range[1])

    # 4. Translation Shifts
    translation_height_factor = 0.1
    translation_width_factor = 0.1
    shift_h_seed = base_seed + [0, 5] # Derive seed H
    shift_w_seed = base_seed + [0, 6] # Derive seed W
    # Use tf.random.stateless_uniform directly
    random_shift_h = tf.random.stateless_uniform(shape=(), seed=shift_h_seed, minval=-translation_height_factor, maxval=translation_height_factor)
    random_shift_w = tf.random.stateless_uniform(shape=(), seed=shift_w_seed, minval=-translation_width_factor, maxval=translation_width_factor)
    shift_height = random_shift_h * tf.cast(img_height, tf.float32)
    shift_width = random_shift_w * tf.cast(img_width, tf.float32)

    # --- Build Combined Affine Matrix (Same as before) ---
    combined_matrix = tf.eye(3, dtype=tf.float32)
    flip_matrix = get_flip_matrix(apply_flip, center_x, center_y)
    zoom_matrix = get_zoom_matrix(zoom_width_factor, zoom_height_factor, center_x, center_y)
    rotation_matrix = get_rotation_matrix(rotation_angle, center_x, center_y)
    translation_matrix = get_translation_matrix(shift_width, shift_height)
    combined_matrix = translation_matrix @ rotation_matrix @ zoom_matrix @ flip_matrix

    # --- Extract 8 parameters for Keras Op (Same as before) ---
    affine_params = matrix_to_affine_params(combined_matrix)

    # --- Apply the transformation using Keras Ops (Same as before) ---
    # keras.ops.image.affine_transform uses the backend implementation which
    # should handle the transformation correctly based on the calculated params.
    augmented_sequence = keras.ops.image.affine_transform(
        sequence,
        affine_params,
        interpolation="bilinear",
        fill_mode="constant"
    )

    # --- Optional Color Augmentations (using tf.random.stateless_*) ---
    # Example:
    # brightness_seed = base_seed + [0, 7]
    # brightness_delta = tf.random.stateless_uniform((), seed=brightness_seed, minval=-0.2, maxval=0.2)
    # augmented_sequence = augmented_sequence + brightness_delta
    # augmented_sequence = keras.ops.clip(augmented_sequence, 0.0, 1.0)

    # contrast_seed = base_seed + [0, 8]
    # contrast_factor = tf.random.stateless_uniform((), seed=contrast_seed, minval=0.7, maxval=1.3)
    # pixel_mean = keras.ops.mean(augmented_sequence, axis=[1, 2], keepdims=True)
    # augmented_sequence = (augmented_sequence - pixel_mean) * contrast_factor + pixel_mean
    # augmented_sequence = keras.ops.clip(augmented_sequence, 0.0, 1.0)

    inputs = {
        "Input_All_Images": augmented_sequence,
        "Input_All_Coords": input_cal_points,
        "Input_Calibration_Mask": input_cal_mask,
    }

    return inputs, targets, mask

In [None]:
if AUGMENTATION:
  train_ds_for_model = train_ds_for_model.map(
    augment_sequence_keras_ops_affine,
    num_parallel_calls=tf.data.AUTOTUNE
  ).prefetch(tf.data.AUTOTUNE)

### Visualizing augmentation

In [None]:
def plot_sequence_grid(sequence, title="Image Sequence Timesteps", figsize=None, cmap='gray'):
    """
    Plots all frames (timesteps) of an image sequence in a grid layout.

    This helps visualize if the same augmentation was applied consistently
    across the time dimension.

    Args:
        sequence: A NumPy array or TensorFlow tensor representing the image
                  sequence. Expected shape: (timesteps, H, W, C).
        title (str): The overall title for the plot.
        figsize (tuple, optional): Desired figure size (width, height) in inches.
                                   If None, a default size is calculated based
                                   on the number of timesteps.
        cmap (str): The colormap to use for grayscale images (when C=1).
                    Defaults to 'viridis'. Ignored for RGB images (C=3).
    """
    # --- Input Validation and Conversion ---
    if hasattr(sequence, 'numpy'): # Check if it's a TF tensor
        sequence = sequence.numpy()

    if not isinstance(sequence, np.ndarray):
        raise TypeError("Input 'sequence' must be a NumPy array or TensorFlow tensor.")

    if sequence.ndim != 4:
        raise ValueError(f"Input 'sequence' must be 4-dimensional (timesteps, H, W, C), got shape {sequence.shape}")

    num_timesteps, height, width, channels = sequence.shape

    if num_timesteps == 0:
        print("Sequence has 0 timesteps, nothing to plot.")
        return

    # --- Determine Grid Size ---
    # Calculate a reasonable grid size (aiming for roughly square or slightly wider)
    # Ignore the 12x12 request as it's usually too large for typical timesteps;
    # a dynamic grid is more practical.
    ncols = int(math.ceil(math.sqrt(num_timesteps)))
    # Try to make it slightly wider if not square - adjust ncols calculation if needed
    # A simpler approach: use a max number of columns
    # ncols = min(num_timesteps, 6) # Example: Limit to 6 columns max
    nrows = int(math.ceil(num_timesteps / float(ncols)))

    # --- Determine Figure Size ---
    if figsize is None:
        # Default figsize: scale based on grid size
        # Adjust the multipliers (2.5, 2) as needed for your image aspect ratio/preference
        figsize = (ncols * 2.5, nrows * 2)

    # --- Create Subplots ---
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)

    # Flatten axes array for easy iteration, handle cases with single row/col
    if isinstance(axes, plt.Axes): # Single subplot case (nrows=1, ncols=1)
        axes_flat = [axes]
    else:
        axes_flat = axes.flat # Flatten the 2D array of axes

    # --- Prepare Data for Plotting ---
    plot_cmap = None
    if channels == 1:
        plot_cmap = cmap # Use provided cmap for grayscale
        # Remove channel dim for imshow if it exists and is 1
        sequence_to_plot = np.squeeze(sequence, axis=-1)
    elif channels == 3:
         sequence_to_plot = sequence # imshow handles RGB (expects float 0-1 or int 0-255)
         # Ensure data range is suitable for imshow if it's float
         if sequence_to_plot.dtype == np.float32 or sequence_to_plot.dtype == np.float64:
              sequence_to_plot = np.clip(sequence_to_plot, 0, 1)
    else:
        # Handle other channel numbers? Plot the first channel as grayscale.
        print(f"Warning: Sequence has {channels} channels. Plotting the first channel as grayscale.")
        plot_cmap = cmap
        sequence_to_plot = sequence[..., 0] # Take the first channel

    # Determine global vmin/vmax for consistent coloring across grayscale timesteps
    vmin = np.min(sequence_to_plot)
    vmax = np.max(sequence_to_plot)

    # --- Plot Each Timestep ---
    for t in range(num_timesteps):
        ax = axes_flat[t] # Get the current subplot axis
        ax.imshow(sequence_to_plot[t], cmap=plot_cmap, vmin=vmin, vmax=vmax)
        ax.set_title(f"t={t}")
        ax.axis('off') # Hide axes ticks and labels

    # --- Hide Unused Subplots ---
    # If nrows * ncols > num_timesteps, hide the remaining axes
    for t in range(num_timesteps, len(axes_flat)):
         axes_flat[t].axis('off')

    # --- Final Touches ---
    fig.suptitle(title, fontsize=16)
    # Adjust layout to prevent titles/labels overlapping and make space for suptitle
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

In [None]:
for element in train_ds_for_model.take(1):
  plot_sequence_grid(element[0]["Input_All_Images"])

## Run the model

In [None]:
learning_rate_scheduler = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = INITIAL_LEARNING_RATE,
    decay_steps = INDIVIDUALS // BATCH_SIZE * DECAY_EPOCHS,
    alpha = DECAY_ALPHA,
    warmup_target = LEARNING_RATE,
    warmup_steps = INDIVIDUALS // BATCH_SIZE * WARMUP_EPOCHS
)

In [None]:
mask_model = create_masked_model()
mask_model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=learning_rate_scheduler),
    loss=normalized_weighted_euc_dist,
    jit_compile=True,
)
mask_model.summary()

In [None]:
mask_model.fit(train_ds_for_model.batch(BATCH_SIZE), epochs=TRAIN_EPOCHS, callbacks=[WandbMetricsLogger()])

# Save and export

In [None]:
mask_model.save('full_model.keras')

In [None]:
mask_model.save_weights('full_model.weights.h5')

In [None]:
wandb.save('full_model.keras')
wandb.save('full_model.weights.h5')

In [None]:
masked_dataset = prepare_masked_dataset(train_data_rescaled, calibration_points=scaled_cal_points)

In [None]:
test_ds = masked_dataset.map(
    prepare_model_inputs,
    num_parallel_calls=tf.data.AUTOTUNE
).prefetch(tf.data.AUTOTUNE)

In [None]:
test_ds.element_spec

In [None]:
predictions = mask_model.predict(test_ds.batch(1))

y_true = np.zeros(predictions.shape, dtype=np.float32)

i = 0
for e in test_ds.as_numpy_iterator():
  y_true[i,:,:] = e[1]
  i = i+1
# y_true = np.array(y_true).reshape(-1, 144, 2)

# # this step is slower than expected. not sure what's going on.
# predictions = full_model.predict(test_ds.)

batch_losses = np.zeros((predictions.shape[0], predictions.shape[1]))
for i in range(len(predictions)):
  loss = normalized_weighted_euc_dist(y_true[i], predictions[i]).numpy()
  batch_losses[i,:] = loss

# Get mean per subject
batch_losses = np.mean(batch_losses, axis=1)

In [None]:
table = wandb.Table(data=[[s] for s in batch_losses], columns=["Loss"])
final_loss_hist = wandb.plot.histogram(table, value="Loss", title="Normalized Euclidean Distance")

In [None]:
final_loss_mean = np.mean(batch_losses)

In [None]:
wandb.log({"final_val_loss_hist": final_loss_hist, "final_loss_mean": final_loss_mean})

In [None]:
plt.figure(figsize=(8, 6))
plt.hist(batch_losses, bins=20, edgecolor='black')
plt.title('Histogram of Batch Losses')
plt.xlabel('Loss')
plt.ylabel('Frequency')
plt.show()

In [None]:
wandb.finish()

In [None]:
from google.colab import runtime
runtime.unassign()