# Setup

In [None]:
!pip install osfclient --quiet
!pip install git+https://github.com/jspsych/eyetracking-utils.git --quiet
!pip install keras_cv --quiet
!pip install plotnine --quiet
!pip install wandb --quiet
!pip install albumentations --quiet

In [None]:
import os
import tensorflow as tf
import numpy as np
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow.keras as keras
import keras_cv
from plotnine import ggplot, geom_point, aes, geom_line, scale_y_reverse, theme_void, scale_color_manual
import pandas as pd
import wandb
from wandb.integration.keras import WandbMetricsLogger
import albumentations as A

import matplotlib.pyplot as plt
from google.colab import userdata



import et_util.dataset_utils as dataset_utils
import et_util.embedding_preprocessing as embed_pre
import et_util.model_layers as model_layers
from et_util import experiment_utils
from et_util.custom_loss import normalized_weighted_euc_dist
from et_util.model_analysis import plot_model_performance

In [None]:
os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
os.environ['OSF_TOKEN'] = userdata.get('osftoken')
os.environ['OSF_USERNAME'] = userdata.get('osfusername')

In [None]:
keras.mixed_precision.set_global_policy('mixed_float16')

# Configure W&B experiment

In [None]:
wandb.login()

In [None]:
# Fixed constants
MAX_TARGETS = 144

# Config constants
EMBEDDING_DIM = 200
RIDGE_REGULARIZATION = 0.1
TRAIN_EPOCHS = 30
MIN_CAL_POINTS = 8
MAX_CAL_POINTS = 40
DENSE_NET_STACKWISE_NUM_REPEATS = [4,4,4]
LEARNING_RATE = 0.001
AUGMENTATION = True

In [None]:
config = {
    "embedding_dim": EMBEDDING_DIM,
    "ridge_regularization": RIDGE_REGULARIZATION,
    "train_epochs": TRAIN_EPOCHS,
    "min_cal_points": MIN_CAL_POINTS,
    "max_cal_points": MAX_CAL_POINTS,
    "dense_net_stackwise_num_repeats": DENSE_NET_STACKWISE_NUM_REPEATS,
    "learning_rate": LEARNING_RATE,
    "augmentation": AUGMENTATION,
}

In [None]:
run = wandb.init(
    project='eye-tracking-dense-full-data-set-single-eye',
    config=config
)

# Download dataset from OSF

In [None]:
!osf -p 6b5cd fetch single_eye_tfrecords.tar.gz

# Process raw data records into TF Dataset

In [None]:
!mkdir single_eye_tfrecords
!tar -xf single_eye_tfrecords.tar.gz -C single_eye_tfrecords

In [None]:
def parse(element):
    """Process function that parses a tfr element in a raw dataset for process_tfr_to_tfds function.
    Gets mediapipe landmarks, raw image, image width, image height, subject id, and xy labels.
    Use for data generated with make_single_example_landmarks_and_jpg (i.e. data in
    jpg_landmarks_tfrecords.tar.gz)

    :param element: tfr element in raw dataset
    :return: image, label(x,y), landmarks, subject_id
    """

    data_structure = {
        'landmarks': tf.io.FixedLenFeature([], tf.string),
        'img_width': tf.io.FixedLenFeature([], tf.int64),
        'img_height': tf.io.FixedLenFeature([], tf.int64),
        'x': tf.io.FixedLenFeature([], tf.float32),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'eye_img': tf.io.FixedLenFeature([], tf.string),
        'subject_id': tf.io.FixedLenFeature([], tf.int64),
    }

    content = tf.io.parse_single_example(element, data_structure)

    landmarks = content['landmarks']
    raw_image = content['eye_img']
    width = content['img_width']
    height = content['img_height']
    depth = 3
    label = [content['x'], content['y']]
    subject_id = content['subject_id']

    landmarks = tf.io.parse_tensor(landmarks, out_type=tf.float32)
    landmarks = tf.reshape(landmarks, shape=(478, 3))

    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)

    return image, landmarks, label, subject_id

In [None]:
train_data, validation_data, test_data = dataset_utils.process_tfr_to_tfds(
    'single_eye_tfrecords/',
    parse,
    train_split=1.0,
    val_split=0.0,
    test_split=0.0,
    random_seed=12604,
    group_function=lambda img, landmarks, coords, z: z
)

## Rescale the `x,y` coordinates to be 0-1 instead of 0-100.

