# Setup

In [1]:
!pip install osfclient --quiet
!pip install git+https://github.com/jspsych/eyetracking-utils.git --quiet
# !pip install wandb --quiet
!pip install --upgrade keras --quiet
!pip install keras-hub --quiet

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [2]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import numpy as np
import keras
import keras_hub
from keras import ops

# import wandb
# from wandb.integration.keras import WandbMetricsLogger

import matplotlib.pyplot as plt
from google.colab import userdata

import et_util.dataset_utils as dataset_utils
import et_util.embedding_preprocessing as embed_pre
import et_util.model_layers as model_layers
from et_util import experiment_utils
from et_util.custom_loss import normalized_weighted_euc_dist
from et_util.model_analysis import plot_model_performance

In [3]:
os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
os.environ['OSF_TOKEN'] = userdata.get('osftoken')
os.environ['OSF_USERNAME'] = userdata.get('osfusername')

In [4]:
keras.version()

'3.9.0'

In [5]:
keras.mixed_precision.set_global_policy('mixed_float16')

# Configure W&B experiment

In [6]:
# wandb.login()

In [7]:
# Fixed constants
MAX_TARGETS = 144

# Config constants
EMBEDDING_DIM = 50
RIDGE_REGULARIZATION = 0.001
TRAIN_EPOCHS = 30
MIN_CAL_POINTS = 8
MAX_CAL_POINTS = 40
DENSE_NET_STACKWISE_NUM_REPEATS = [4,4,4]
LEARNING_RATE = 0.001
AUGMENTATION = True

In [8]:
config = {
    "embedding_dim": EMBEDDING_DIM,
    "ridge_regularization": RIDGE_REGULARIZATION,
    "train_epochs": TRAIN_EPOCHS,
    "min_cal_points": MIN_CAL_POINTS,
    "max_cal_points": MAX_CAL_POINTS,
    "dense_net_stackwise_num_repeats": DENSE_NET_STACKWISE_NUM_REPEATS,
    "learning_rate": LEARNING_RATE,
    "augmentation": AUGMENTATION,
}

In [9]:
# run = wandb.init(
#     project='eye-tracking-dense-full-data-set-single-eye',
#     config=config
# )

# Download dataset from OSF

In [10]:
!osf -p 6b5cd fetch single_eye_tfrecords.tar.gz

usage: osf fetch [-h] [-f] [-U] remote [local]
osf fetch: error: Local file single_eye_tfrecords.tar.gz already exists, not overwriting.


# Process raw data records into TF Dataset

In [11]:
!mkdir single_eye_tfrecords
!tar -xf single_eye_tfrecords.tar.gz -C single_eye_tfrecords

mkdir: cannot create directory ‘single_eye_tfrecords’: File exists


In [12]:
def parse(element):
    """Process function that parses a tfr element in a raw dataset for process_tfr_to_tfds function.
    Gets mediapipe landmarks, raw image, image width, image height, subject id, and xy labels.
    Use for data generated with make_single_example_landmarks_and_jpg (i.e. data in
    jpg_landmarks_tfrecords.tar.gz)

    :param element: tfr element in raw dataset
    :return: image, label(x,y), landmarks, subject_id
    """

    data_structure = {
        'landmarks': tf.io.FixedLenFeature([], tf.string),
        'img_width': tf.io.FixedLenFeature([], tf.int64),
        'img_height': tf.io.FixedLenFeature([], tf.int64),
        'x': tf.io.FixedLenFeature([], tf.float32),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'eye_img': tf.io.FixedLenFeature([], tf.string),
        'subject_id': tf.io.FixedLenFeature([], tf.int64),
    }

    content = tf.io.parse_single_example(element, data_structure)

    landmarks = content['landmarks']
    raw_image = content['eye_img']
    width = content['img_width']
    height = content['img_height']
    depth = 3
    label = [content['x'], content['y']]
    subject_id = content['subject_id']

    landmarks = tf.io.parse_tensor(landmarks, out_type=tf.float32)
    landmarks = tf.reshape(landmarks, shape=(478, 3))

    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)

    return image, landmarks, label, subject_id

In [13]:
train_data, validation_data, test_data = dataset_utils.process_tfr_to_tfds(
    'single_eye_tfrecords/',
    parse,
    train_split=1.0,
    val_split=0.0,
    test_split=0.0,
    random_seed=12604,
    group_function=lambda img, landmarks, coords, z: z
)

In [14]:
def rescale_coords_map(eyes, mesh, coords, id):
  return eyes, mesh, tf.divide(coords, tf.constant([100.])), id

In [15]:
train_data_rescaled = train_data.map(rescale_coords_map)
validation_data_rescaled = validation_data.map(rescale_coords_map)

In [16]:
def prepare_masked_dataset(dataset):
    """
    Creates a dataset where each batch contains all images for a subject,
    and a random subset is marked as calibration images via a mask.

    Args:
        dataset: The base TF dataset containing eye images and coordinates
        batch_size: Number of images to include in each batch

    Returns:
        A TF dataset with masked images ready for model training
    """
    # Step 1: Group dataset by subject_id and batch all images
    def group_by_subject(subject_id, ds):
        return ds.batch(batch_size=MAX_TARGETS)

    grouped_dataset = dataset.group_by_window(
        key_func=lambda img, mesh, coords, subject_id: subject_id,
        reduce_func=group_by_subject,
        window_size=MAX_TARGETS
    )

    # Step 2: Filter out subjects with fewer than 72 images
    def filter_by_image_count(images, meshes, coords, subject_ids):
        return tf.shape(images)[0] >= 144

    grouped_dataset = grouped_dataset.filter(filter_by_image_count)

    # Step 3: Transform each batch to include masks
    def add_masks_to_batch(images, meshes, coords, subject_ids):

        actual_batch_size = tf.shape(images)[0]

        cal_mask = tf.zeros(MAX_TARGETS, dtype=tf.int8)
        target_mask = tf.zeros(MAX_TARGETS, dtype=tf.int8)

        # Determine how many calibration images to use (random between min and max)
        n_cal_images = tf.random.uniform(
            shape=[],
            minval=MIN_CAL_POINTS,
            maxval=MAX_CAL_POINTS,
            dtype=tf.int32
        )

        # Create random indices for calibration images
        # NOTE: These will be different each time through the dataset
        random_indices = tf.random.shuffle(tf.range(actual_batch_size))
        cal_indices = random_indices[:n_cal_images]

        # Create masks (1 = included, 0 = excluded)
        cal_mask = tf.scatter_nd(
            tf.expand_dims(cal_indices, 1),
            tf.ones(n_cal_images, dtype=tf.int8),
            [MAX_TARGETS]
        )
        target_mask = 1 - cal_mask

        # Pad everything to fixed size

        padded_images = tf.pad(
            tf.reshape(images, (-1, 36, 144, 1)),
            [[0, MAX_TARGETS - actual_batch_size], [0, 0], [0, 0], [0, 0]]
        )
        padded_coords = tf.pad(
            coords,
            [[0, MAX_TARGETS - actual_batch_size], [0, 0]]
        )

        # Ensure all shapes are fixed
        padded_images = tf.ensure_shape(padded_images, [MAX_TARGETS, 36, 144, 1])
        padded_coords = tf.ensure_shape(padded_coords, [MAX_TARGETS, 2])
        padded_cal_mask = tf.ensure_shape(cal_mask, [MAX_TARGETS])
        padded_target_mask = tf.ensure_shape(target_mask, [MAX_TARGETS])

        return (padded_images, padded_coords, padded_cal_mask, padded_target_mask), padded_coords

    # Apply the transformation
    masked_dataset = grouped_dataset.map(
        lambda imgs, meshes, coords, subj_ids: add_masks_to_batch(imgs, meshes, coords, subj_ids),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    return masked_dataset

# Modified model input preparation function
def prepare_model_inputs(features, labels):
    """
    Restructure the inputs for the model, using the mask-based approach.

    Args:
        features: Tuple of (images, coords, cal_mask, target_mask)
        labels: The target coordinates

    Returns:
        Dictionary of inputs for the model and the target labels
    """
    images, coords, cal_mask, target_mask = features

    # Use masking to create the model inputs
    inputs = {
        "Input_All_Images": images,                   # All images (both cal and target)
        "Input_All_Coords": coords,                   # All coordinates
        "Input_Calibration_Mask": cal_mask,           # Mask indicating calibration images
        "Input_Target_Mask": target_mask              # Mask indicating target images
    }

    return inputs, labels

In [17]:
masked_dataset = prepare_masked_dataset(train_data_rescaled)

In [18]:
train_ds_for_model = masked_dataset.map(
    prepare_model_inputs,
    num_parallel_calls=tf.data.AUTOTUNE
).shuffle(200).prefetch(tf.data.AUTOTUNE)

## Visualization

In [19]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import tensorflow as tf
from IPython.display import HTML
import base64

def visualize_eye_tracking_data(dataset, interval=100, figsize=(10, 6)):
    """
    Visualize eye tracking data by animating through a subject's images and displaying
    where they are looking with a red dot. Returns an HTML animation for direct display
    in a Colab notebook.

    Args:
        dataset: A TensorFlow dataset containing a single batch of one subject's data
                Expected format: ({"Input_All_Images": images, "Input_All_Coords": coords,
                                 "Input_Calibration_Mask": cal_mask, "Input_Target_Mask": target_mask}, labels)
                OR: Your model's dataset with one subject batch
        interval: Time between frames in milliseconds
        figsize: Size of the figure (width, height)

    Returns:
        IPython.display.HTML object with the animation

    Example usage in a Colab notebook:
    ```python
    # For a dataset from a single subject
    subject_dataset = train_ds_for_model.take(1)

    # Or you can create a dataset for just one subject (example from your notebook):
    # This line extracts data for one subject from the complete dataset
    single_subject = train_ds_for_model.filter(
        lambda inputs, labels: tf.equal(inputs["Input_Subject_ID"][0], subject_id)
    ).take(1)

    # Display the animation directly in the notebook
    from IPython.display import display
    animation = visualize_eye_tracking_data(single_subject)
    display(animation)
    ```
    """
    # Extract data from the dataset
    for inputs, labels in dataset.take(1):
        images = inputs["Input_All_Images"]
        coords = inputs["Input_All_Coords"]

        # Convert to numpy arrays
        valid_images = images.numpy()
        valid_coords = coords.numpy()

        # Sort by X coordinate, then by Y coordinate
        sorted_indices = np.lexsort((valid_coords[:, 0], valid_coords[:, 1]))
        valid_images = valid_images[sorted_indices]
        valid_coords = valid_coords[sorted_indices]

        num_images = valid_images.shape[0]

    # Set up the figure
    plt.ioff()  # Turn off interactive mode to avoid displaying during generation
    fig, ax = plt.subplots(figsize=figsize)

    # Create a function that draws each frame
    def draw_frame(frame_num):
        ax.clear()

        # Set up the 16:9 coordinate space
        ax.set_xlim(0, 100)
        ax.set_ylim(100, 0)  # 16:9 aspect ratio

        # Get the current image and coordinates
        img = valid_images[frame_num].squeeze()
        x, y = valid_coords[frame_num]

        # Draw rectangle for the screen
        rect = plt.Rectangle((0, 0), 100, 100, fill=False, color='black', linewidth=2)
        ax.add_patch(rect)

        # Add image display in the top right corner
        img_display = ax.inset_axes([0.35, 0.35, 0.3, 0.3], transform=ax.transAxes)
        img_display.imshow(np.fliplr(img), cmap='gray')
        img_display.axis('off')

        # Draw red dot at the coordinate - make it bigger for visibility
        ax.scatter(x * 100, y * 100, color='red', s=150, zorder=5,
                  edgecolor='white', linewidth=1.5)

        # Draw crosshair
        ax.axhline(y * 100, color='gray', linestyle='--', alpha=0.5, zorder=1)
        ax.axvline(x * 100, color='gray', linestyle='--', alpha=0.5, zorder=1)

        # Set title
        ax.set_title('Eye Tracking Visualization')

        # Set axis labels
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')

    # Create the animation
    anim = animation.FuncAnimation(fig, draw_frame, frames=num_images, interval=interval)

  # Use animation's to_jshtml method which directly generates HTML
    html_animation = anim.to_jshtml()

    # Clean up
    plt.close(fig)

    # Create an HTML object that can be displayed in the notebook
    return HTML(html_animation)


In [20]:
visualize_eye_tracking_data(train_ds_for_model.skip(1000).take(1))

# Create custom layers

In [21]:
class SimpleTimeDistributed(keras.layers.Wrapper):
    """A simplified version of TimeDistributed that applies a layer to every temporal slice of an input.

    This implementation avoids for loops by using reshape operations to apply the wrapped layer
    to all time steps at once.

    Args:
        layer: a `keras.layers.Layer` instance.
    """

    def __init__(self, layer, **kwargs):
        super().__init__(layer, **kwargs)
        self.supports_masking = getattr(layer, 'supports_masking', False)

    def build(self, input_shape):
        # Validate input shape has at least 3 dimensions (batch, time, ...)
        if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3:
            raise ValueError(
                "`SimpleTimeDistributed` requires input with at least 3 dimensions"
            )

        # Build the wrapped layer with shape excluding the time dimension
        super().build((input_shape[0], *input_shape[2:]))
        self.built = True

    def compute_output_shape(self, input_shape):
        # Get output shape by applying the layer to a single time slice
        child_output_shape = self.layer.compute_output_shape((input_shape[0], *input_shape[2:]))
        # Include time dimension in the result
        return (child_output_shape[0], input_shape[1], *child_output_shape[1:])

    def call(self, inputs, training=None):
        input_shape = ops.shape(inputs)
        batch_size = input_shape[0]
        time_steps = input_shape[1]

        # Reshape inputs to combine batch and time dimensions: (batch*time, ...)
        reshaped_inputs = ops.reshape(inputs, (-1, *input_shape[2:]))

        # Apply the layer to all time steps at once
        outputs = self.layer.call(reshaped_inputs, training=training)

        # Get output dimensions
        output_shape = ops.shape(outputs)

        # Reshape back to include the separate batch and time dimensions: (batch, time, ...)
        return ops.reshape(outputs, (batch_size, time_steps, *output_shape[1:]))

In [22]:
# cal_points = tf.constant([
#     [5, 5],
#     [5, 27.5],
#     [5, 50],
#     [5, 72.5],
#     [5, 95],
#     [35, 5],
#     [35, 27.5],
#     [35, 50],
#     [35, 72.5],
#     [35, 95],
#     [65, 5],
#     [65, 27.5],
#     [65, 50],
#     [65, 72.5],
#     [65, 95],
#     [95, 5],
#     [95, 27.5],
#     [95, 50],
#     [95, 72.5],
#     [95, 95],
# ], dtype=tf.float32)

# scaled_cal_points = tf.divide(cal_points, tf.constant([100.]))


In [23]:
class MaskedWeightedRidgeRegressionLayer(keras.layers.Layer):
    """
    A custom layer that performs weighted ridge regression with proper masking support.

    This layer takes embeddings, coordinates, weights, and calibration mask as explicit inputs,
    while using Keras' masking system to handle target masking. This separation allows more
    precise control over calibration points while leveraging Keras' built-in mask propagation
    for target predictions.

    Args:
        lambda_ridge (float): Regularization parameter for ridge regression
        embedding_dim (int): Dimension of the embedding vectors
        epsilon (float): Small constant for numerical stability
    """
    def __init__(self, lambda_ridge, **kwargs):
        self.lambda_ridge = lambda_ridge
        super(MaskedWeightedRidgeRegressionLayer, self).__init__(**kwargs)

    def call(self, inputs, mask=None):
        """
        The forward pass of the layer.

        Args:
            inputs: A list containing:
                - embeddings: Embeddings for all points (batch_size, n_points, embedding_dim)
                - coords: Coordinates for all points (batch_size, n_points, 2)
                - calibration_weights: Importance weights (batch_size, n_points, 1)
                - cal_mask: Mask for calibration points (batch_size, n_points) [EXPLICIT]

        Returns:
            Predicted coordinates for the target points (batch_size, n_points, 2)
        """
        # Unpack inputs
        embeddings, coords, calibration_weights, cal_mask  = inputs

        embeddings = ops.cast(embeddings, "float32")
        coords = ops.cast(coords, "float32")
        calibration_weights = ops.cast(calibration_weights, "float32")
        cal_mask = ops.cast(cal_mask, "float32")

        # reshape weights to (batch, calibration)
        w = ops.squeeze(calibration_weights, axis=-1)

        # Pre-compute masked weights for calibration points
        w_masked = w * cal_mask
        w_sqrt = ops.sqrt(w_masked)
        w_sqrt = ops.expand_dims(w_sqrt, -1)  # More efficient than reshape

        # Apply calibration mask to embeddings
        cal_mask_expand = ops.expand_dims(cal_mask, -1)
        X = embeddings * cal_mask_expand

        # Weight calibration embeddings and coordinates
        X_weighted = X * w_sqrt
        y_weighted = coords * w_sqrt

        # Matrix operations
        X_t = ops.transpose(X_weighted, axes=[0, 2, 1])
        X_t_X = ops.matmul(X_t, X_weighted)

        # Add regularization
        lhs = X_t_X + self.lambda_ridge * ops.eye(embeddings.shape[-1])

        # Compute RHS
        rhs = ops.matmul(X_t, y_weighted)

        # Solve the system
        kernel = ops.linalg.solve(lhs, rhs)

        # Apply regression
        output = ops.matmul(embeddings, kernel)

        return output

    def compute_output_shape(self, input_shapes):
        """
        Computes the output shape of the layer.

        Args:
            input_shapes: List of input shapes

        Returns:
            Output shape tuple
        """
        return (input_shapes[0][0], input_shapes[0][1], 2)

    def get_config(self):
        """
        Returns the configuration of the layer for serialization.

        Returns:
            Dictionary containing the layer configuration
        """
        config = super(MaskedWeightedRidgeRegressionLayer, self).get_config()
        config.update({
            "lambda_ridge": self.lambda_ridge,
        })
        return config

In [24]:
def normalized_weighted_euc_dist(y_true, y_pred):
    """Custom loss function that calculates a weighted Euclidean distance between two sets of points,
    respecting masks for padded sequences.

    Weighting multiplies the x-coordinate of each input by 1.778 (derived from the 16:9 aspect ratio of most laptops)
    to match the scale of the y-axis. For interpretability, the distances are normalized to the diagonal
    such that the maximum distance between two points (corner to corner) is 100.

    :param y_true (tensor): A tensor of shape (seq_len, 2) containing ground-truth x- and y- coordinates
    :param y_pred (tensor): A tensor of shape (seq_len, 2) containing predicted x- and y- coordinates

    :returns: A tensor of shape (1,) with the weighted and normalized euclidean distances between valid points.
    """
    y_true = ops.cast(y_true, "float32")
    y_pred = ops.cast(y_pred, "float32")

    # Weighting treats y-axis as unit-scale and creates a rectangle that's 177.8x100 units.
    x_weight = ops.convert_to_tensor([1.778, 1.0], dtype="float32")

    # Multiply x-coordinate by 16/9 = 1.778
    y_true_weighted = ops.multiply(x_weight, y_true)
    y_pred_weighted = ops.multiply(x_weight, y_pred)

    # Calculate Euclidean distance with weighted coordinates
    squared_diff = ops.square(y_pred_weighted - y_true_weighted)
    squared_dist = ops.sum(squared_diff, axis=-1)
    dist = ops.sqrt(squared_dist)

    # Euclidean Distance from [0,0] to [1.778, 1.00] = 2.03992
    # Divide by 2.03992, mult by 100 is same as / .0203992
    # Normalizes loss values to the diagonal-- makes loss easier to interpret
    normalized_dist = ops.divide(dist, .0203992)

    return normalized_dist

# Model Training


### Setup Stuff

Cols = 5, 11, 17, 23, 29, 35, 41, 47, 53, 59, 65, 71, 77, 83, 89, 95

Rows = 5, 16.25, 27.5, 38.75, 50, 61.25, 72.5, 83.75, 95  


Fixed Points as calibration

In [25]:
def create_embedding_model():
  image_shape = (36, 144, 1)
  input_eyes = keras.layers.Input(shape=image_shape)

  eyes_rescaled = keras.layers.Rescaling(scale=1./255)(input_eyes)

  # Continue with the backbone
  backbone = keras_hub.models.DenseNetBackbone(
      stackwise_num_repeats=DENSE_NET_STACKWISE_NUM_REPEATS,
      image_shape=image_shape,
  )
  backbone_encoder = backbone(eyes_rescaled)

  # backbone = keras.Sequential([
  #     keras.layers.Flatten(),
  #     keras.layers.Dense(10, activation="relu")
  # ])

  # backbone_encoder = backbone(input_eyes)

  # backbone = keras_hub.models.MiTBackbone(
  #   image_shape=(36,144,1),
  #   layerwise_depths=[2,2,2,2],
  #   num_layers=4,
  #   layerwise_num_heads=[1,2,5,8],
  #   layerwise_sr_ratios=[8,4,2,1],
  #   max_drop_path_rate=0.1,
  #   layerwise_patch_sizes=[7,3,3,3],
  #   layerwise_strides=[4,2,2,2],
  #   hidden_dims=[32,64,160,256]
  # )
  # backbone_encoder = backbone(input_eyes)

  flatten_compress = keras.layers.Flatten()(backbone_encoder)
  eye_embedding = keras.layers.Dense(units=EMBEDDING_DIM, activation="tanh")(flatten_compress)

  embedding_model = keras.Model(inputs=input_eyes, outputs=eye_embedding, name="Eye_Image_Embedding")

  return embedding_model

In [26]:
e = create_embedding_model()

e.summary()

In [27]:
class MaskInspectorLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MaskInspectorLayer, self).__init__(**kwargs)
        self.supports_masking = True
    def call(self, inputs, mask=None):
      # Print mask information
      tf.print("Layer mask:", mask)
      return inputs