In [None]:
def rescale_coords_map(eyes, mesh, coords, id):
  return eyes, mesh, tf.divide(coords, tf.constant([100.])), id

In [None]:
train_data_rescaled = train_data.map(rescale_coords_map)
validation_data_rescaled = validation_data.map(rescale_coords_map)

# Create the weighted regression layer

In [None]:
class MaskedWeightedRidgeRegressionLayer(keras.layers.Layer):
    def __init__(self, lambda_ridge, embedding_dim, epsilon=1e-6, **kwargs):
        self.lambda_ridge = lambda_ridge
        self.epsilon = epsilon
        self.embedding_dim = embedding_dim
        super(MaskedWeightedRidgeRegressionLayer, self).__init__(**kwargs)
        self._identity = tf.eye(embedding_dim, dtype=tf.float32)

    @tf.function(jit_compile=True)
    def call(self, inputs):
        # Unpack inputs
        all_embeddings, all_coords, calibration_weights, cal_mask, target_mask = inputs

        # Cast all inputs to the same precision
        embeddings = tf.cast(all_embeddings, tf.float32)  # (batch_size, max_images, embedding_dim)
        coords = tf.cast(all_coords, tf.float32)          # (batch_size, max_images, 2)
        w = tf.cast(calibration_weights, tf.float32)      # (batch_size, max_images)
        w = tf.squeeze(w, axis=-1)
        cal_mask = tf.cast(cal_mask, tf.float32)          # (batch_size, max_images)
        target_mask = tf.cast(target_mask, tf.float32)    # (batch_size, max_images)

        # Get batch size and shape constants
        batch_size = tf.shape(embeddings)[0]
        max_images = tf.shape(embeddings)[1]

        # Create identity matrix broadcast to batch size
        I = tf.broadcast_to(self._identity, [batch_size, self.embedding_dim, self.embedding_dim])

        # Apply calibration mask to weights
        w = w * cal_mask

        # Reshape weights for broadcasting
        w_sqrt = tf.sqrt(w)
        w_sqrt = tf.reshape(w_sqrt, [batch_size, max_images, 1])  # For broadcasting

        # Apply calibration mask to get calibration embeddings
        cal_mask_expand = tf.reshape(cal_mask, [batch_size, max_images, 1])
        X = embeddings * cal_mask_expand  # Zero out non-calibration embeddings

        # Weight calibration embeddings and coordinates
        X_weighted = X * w_sqrt
        y_weighted = coords * w_sqrt

        # Matrix multiplication
        X_t = tf.transpose(X_weighted, perm=[0, 2, 1])  # (batch_size, embedding_dim, max_images)
        X_t_X = tf.matmul(X_t, X_weighted)              # (batch_size, embedding_dim, embedding_dim)

        # Add regularization
        ridge_term = tf.multiply(I, self.lambda_ridge + self.epsilon)
        lhs = tf.add(X_t_X, ridge_term)                 # (batch_size, embedding_dim, embedding_dim)

        # Compute right-hand side
        rhs = tf.matmul(X_t, y_weighted)                # (batch_size, embedding_dim, 2)

        # Solve linear system with stability check
        is_singular = tf.math.reduce_any(tf.math.is_nan(lhs))
        kernel = tf.cond(
            is_singular,
            lambda: tf.zeros([batch_size, self.embedding_dim, 2], dtype=tf.float32),
            lambda: tf.linalg.solve(lhs, rhs)
        )  # Shape: [batch_size, embedding_dim, 2]

        # Apply target mask to get target embeddings
        target_mask_expand = tf.reshape(target_mask, [batch_size, max_images, 1])
        target_embeddings = embeddings * target_mask_expand  # Zero out non-target embeddings

        # Apply regression to target embeddings
        output = tf.matmul(target_embeddings, kernel)   # (batch_size, max_images, 2)

        # Mask the output (zeros for non-target points)
        output = output * target_mask_expand

        return output

    def compute_output_shape(self, input_shapes):
        # Output shape is (batch_size, max_images, 2)
        return (input_shapes[0][0], input_shapes[0][1], 2)

    def get_config(self):
        config = super(MaskedWeightedRidgeRegressionLayer, self).get_config()
        config.update({
            "lambda_ridge": self.lambda_ridge,
            "epsilon": self.epsilon,
            "embedding_dim": self.embedding_dim
        })
        return config

# Generate dataset that has calibration points, target point, and target output