In [28]:
# Modified model to work with the new masked input approach
def create_masked_model():
    """
    Create a model that uses masks to distinguish calibration and target images.

    This implementation:
    - Uses explicit calibration mask as a regular input
    - Integrates target masking with Keras' masking system
    - Properly propagates the target mask through the network
    """

    # Define inputs
    input_all_images = keras.layers.Input(
        shape=(MAX_TARGETS, 36, 144, 1),
        name="Input_All_Images"
    )

    input_all_coords = keras.layers.Input(
        shape=(MAX_TARGETS, 2),
        name="Input_All_Coords"
    )

    # Calibration mask is a regular input (not using Keras masking)
    input_cal_mask = keras.layers.Input(
        shape=(MAX_TARGETS,),
        name="Input_Calibration_Mask",
    )

    # Target mask gets proper Keras masking
    input_target_mask = keras.layers.Input(
        shape=(MAX_TARGETS,),
        name="Input_Target_Mask",
    )

    # Create the embedding model
    embedding_model = create_embedding_model()

    # Apply the embedding model to all images
    all_embeddings = SimpleTimeDistributed(embedding_model, name="Image_Embeddings")(input_all_images)

    # Calculate importance weights for calibration points
    calibration_weights = keras.layers.Dense(
        1,
        activation="sigmoid",
        name="Calibration_Weights"
    )(all_embeddings)

    ridge = MaskedWeightedRidgeRegressionLayer(
        RIDGE_REGULARIZATION,
        name="Regression"
    )(
        # Inputs to the regression layer
        [
            all_embeddings,           # Embeddings for all points
            input_all_coords,         # Coordinates for all points
            calibration_weights,      # Weights for calibration importance
            input_cal_mask,            # Explicit calibration mask
        ],
    )


    reshaped_target_mask = keras.layers.Reshape((-1,1), name="Reshaped_Target_Mask")(input_target_mask)
    mask = keras.layers.Masking(mask_value=0, name="Mask")(reshaped_target_mask)
    output = keras.layers.Multiply(name="Masked_Predictions")([ridge, mask])
    # mask_inspector = MaskInspectorLayer(name="Mask_Inspector")
    # output = mask_inspector(output)

    # Create the full model with proper masking connections
    full_model = keras.Model(
        inputs=[
            input_all_images,
            input_all_coords,
            input_cal_mask,
            input_target_mask
        ],
        outputs=ridge,
        name="MaskedEyePredictionModel"
    )

    return full_model