Cols = 5, 11, 17, 23, 29, 35, 41, 47, 53, 59, 65, 71, 77, 83, 89, 95

Rows = 5, 16.25, 27.5, 38.75, 50, 61.25, 72.5, 83.75, 95  


Fixed Points as calibration

In [None]:
cal_points = tf.constant([
    [5, 5],
    [5, 27.5],
    [5, 50],
    [5, 72.5],
    [5, 95],
    [35, 5],
    [35, 27.5],
    [35, 50],
    [35, 72.5],
    [35, 95],
    [65, 5],
    [65, 27.5],
    [65, 50],
    [65, 72.5],
    [65, 95],
    [95, 5],
    [95, 27.5],
    [95, 50],
    [95, 72.5],
    [95, 95],
], dtype=tf.float32)

scaled_cal_points = tf.divide(cal_points, tf.constant([100.]))


In [None]:
def create_embedding_model():
    input_eyes = keras.layers.Input(shape=(36,144,1))

    # Continue with the backbone
    backbone = keras_cv.models.DenseNetBackbone(
        include_rescaling=False,
        input_shape=(36,144,1),
        stackwise_num_repeats=DENSE_NET_STACKWISE_NUM_REPEATS
    )
    backbone_encoder = backbone(input_eyes)

    flatten_compress = keras.layers.Flatten()(backbone_encoder)
    eye_embedding = keras.layers.Dense(units=EMBEDDING_DIM, activation="tanh")(flatten_compress)

    embedding_model = keras.Model(inputs=input_eyes, outputs=eye_embedding, name="Eye_Image_Embedding")

    return embedding_model

In [None]:
def masked_normalized_weighted_euc_dist(y_true, y_pred):
    """Custom loss function that calculates a weighted Euclidean distance between two sets of points,
    respecting masks for padded sequences.

    Weighting multiplies the x-coordinate of each input by 1.778 (derived from the 16:9 aspect ratio of most laptops)
    to match the scale of the y-axis. For interpretability, the distances are normalized to the diagonal
    such that the maximum distance between two points (corner to corner) is 100.

    :param y_true (tensor): A tensor of shape (batch_size, seq_len, 2) containing ground-truth x- and y- coordinates
    :param y_pred (tensor): A tensor of shape (batch_size, seq_len, 2) containing predicted x- and y- coordinates

    :returns: A tensor of shape (1,) with the weighted and normalized euclidean distances between valid points.
    """
    # Extract shape information
    batch_size = tf.shape(y_true)[0]
    seq_len = tf.shape(y_true)[1]

    # Create a mask based on zero-padding in y_pred
    # Assuming padded values are exactly 0 in both dimensions
    mask = tf.reduce_any(tf.not_equal(y_pred, 0), axis=-1)
    mask = tf.cast(mask, tf.float32)

    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    # Weighting treats y-axis as unit-scale and creates a rectangle that's 177.8x100 units.
    x_weight = tf.constant([1.778, 1.0], dtype=tf.float32)

    # Multiply x-coordinate by 16/9 = 1.778
    y_true_weighted = tf.math.multiply(x_weight, y_true)
    y_pred_weighted = tf.math.multiply(x_weight, y_pred)

    # Calculate Euclidean distance with weighted coordinates
    squared_diff = tf.math.square(y_pred_weighted - y_true_weighted)
    squared_dist = tf.math.reduce_sum(squared_diff, axis=-1)
    dist = tf.math.sqrt(squared_dist)

    # Euclidean Distance from [0,0] to [177.8, 100] = 203.992
    norm_scale = tf.constant(203.992, dtype=tf.float32)

    # Normalizes loss values to the diagonal-- makes loss easier to interpret
    normalized_dist = (dist / norm_scale) * 100

    # Apply mask to only include valid points in the loss calculation
    masked_dist = normalized_dist * mask

    # Sum the losses and divide by the number of non-padded elements
    num_valid = tf.maximum(tf.reduce_sum(mask), 1.0)  # Avoid division by zero
    loss = tf.reduce_sum(masked_dist) / num_valid

    return loss

In [None]:
# Constants
MAX_IMAGES = MAX_TARGETS  # Total number of images in each batch
MIN_CAL_IMAGES = MIN_CAL_POINTS  # Minimum number of calibration images
MAX_CAL_IMAGES = MAX_CAL_POINTS  # Maximum number of calibration images

def prepare_masked_dataset(dataset, batch_size=200):
    """
    Creates a dataset where each batch contains all images for a subject,
    and a random subset is marked as calibration images via a mask.

    Args:
        dataset: The base TF dataset containing eye images and coordinates
        batch_size: Number of images to include in each batch

    Returns:
        A TF dataset with masked images ready for model training
    """
    # Step 1: Group dataset by subject_id and batch all images
    def group_by_subject(subject_id, ds):
        return ds.batch(batch_size=MAX_TARGETS)

    grouped_dataset = dataset.group_by_window(
        key_func=lambda img, mesh, coords, subject_id: subject_id,
        reduce_func=group_by_subject,
        window_size=batch_size
    )

    # Step 2: Transform each batch to include masks
    def add_masks_to_batch(images, meshes, coords, subject_ids):
        # Get actual batch size (number of images for this subject)
        actual_batch_size = tf.shape(images)[0]

        # Determine how many calibration images to use (random between min and max)
        n_cal_images = tf.random.uniform(
            shape=[],
            minval=MIN_CAL_IMAGES,
            maxval=MAX_CAL_IMAGES,
            dtype=tf.int32
        )

        # Create random indices for calibration images
        # NOTE: These will be different each time through the dataset
        random_indices = tf.random.shuffle(tf.range(actual_batch_size))
        cal_indices = random_indices[:n_cal_images]
        target_indices = random_indices[n_cal_images:]

        # Create masks (1 = included, 0 = excluded)
        cal_mask = tf.zeros(actual_batch_size, dtype=tf.float32)
        cal_mask = tf.tensor_scatter_nd_update(
            cal_mask,
            tf.expand_dims(cal_indices, 1),
            tf.ones(n_cal_images, dtype=tf.float32)
        )

        target_mask = tf.zeros(actual_batch_size, dtype=tf.float32)
        target_mask = tf.tensor_scatter_nd_update(
            target_mask,
            tf.expand_dims(target_indices, 1),
            tf.ones(actual_batch_size - n_cal_images, dtype=tf.float32)
        )

        # Reshape images to expected format (batch, height, width, channels)
        reshaped_images = tf.reshape(images, (-1, 36, 144, 1))

        # Pad everything to fixed size
        padded_images = tf.pad(
            reshaped_images,
            [[0, MAX_IMAGES - actual_batch_size], [0, 0], [0, 0], [0, 0]]
        )
        padded_coords = tf.pad(
            coords,
            [[0, MAX_IMAGES - actual_batch_size], [0, 0]]
        )
        padded_cal_mask = tf.pad(
            cal_mask,
            [[0, MAX_IMAGES - actual_batch_size]]
        )
        padded_target_mask = tf.pad(
            target_mask,
            [[0, MAX_IMAGES - actual_batch_size]]
        )

        # Ensure all shapes are fixed
        padded_images = tf.ensure_shape(padded_images, [MAX_IMAGES, 36, 144, 1])
        padded_coords = tf.ensure_shape(padded_coords, [MAX_IMAGES, 2])
        padded_cal_mask = tf.ensure_shape(padded_cal_mask, [MAX_IMAGES])
        padded_target_mask = tf.ensure_shape(padded_target_mask, [MAX_IMAGES])

        return (padded_images, padded_coords, padded_cal_mask, padded_target_mask), padded_coords

    # Apply the transformation
    masked_dataset = grouped_dataset.map(
        lambda imgs, meshes, coords, subj_ids: add_masks_to_batch(imgs, meshes, coords, subj_ids),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Cache the results for faster training
    masked_dataset = masked_dataset.cache().prefetch(tf.data.AUTOTUNE)

    return masked_dataset

# Modified model input preparation function
def prepare_model_inputs(features, labels):
    """
    Restructure the inputs for the model, using the mask-based approach.

    Args:
        features: Tuple of (images, coords, cal_mask, target_mask)
        labels: The target coordinates

    Returns:
        Dictionary of inputs for the model and the target labels
    """
    images, coords, cal_mask, target_mask = features

    # Use masking to create the model inputs
    inputs = {
        "Input_All_Images": images,                   # All images (both cal and target)
        "Input_All_Coords": coords,                   # All coordinates
        "Input_Calibration_Mask": cal_mask,           # Mask indicating calibration images
        "Input_Target_Mask": target_mask              # Mask indicating target images
    }

    return inputs, labels

# Modified model to work with the new masked input approach
def create_masked_model():
    """Create a model that uses masks to distinguish calibration and target images"""

    # Create the embedding model (same as before)
    embedding_model = create_embedding_model()

    # Define new inputs
    input_all_images = keras.layers.Input(shape=(MAX_IMAGES, 36, 144, 1), name="Input_All_Images")
    input_all_coords = keras.layers.Input(shape=(MAX_IMAGES, 2), name="Input_All_Coords")
    input_cal_mask = keras.layers.Input(shape=(MAX_IMAGES,), name="Input_Calibration_Mask")
    input_target_mask = keras.layers.Input(shape=(MAX_IMAGES,), name="Input_Target_Mask")

    # Apply the embedding model to all images
    all_embeddings = keras.layers.TimeDistributed(embedding_model)(input_all_images)

    # Calculate importance weights for all points
    calibration_weights = keras.layers.Dense(1, activation="sigmoid", name="Calibration_Weights")(all_embeddings)

    # Apply the WeightedRidgeRegressionLayer as before
    ridge = MaskedWeightedRidgeRegressionLayer(RIDGE_REGULARIZATION, embedding_dim=EMBEDDING_DIM)([
        all_embeddings,
        input_all_coords,
        calibration_weights,
        input_cal_mask,
        input_target_mask
    ])


    # Create the full model
    full_model = keras.Model(
        inputs=[
            input_all_images,
            input_all_coords,
            input_cal_mask,
            input_target_mask
        ],
        outputs=ridge,
        name="MaskedEyePredictionModel"
    )

    return full_model

# Usage example
def prepare_and_train_with_masks():
    """Example of how to use the mask-based approach for training"""

    # 1. Prepare the rescaled data
    train_data_rescaled = train_data.map(rescale_coords_map).cache()

    # 2. Create the masked dataset
    masked_dataset = prepare_masked_dataset(train_data_rescaled)

    # 3. Prepare the dataset for the model
    train_ds_for_model = masked_dataset.map(
        prepare_model_inputs,
        num_parallel_calls=tf.data.AUTOTUNE
    ).prefetch(tf.data.AUTOTUNE)

    # 4. Create and compile the model
    mask_model = create_masked_model()
    mask_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss=masked_normalized_weighted_euc_dist,
        metrics=[normalized_weighted_euc_dist]
    )

    # 5. Train the model
    history = mask_model.fit(
        train_ds_for_model.batch(4),
        epochs=TRAIN_EPOCHS,
        callbacks=[
            WandbMetricsLogger(),
            keras.callbacks.LearningRateScheduler(lr_schedule)
        ]
    )

    return mask_model, history