### Build stuff

In [33]:
mask_model = create_masked_model()

In [34]:
mask_model.summary()

In [35]:
mask_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=normalized_weighted_euc_dist,
    jit_compile=False
)

In [37]:
mask_model.fit(train_ds_for_model.batch(1), epochs=8)

Epoch 1/8
[1m1564/1564[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 91ms/step - loss: 5.8228
Epoch 2/8
[1m1564/1564[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m201s[0m 91ms/step - loss: 5.6911
Epoch 3/8
[1m1564/1564[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 91ms/step - loss: 5.4798
Epoch 4/8
[1m1564/1564[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 91ms/step - loss: 5.4987
Epoch 5/8
[1m1564/1564[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 90ms/step - loss: 5.4561
Epoch 6/8
[1m1564/1564[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 91ms/step - loss: 5.4443
Epoch 7/8
[1m1564/1564[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 90ms/step - loss: 5.2526
Epoch 8/8
[1m1564/1564[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m202s[0m 91ms/step - loss: 5.3230


<keras.src.callbacks.history.History at 0x780ad6c324d0>

In [None]:
r = mask_model.predict(train_ds_for_model.take(1).batch(1))
#y = sub_model.predict(train_ds_for_model.take(1).batch(1))

In [None]:
r

# Test the new layers


In [None]:
import jax.numpy as jnp

def test_masked_normalized_weighted_euc_dist():
    """Test suite for masked_normalized_weighted_euc_dist function"""

    print("Running tests for masked_normalized_weighted_euc_dist:")

    # Test 1: Basic functionality - identical inputs should have zero distance
    y_true = jnp.array([[[1.0, 2.0], [3.0, 4.0]]])
    y_pred = jnp.array([[[1.0, 2.0], [3.0, 4.0]]])
    loss = masked_normalized_weighted_euc_dist(y_true, y_pred)
    print(f"Test 1 - Identical inputs: {float(loss):.6f}")
    assert np.isclose(float(loss), 0.0), "Identical inputs should have zero distance"

    # Test 2: Masking - verify padded points are ignored
    y_true = jnp.array([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]])
    y_pred = jnp.array([[[1.1, 2.1], [0.0, 0.0], [5.1, 6.1]]])  # Middle point is padded
    loss = masked_normalized_weighted_euc_dist(y_true, y_pred)
    print(f"Test 2 - Masked sequence: {float(loss):.6f}")

    # Count non-padded elements (should be 2)
    mask = jnp.any(jnp.not_equal(y_pred, 0), axis=-1).astype(jnp.float32)
    num_valid = float(jnp.sum(mask))
    print(f"  Number of valid (non-padded) points: {num_valid}")
    assert num_valid == 2, "Expected 2 non-padded points"

    # Test 3: Batch processing with mixed padding
    y_true = jnp.array([
        [[1.0, 2.0], [3.0, 4.0]],  # Sample 1
        [[5.0, 6.0], [7.0, 8.0]]   # Sample 2
    ])
    y_pred = jnp.array([
        [[1.0, 2.0], [0.0, 0.0]],  # Sample 1 with second point padded
        [[5.1, 6.1], [7.1, 8.1]]   # Sample 2 with no padding
    ])

    loss = masked_normalized_weighted_euc_dist(y_true, y_pred)
    print(f"Test 3 - Batch handling: {float(loss):.6f}")

    # Count non-padded elements (should be 3: 1 from first batch, 2 from second)
    mask = jnp.any(jnp.not_equal(y_pred, 0), axis=-1).astype(jnp.float32)
    num_valid = float(jnp.sum(mask))
    print(f"  Number of valid (non-padded) points: {num_valid}")
    assert num_valid == 3, "Expected 3 non-padded points in batch"

    # Test 4: Edge case - all points padded (should not divide by zero)
    y_true = jnp.array([[[1.0, 2.0], [3.0, 4.0]]])
    y_pred = jnp.array([[[0.0, 0.0], [0.0, 0.0]]])  # All padded
    loss = masked_normalized_weighted_euc_dist(y_true, y_pred)
    print(f"Test 4 - All padded: {float(loss):.6f}")
    assert np.isclose(float(loss), 0.0), "All padded points should result in zero loss"

    # Test 5: Verify mask application by comparing with manual calculation
    y_true = jnp.array([
        [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]
    ])
    y_pred = jnp.array([
        [[1.1, 2.1], [0.0, 0.0], [5.1, 6.1]]  # Middle point padded
    ])

    # Manual calculation to verify masking
    manual_mask = jnp.array([[1.0, 0.0, 1.0]])  # Only first and third points are valid

    # Apply function with automatic masking
    auto_masked_loss = masked_normalized_weighted_euc_dist(y_true, y_pred)

    # Manual calculation of distance for comparison
    y_true_float = y_true.astype(jnp.float32)
    y_pred_float = y_pred.astype(jnp.float32)

    x_weight = jnp.array([1.778, 1.0], dtype=jnp.float32)
    y_true_weighted = y_true_float * x_weight
    y_pred_weighted = y_pred_float * x_weight

    squared_diff = jnp.square(y_pred_weighted - y_true_weighted)
    squared_dist = jnp.sum(squared_diff, axis=-1)
    dist = jnp.sqrt(squared_dist)

    normalized_dist = (dist / jnp.array(203.992, dtype=jnp.float32)) * 100

    # Apply manual mask and calculate mean
    masked_dist = normalized_dist * manual_mask
    manual_loss = jnp.sum(masked_dist) / jnp.sum(manual_mask)

    print(f"Test 5 - Masking verification:")
    print(f"  Auto-masked loss: {float(auto_masked_loss):.6f}")
    print(f"  Manual masked loss: {float(manual_loss):.6f}")

    assert np.isclose(float(auto_masked_loss), float(manual_loss), rtol=1e-5), \
        "Auto-masking should match manual masking calculation"

    print("All tests passed successfully!")

In [None]:
test_masked_normalized_weighted_euc_dist()

In [None]:
import jax.numpy as jnp

def test_simple_time_distributed():
    batch_size = 5
    time_steps = 4
    feature_dim = 5
    output_dim = 3

    # Use a Lambda layer that adds 1 to each element
    wrapped_layer = keras.layers.Lambda(lambda x: x + 1)
    time_distributed_layer = SimpleTimeDistributed(wrapped_layer)

    # Create structured input
    inputs = tf.constant(np.arange(batch_size * time_steps * feature_dim).reshape(batch_size, time_steps, feature_dim), dtype=tf.float32)

    # Apply the layer
    outputs = time_distributed_layer(inputs)

    # Expected output
    expected_outputs = inputs + 1

    # Check if shape is preserved
    shape_correct = outputs.shape == inputs.shape
    print("Shape preserved:", shape_correct)

    # Check if elements are in correct order
    values_correct = np.array_equal(jnp.array(outputs), jnp.array(expected_outputs))
    print("Order preserved:", values_correct)

    # Print results
    print("\nInput:\n", inputs.numpy())
    print("\nOutput:\n", jnp.array(outputs))
    print("\nExpected Output:\n", expected_outputs.numpy())

# Run the test
test_simple_time_distributed()

In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import unittest

class TestSimpleTimeDistributed(unittest.TestCase):

    def test_basic_functionality(self):
        """Test the layer with a simple Dense layer."""
        # Create a model with SimpleTimeDistributed
        inputs = keras.Input(shape=(3, 4))  # (batch, time, features)
        dense_layer = keras.layers.Dense(5, activation='relu')
        x = SimpleTimeDistributed(dense_layer)(inputs)
        model = keras.Model(inputs, x)

        # Generate sample input
        batch_size = 2
        sample_input = np.random.random((batch_size, 3, 4))

        # Get the output
        output = model.predict(sample_input)

        # Verify output shape
        self.assertEqual(output.shape, (batch_size, 3, 5))

        # Manually apply the dense layer to each time step and verify results
        manual_output = np.zeros((batch_size, 3, 5))
        for b in range(batch_size):
            for t in range(3):
                manual_output[b, t] = dense_layer(np.expand_dims(sample_input[b, t], 0)).numpy()

        # Check outputs are the same
        np.testing.assert_allclose(output, manual_output, rtol=1e-5)

    def test_shape_handling(self):
        """Test that the layer correctly handles different input shapes."""
        # Test with different input shapes
        shapes = [
            (3, 5),      # (time, features)
            (3, 5, 6),   # (time, height, width)
            (3, 5, 6, 2) # (time, height, width, channels)
        ]

        for shape in shapes:
            # Create a model with SimpleTimeDistributed
            inputs = keras.Input(shape=shape)

            # Choose appropriate layer based on input dimensions
            if len(shape) == 2:
                layer = keras.layers.Dense(4)
                expected_output_shape = (None, shape[0], 4)
            elif len(shape) == 3:
                layer = keras.layers.Conv1D(4, 3, padding='same')
                expected_output_shape = (None, shape[0], shape[1], 4)
            else:
                layer = keras.layers.Conv2D(4, 3, padding='same')
                expected_output_shape = (None, shape[0], shape[1], shape[2], 4)

            x = SimpleTimeDistributed(layer)(inputs)
            model = keras.Model(inputs, x)

            # Verify output shape
            self.assertEqual(model.output_shape, expected_output_shape)

    def test_complex_layer(self):
        """Test with a more complex layer like Conv2D."""
        # Create a model with SimpleTimeDistributed wrapping Conv2D
        inputs = keras.Input(shape=(10, 28, 28, 3))  # (batch, time, height, width, channels)
        conv_layer = keras.layers.Conv2D(16, 3, padding='same', activation='relu')
        x = SimpleTimeDistributed(conv_layer)(inputs)
        model = keras.Model(inputs, x)

        # Generate sample input
        batch_size = 2
        sample_input = np.random.random((batch_size, 10, 28, 28, 3))

        # Get the output
        output = model.predict(sample_input)

        # Verify output shape
        self.assertEqual(output.shape, (batch_size, 10, 28, 28, 16))

        # Manually apply the conv layer to each time step and verify results
        manual_output = np.zeros((batch_size, 10, 28, 28, 16))
        for b in range(batch_size):
            for t in range(10):
                manual_output[b, t] = conv_layer(np.expand_dims(sample_input[b, t], 0)).numpy()

        # Check outputs are the same
        np.testing.assert_allclose(output, manual_output, rtol=1e-5)

    def test_training_mode(self):
        """Test that training mode is correctly passed to the wrapped layer."""
        # Create a dropout layer which behaves differently in training vs. inference
        dropout_layer = keras.layers.Dropout(0.5)

        # Create a model with SimpleTimeDistributed
        inputs = keras.Input(shape=(5, 10))
        x = SimpleTimeDistributed(dropout_layer)(inputs)
        model = keras.Model(inputs, x)

        # Generate sample input
        sample_input = np.ones((1, 5, 10))

        # Output during training (with dropout applied) should not be equal to the input
        output_training = model(sample_input, training=True).numpy()
        # Since dropout is random, we can't guarantee outputs will differ, but it's highly likely
        # Instead of directly comparing, we'll check if any element was dropped (set to 0)
        self.assertTrue(np.any(output_training == 0))

        # Output during inference (no dropout) should be equal to the input
        output_inference = model(sample_input, training=False).numpy()
        np.testing.assert_array_equal(output_inference, sample_input)

    def test_error_handling(self):
        """Test that appropriate errors are raised for invalid inputs."""
        # Test with invalid input shape (less than 3 dimensions)
        with self.assertRaises(ValueError):
            layer = SimpleTimeDistributed(keras.layers.Dense(10))
            layer.build((None, 10))  # Missing time dimension

    def test_variable_time_steps(self):
        """Test that the layer correctly handles variable time steps."""
        # Create a model with variable time steps
        inputs = keras.Input(shape=(None, 10))  # Variable time steps
        x = SimpleTimeDistributed(keras.layers.Dense(5))(inputs)
        model = keras.Model(inputs, x)

        # Test with different sequence lengths
        for seq_len in [3, 5, 7]:
            sample_input = np.random.random((2, seq_len, 10))
            output = model.predict(sample_input)
            self.assertEqual(output.shape, (2, seq_len, 5))

    def test_multiple_layer_stacking(self):
        """Test that the layer works when stacked with other layers."""
        # Create a model with multiple SimpleTimeDistributed layers
        inputs = keras.Input(shape=(10, 20))
        x = SimpleTimeDistributed(keras.layers.Dense(15))(inputs)
        x = SimpleTimeDistributed(keras.layers.Dense(10))(x)
        x = SimpleTimeDistributed(keras.layers.Dense(5))(x)
        model = keras.Model(inputs, x)

        # Generate sample input
        sample_input = np.random.random((2, 10, 20))

        # Get the output
        output = model.predict(sample_input)

        # Verify output shape
        self.assertEqual(output.shape, (2, 10, 5))

    def test_masking_support_inheritance(self):
        """Test that SimpleTimeDistributed correctly inherits masking support."""
        # Create layers with and without masking support
        lstm_layer = keras.layers.LSTM(10, return_sequences=True)  # Supports masking
        dense_layer = keras.layers.Dense(10)  # Does not support masking by default

        # Check that SimpleTimeDistributed inherits masking support correctly
        time_distributed_lstm = SimpleTimeDistributed(lstm_layer)
        time_distributed_dense = SimpleTimeDistributed(dense_layer)

        self.assertTrue(time_distributed_lstm.supports_masking)
        self.assertFalse(time_distributed_dense.supports_masking)

    def test_comparison_with_keras_time_distributed(self):
        """Compare outputs with Keras's built-in TimeDistributed."""
        # Create a dense layer with fixed weights for deterministic comparison
        dense_layer = keras.layers.Dense(10,
                                        kernel_initializer='ones',
                                        bias_initializer='zeros',
                                        use_bias=True)

        # Create a model with SimpleTimeDistributed
        inputs1 = keras.Input(shape=(5, 8))
        x1 = SimpleTimeDistributed(dense_layer)(inputs1)
        model1 = keras.Model(inputs1, x1)

        # Create a model with Keras TimeDistributed
        # Here we'll need to clone the dense layer to avoid weight sharing
        dense_layer_clone = keras.layers.Dense(10,
                                             kernel_initializer='ones',
                                             bias_initializer='zeros',
                                             use_bias=True)
        inputs2 = keras.Input(shape=(5, 8))
        x2 = keras.layers.TimeDistributed(dense_layer_clone)(inputs2)
        model2 = keras.Model(inputs2, x2)

        # Ensure weights are the same
        dense_layer_clone.set_weights(dense_layer.get_weights())

        # Generate sample input
        sample_input = np.random.random((2, 5, 8))

        # Get outputs
        output1 = model1.predict(sample_input)
        output2 = model2.predict(sample_input)

        # Check outputs are the same
        np.testing.assert_allclose(output1, output2, rtol=1e-5)

    def test_weights_training(self):
        """Test that the weights of the wrapped layer are updated correctly during training."""
        # Create a dense layer with specific initializers for easy tracking
        dense_layer = keras.layers.Dense(1,
                                        kernel_initializer='ones',
                                        bias_initializer='zeros')

        # Create a model with SimpleTimeDistributed
        inputs = keras.Input(shape=(3, 2))
        x = SimpleTimeDistributed(dense_layer)(inputs)
        model = keras.Model(inputs, x)

        # Compile the model
        model.compile(optimizer='sgd', loss='mse')

        # Get initial weights
        initial_weights = [w.copy() for w in dense_layer.get_weights()]

        # Generate some training data
        x_train = np.ones((10, 3, 2))
        y_train = np.zeros((10, 3, 1))  # Target: opposite of what initial weights would produce

        # Train the model for a few epochs
        model.fit(x_train, y_train, epochs=5, verbose=0)

        # Get updated weights
        updated_weights = dense_layer.get_weights()

        # Check that weights have been updated
        for i in range(len(initial_weights)):
            self.assertFalse(np.array_equal(initial_weights[i], updated_weights[i]))

## Create augmentation pipeline

In [None]:
augmentation_layers = [
    #keras.layers.RandomBrightness(factor=0.1),
    #keras.layers.RandomContrast(factor=0.1),
    keras.layers.RandomRotation(factor=.05, fill_mode='constant', fill_value=0),
    #keras.layers.RandomTranslation(height_factor=0.05, width_factor=0.05)
]

augmentation_model = keras.Sequential(augmentation_layers)

@tf.function
def apply_consistent_augmentations(inputs, augmentation_model):
    """Apply consistent augmentations to calibration and target images.

    Args:
        inputs: Tuple of (cal_imgs, cal_coords, cal_mask, target_imgs, target_mask)
        augmentation_fn: Function created by create_augmentation_fn

    Returns:
        Tuple with augmented images in the same structure
    """
    cal_imgs, cal_coords, cal_mask, target_imgs, target_mask = inputs

    merged_cal_images = tf.transpose(cal_imgs, perm=[3, 1, 2, 0])
    merged_target_images = tf.transpose(target_imgs, perm=[3, 1, 2, 0])

    n_cal_images = MAX_CAL_POINTS
    n_target_images = MAX_TARGETS

    merged_all_images = tf.concat([merged_cal_images, merged_target_images], axis=-1)

    merged_all_images_aug = augmentation_model(merged_all_images)

    merged_cal_imgs_aug = merged_all_images_aug[..., :n_cal_images]
    merged_target_imgs_aug = merged_all_images_aug[..., n_cal_images:]

    cal_imgs_aug = tf.transpose(merged_cal_imgs_aug, perm=[3, 1, 2, 0])
    target_imgs_aug = tf.transpose(merged_target_imgs_aug, perm=[3, 1, 2, 0])

    return cal_imgs_aug, cal_coords, cal_mask, target_imgs_aug, target_mask

In [None]:
t0 = train_data_rescaled.cache()

t1 = t0.group_by_window(
    key_func = lambda img, m, c, z: z,
    reduce_func = reducer_function,
    window_size = 200
)

t2 = t0.group_by_window(
    key_func = lambda img, m, c, z: z,
    reduce_func = reducer_function_fixed_pts_with_id,
    window_size = 200
)

# Apply to dataset - ensure consistent shapes
train_ds = t1.map(
    lambda x, y: (apply_consistent_augmentations(x, augmentation_model), y),
    num_parallel_calls=tf.data.AUTOTUNE
)

# Update the batching step to use a custom padded_batch function
def prepare_batch_for_model(features, labels):
    """Ensure inputs are correctly formatted for the model"""
    cal_imgs, cal_coords, cal_mask, target_imgs, target_mask = features

    # Prepare inputs as a dictionary matching the model's expected inputs
    inputs = {
        "Input_Calibration_Eyes": cal_imgs,
        "Input_Calibration_Points": cal_coords,
        "Input_Calibration_Mask": cal_mask,
        "Input_Target_Eyes": target_imgs,
        "Input_Target_Mask": target_mask
    }

    return inputs, labels

# Use the updated version with proper input preparation
train_ds_for_model = train_ds.map(prepare_batch_for_model).prefetch(tf.data.AUTOTUNE)


In [None]:
train_ds.element_spec

## Visualize augmentation

In [None]:
def visualize_augmentations(dataset, num_examples=5):
    """Visualize augmentations from a tf.data.Dataset."""

    # Get a sample element from the dataset
    for element in dataset.take(1):
        # Access the calibration images from the element
        images = element[0][0]  # Access all calibration images

    # Calculate grid dimensions
    num_images = images.shape[0]  # Get the number of images in the batch
    grid_cols = 5  # Number of columns in the grid
    grid_rows = (num_images + grid_cols - 1) // grid_cols  # Calculate number of rows

    # Create a figure to display results
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(15, 3 * grid_rows))  # Create subplots in a grid

    # Display all images
    for i in range(num_images):
        row = i // grid_cols
        col = i % grid_cols
        axes[row, col].imshow(images[i].numpy().squeeze(), cmap='gray')  # Display image
        axes[row, col].set_title(f"Image {i + 1}")  # Set title for each subplot
        axes[row, col].axis('off')  # Turn off axis

    # Hide empty subplots if any
    for i in range(num_images, grid_rows * grid_cols):
        row = i // grid_cols
        col = i % grid_cols
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.show()

# Use this to check your augmentations
visualize_augmentations(train_ds)

# Save and export

In [None]:
full_model.save('full_model.keras')

In [None]:
full_model.save_weights('full_model.weights.h5')

In [None]:
wandb.save('full_model.keras')
wandb.save('full_model.weights.h5')

In [None]:
# Update the prediction code
t2_processed = t2.map(
    lambda x, y, id: (prepare_batch_for_model((x[0], x[1], x[2], x[3], x[4]), y)[0], y, id)
).prefetch(tf.data.AUTOTUNE)

# Initialize a list to store the batch losses
batch_losses = []
subject_ids = []
y_true = []

for e in t2_processed.batch(1).as_numpy_iterator():
    inputs = e[0]
    y_true.append(e[1])
    subject_ids.append(e[2][0])

y_true = np.array(y_true).reshape(-1, 144, 2)

# this step is slower than expected. not sure what's going on.
predictions = full_model.predict(t2_processed.batch(1))

for i in range(len(predictions)):
  loss = normalized_weighted_euc_dist(y_true[i], predictions[i]).numpy()
  batch_losses.append(loss)

batch_losses = np.array(batch_losses)

# Get mean per subject
batch_losses = np.mean(batch_losses, axis=1)

In [None]:
final_loss_table = wandb.Table(data=[[s, l] for (s, l) in zip(subject_ids, batch_losses)], columns=["subject", "scores"])

In [None]:
final_loss_hist = wandb.plot.histogram(final_loss_table, "scores", title="Normalized Euclidean Distance")

In [None]:
final_loss_mean = np.mean(batch_losses)

In [None]:
wandb.log({"final_val_loss_table": final_loss_table, "final_val_loss_hist": final_loss_hist, "final_loss_mean": final_loss_mean})

In [None]:
plt.figure(figsize=(8, 6))
plt.hist(batch_losses, bins=20, edgecolor='black')
plt.title('Histogram of Batch Losses')
plt.xlabel('Loss')
plt.ylabel('Frequency')
plt.show()

In [None]:
wandb.finish()