In [None]:
masked_dataset = prepare_masked_dataset(train_data_rescaled)

In [None]:
train_ds_for_model = masked_dataset.map(
        prepare_model_inputs,
        num_parallel_calls=tf.data.AUTOTUNE
    ).take(10).prefetch(tf.data.AUTOTUNE)

In [None]:
train_ds_for_model.element_spec

In [None]:
mask_model = create_masked_model()

In [None]:
mask_model.summary()

In [None]:
mask_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=masked_normalized_weighted_euc_dist,
    metrics=[normalized_weighted_euc_dist],
    jit_compile=False
)

In [None]:
mask_model.fit(train_ds_for_model.batch(4), epochs=10)

## Create augmentation pipeline

In [None]:
augmentation_layers = [
    #keras.layers.RandomBrightness(factor=0.1),
    #keras.layers.RandomContrast(factor=0.1),
    keras.layers.RandomRotation(factor=.05, fill_mode='constant', fill_value=0),
    #keras.layers.RandomTranslation(height_factor=0.05, width_factor=0.05)
]

augmentation_model = keras.Sequential(augmentation_layers)

@tf.function
def apply_consistent_augmentations(inputs, augmentation_model):
    """Apply consistent augmentations to calibration and target images.

    Args:
        inputs: Tuple of (cal_imgs, cal_coords, cal_mask, target_imgs, target_mask)
        augmentation_fn: Function created by create_augmentation_fn

    Returns:
        Tuple with augmented images in the same structure
    """
    cal_imgs, cal_coords, cal_mask, target_imgs, target_mask = inputs

    merged_cal_images = tf.transpose(cal_imgs, perm=[3, 1, 2, 0])
    merged_target_images = tf.transpose(target_imgs, perm=[3, 1, 2, 0])

    n_cal_images = MAX_CAL_POINTS
    n_target_images = MAX_TARGETS

    merged_all_images = tf.concat([merged_cal_images, merged_target_images], axis=-1)

    merged_all_images_aug = augmentation_model(merged_all_images)

    merged_cal_imgs_aug = merged_all_images_aug[..., :n_cal_images]
    merged_target_imgs_aug = merged_all_images_aug[..., n_cal_images:]

    cal_imgs_aug = tf.transpose(merged_cal_imgs_aug, perm=[3, 1, 2, 0])
    target_imgs_aug = tf.transpose(merged_target_imgs_aug, perm=[3, 1, 2, 0])

    return cal_imgs_aug, cal_coords, cal_mask, target_imgs_aug, target_mask

In [None]:
t0 = train_data_rescaled.cache()

t1 = t0.group_by_window(
    key_func = lambda img, m, c, z: z,
    reduce_func = reducer_function,
    window_size = 200
)

t2 = t0.group_by_window(
    key_func = lambda img, m, c, z: z,
    reduce_func = reducer_function_fixed_pts_with_id,
    window_size = 200
)

# Apply to dataset - ensure consistent shapes
train_ds = t1.map(
    lambda x, y: (apply_consistent_augmentations(x, augmentation_model), y),
    num_parallel_calls=tf.data.AUTOTUNE
)

# Update the batching step to use a custom padded_batch function
def prepare_batch_for_model(features, labels):
    """Ensure inputs are correctly formatted for the model"""
    cal_imgs, cal_coords, cal_mask, target_imgs, target_mask = features

    # Prepare inputs as a dictionary matching the model's expected inputs
    inputs = {
        "Input_Calibration_Eyes": cal_imgs,
        "Input_Calibration_Points": cal_coords,
        "Input_Calibration_Mask": cal_mask,
        "Input_Target_Eyes": target_imgs,
        "Input_Target_Mask": target_mask
    }

    return inputs, labels

# Use the updated version with proper input preparation
train_ds_for_model = train_ds.map(prepare_batch_for_model).prefetch(tf.data.AUTOTUNE)


In [None]:
train_ds.element_spec

## Visualize augmentation

In [None]:
def visualize_augmentations(dataset, num_examples=5):
    """Visualize augmentations from a tf.data.Dataset."""

    # Get a sample element from the dataset
    for element in dataset.take(1):
        # Access the calibration images from the element
        images = element[0][0]  # Access all calibration images

    # Calculate grid dimensions
    num_images = images.shape[0]  # Get the number of images in the batch
    grid_cols = 5  # Number of columns in the grid
    grid_rows = (num_images + grid_cols - 1) // grid_cols  # Calculate number of rows

    # Create a figure to display results
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(15, 3 * grid_rows))  # Create subplots in a grid

    # Display all images
    for i in range(num_images):
        row = i // grid_cols
        col = i % grid_cols
        axes[row, col].imshow(images[i].numpy().squeeze(), cmap='gray')  # Display image
        axes[row, col].set_title(f"Image {i + 1}")  # Set title for each subplot
        axes[row, col].axis('off')  # Turn off axis

    # Hide empty subplots if any
    for i in range(num_images, grid_rows * grid_cols):
        row = i // grid_cols
        col = i % grid_cols
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.show()

# Use this to check your augmentations
visualize_augmentations(train_ds)

# Construct Model



## Eye image processing

## Regression Layer

In [None]:
class WeightedRidgeRegressionLayer(keras.layers.Layer):
    def __init__(self, lambda_ridge, embedding_dim, epsilon=1e-6, **kwargs):
        self.lambda_ridge = lambda_ridge
        self.epsilon = epsilon
        self.embedding_dim = embedding_dim  # Required static embedding dimension
        super(WeightedRidgeRegressionLayer, self).__init__(**kwargs)

        # Create identity matrix constant during initialization
        # This ensures it's created only once and is available immediately
        self._identity = tf.eye(embedding_dim, dtype=tf.float32)

    @tf.function(jit_compile=True)
    def call(self, inputs):
        unknown_embeddings, calibration_embeddings, calibration_coords, weights, cal_mask, target_mask = inputs

        # Explicitly set shapes where possible (assuming batch_size can remain dynamic)
        batch_size = tf.shape(calibration_embeddings)[0]

        # Force shapes to match the static embedding dimension
        calibration_embeddings = tf.ensure_shape(
            calibration_embeddings, [None, None, self.embedding_dim]
        )
        unknown_embeddings = tf.ensure_shape(
            unknown_embeddings, [None, None, self.embedding_dim]
        )

        # Cast all inputs to the same precision
        X = tf.cast(calibration_embeddings, tf.float32)  # (batch_size, n_calibration, embedding_dim)
        y = tf.cast(calibration_coords, tf.float32)      # (batch_size, n_calibration, 2)
        w = tf.cast(weights, tf.float32)                 # (batch_size, n_calibration)
        cal_mask = tf.cast(cal_mask, tf.float32)         # (batch_size, n_calibration)
        target_mask = tf.cast(target_mask, tf.float32)   # (batch_size, n_target)

        # Ensure shapes are consistent
        n_calibration = tf.shape(X)[1]
        n_target = tf.shape(unknown_embeddings)[1]

        # Broadcast the identity matrix to batch size
        I = tf.broadcast_to(self._identity, [batch_size, self.embedding_dim, self.embedding_dim])

        # Apply mask to weights - ensure this is broadcastable
        w = w * cal_mask

        # Apply weights to X and y with explicit reshaping
        w_sqrt = tf.sqrt(w)
        w_sqrt = tf.reshape(w_sqrt, [batch_size, n_calibration, 1])  # Reshape for broadcasting

        X_weighted = X * w_sqrt  # (batch_size, n_calibration, embedding_dim)
        y_weighted = y * w_sqrt  # (batch_size, n_calibration, 2)

        # Matrix multiplication with explicit transpose
        X_t = tf.transpose(X_weighted, perm=[0, 2, 1])  # (batch_size, embedding_dim, n_calibration)
        X_t_X = tf.matmul(X_t, X_weighted)  # (batch_size, embedding_dim, embedding_dim)

        # Add regularization with stable epsilon
        ridge_term = tf.multiply(I, self.lambda_ridge + self.epsilon)
        lhs = tf.add(X_t_X, ridge_term)  # (batch_size, embedding_dim, embedding_dim)

        # Compute right-hand side
        rhs = tf.matmul(X_t, y_weighted)  # (batch_size, embedding_dim, 2)

        # Solve linear system
        # Add stability check for better XLA compatibility
        is_singular = tf.math.reduce_any(tf.math.is_nan(lhs))
        kernel = tf.cond(
            is_singular,
            lambda: tf.zeros([batch_size, self.embedding_dim, 2], dtype=tf.float32),
            lambda: tf.linalg.solve(lhs, rhs)
        )  # Shape: [batch_size, embedding_dim, 2]

        # Apply regression to unknown point with explicit masking
        unknown_embeddings = tf.cast(unknown_embeddings, tf.float32)  # Shape: [batch_size, n_target, embedding_dim]
        target_mask_expanded = tf.reshape(target_mask, [batch_size, n_target, 1])
        unknown_embeddings_masked = unknown_embeddings * target_mask_expanded

        # Matrix multiplication with proper dimensions:
        # unknown_embeddings_masked: [batch_size, n_target, embedding_dim]
        # kernel:                    [batch_size, embedding_dim, 2]
        # Result:                    [batch_size, n_target, 2]
        output = tf.matmul(unknown_embeddings_masked, kernel)  # Shape: [batch_size, n_target, 2]

        # Apply final mask
        output = output * target_mask_expanded

        return output

    def compute_output_shape(self, input_shapes):
        unknown_embeddings_shape, _, _, _, _, _ = input_shapes
        return (unknown_embeddings_shape[0], unknown_embeddings_shape[1], 2)

    def build(self, input_shapes):
        # This method is called once before the first call() to build the layer
        # We already handle initialization in __init__, but this is a good place to validate shapes
        unknown_shape, cal_shape, coords_shape, weights_shape, cal_mask_shape, target_mask_shape = input_shapes

        if unknown_shape[-1] != self.embedding_dim:
            raise ValueError(f"Unknown embeddings dimension {unknown_shape[-1]} doesn't match specified embedding_dim {self.embedding_dim}")

        if cal_shape[-1] != self.embedding_dim:
            raise ValueError(f"Calibration embeddings dimension {cal_shape[-1]} doesn't match specified embedding_dim {self.embedding_dim}")

        # We don't need to create any weights in build() since we don't have trainable parameters
        super(WeightedRidgeRegressionLayer, self).build(input_shapes)

    def get_config(self):
        config = super(WeightedRidgeRegressionLayer, self).get_config()
        config.update({
            "lambda_ridge": self.lambda_ridge,
            "epsilon": self.epsilon,
            "embedding_dim": self.embedding_dim
        })
        return config

In [None]:
def create_full_model():

  embedding_model = create_embedding_model()

  input_calibration_eyes = keras.layers.Input(shape=(MAX_CAL_POINTS, 36, 144, 1), name="Input_Calibration_Eyes")
  input_calibration_points = keras.layers.Input(shape=(MAX_CAL_POINTS, 2), name="Input_Calibration_Points")
  input_calibration_mask = keras.layers.Input(shape=(MAX_CAL_POINTS,), name="Input_Calibration_Mask")

  input_target_eyes = keras.layers.Input(shape=(MAX_TARGETS, 36, 144, 1), name="Input_Target_Eyes")
  input_target_mask = keras.layers.Input(shape=(MAX_TARGETS,), name="Input_Target_Mask")

  # Apply the embedding model using masking
  target_embedding = keras.layers.TimeDistributed(embedding_model)(input_target_eyes)
  calibration_embeddings = keras.layers.TimeDistributed(embedding_model)(input_calibration_eyes)

  calibration_weights = keras.layers.Dense(1, activation="sigmoid", name="Calibration_Weights")(calibration_embeddings)
  calibration_weights_reshaped = keras.layers.Reshape((-1,), name="Calibration_Weights_Reshaped")(calibration_weights)

  ridge = WeightedRidgeRegressionLayer(RIDGE_REGULARIZATION, embedding_dim=EMBEDDING_DIM)([
    target_embedding,
    calibration_embeddings,
    input_calibration_points,
    calibration_weights_reshaped,
    input_calibration_mask,
    input_target_mask
  ])

  full_model = keras.Model(inputs=[
    input_calibration_eyes,
    input_calibration_points,
    input_calibration_mask,
    input_target_eyes,
    input_target_mask
  ], outputs=ridge, name="FullEyePredictionModel")

  return full_model

## Full trainable model

In [None]:
full_model = create_full_model()

full_model.summary()



full_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=masked_normalized_weighted_euc_dist,
    metrics=[normalized_weighted_euc_dist],  # Keep the original for comparison if needed
    jit_compile=False
)

In [None]:
def lr_schedule(epoch):
  if epoch < 15:
    return LEARNING_RATE
  else:
    return LEARNING_RATE * 0.1

lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule)

# Train model

In [None]:
train_history = full_model.fit(
    train_ds_for_model.batch(4),
    epochs=TRAIN_EPOCHS,
    callbacks=[WandbMetricsLogger(), lr_scheduler]
)

In [None]:
full_model.save('full_model.keras')

In [None]:
full_model.save_weights('full_model.weights.h5')

In [None]:
wandb.save('full_model.keras')
wandb.save('full_model.weights.h5')

In [None]:
# Update the prediction code
t2_processed = t2.map(
    lambda x, y, id: (prepare_batch_for_model((x[0], x[1], x[2], x[3], x[4]), y)[0], y, id)
).prefetch(tf.data.AUTOTUNE)

# Initialize a list to store the batch losses
batch_losses = []
subject_ids = []
y_true = []

for e in t2_processed.batch(1).as_numpy_iterator():
    inputs = e[0]
    y_true.append(e[1])
    subject_ids.append(e[2][0])

y_true = np.array(y_true).reshape(-1, 144, 2)

# this step is slower than expected. not sure what's going on.
predictions = full_model.predict(t2_processed.batch(1))

for i in range(len(predictions)):
  loss = normalized_weighted_euc_dist(y_true[i], predictions[i]).numpy()
  batch_losses.append(loss)

batch_losses = np.array(batch_losses)

# Get mean per subject
batch_losses = np.mean(batch_losses, axis=1)

In [None]:
final_loss_table = wandb.Table(data=[[s, l] for (s, l) in zip(subject_ids, batch_losses)], columns=["subject", "scores"])

In [None]:
final_loss_hist = wandb.plot.histogram(final_loss_table, "scores", title="Normalized Euclidean Distance")

In [None]:
final_loss_mean = np.mean(batch_losses)

In [None]:
wandb.log({"final_val_loss_table": final_loss_table, "final_val_loss_hist": final_loss_hist, "final_loss_mean": final_loss_mean})

In [None]:
plt.figure(figsize=(8, 6))
plt.hist(batch_losses, bins=20, edgecolor='black')
plt.title('Histogram of Batch Losses')
plt.xlabel('Loss')
plt.ylabel('Frequency')
plt.show()

In [None]:
wandb.finish()