# Setup

In [None]:
!pip install osfclient --quiet
!pip install git+https://github.com/jspsych/eyetracking-utils.git --quiet
!pip install wandb --quiet
!pip install --upgrade keras --quiet
!pip install keras-hub --quiet

In [None]:
# ============================================================================
# Backend Configuration - Must be set before importing Keras
# ============================================================================
import os
os.environ["KERAS_BACKEND"] = "jax"  # JAX backend for TPU compatibility

# ============================================================================
# Core ML/Deep Learning Libraries
# ============================================================================
import tensorflow as tf  # Used for data pipeline (tf.data, tf.io)
import numpy as np
import math
import keras
import keras_hub
from keras import ops  # Backend-agnostic operations

# ============================================================================
# Experiment Tracking
# ============================================================================
import wandb
from wandb.integration.keras import WandbMetricsLogger

# ============================================================================
# Visualization
# ============================================================================
import matplotlib.pyplot as plt

# ============================================================================
# Google Colab Utilities
# ============================================================================
from google.colab import userdata

# ============================================================================
# Project-Specific Utilities
# ============================================================================
import et_util.dataset_utils as dataset_utils
import et_util.embedding_preprocessing as embed_pre
import et_util.model_layers as model_layers
from et_util import experiment_utils
from et_util.custom_loss import normalized_weighted_euc_dist
from et_util.model_analysis import plot_model_performance

In [None]:
os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
os.environ['OSF_TOKEN'] = userdata.get('osftoken')
os.environ['OSF_USERNAME'] = userdata.get('osfusername')

In [None]:
keras.version()

In [None]:
keras.mixed_precision.set_global_policy('bfloat16')

# Configure W&B experiment

In [None]:
wandb.login()

In [None]:
# =============================================================================
# EXPERIMENT CONFIGURATION
# =============================================================================

# -----------------------------------------------------------------------------
# Dataset Parameters
# -----------------------------------------------------------------------------
MAX_TARGETS = 144  # Maximum images per subject (for padding)

# -----------------------------------------------------------------------------
# Model Architecture
# -----------------------------------------------------------------------------
EMBEDDING_DIM = 200  # Size of embedding vector
RIDGE_REGULARIZATION = 0.1  # Ridge regression lambda
MIN_CAL_POINTS = 8  # Minimum calibration images
MAX_CAL_POINTS = 40  # Maximum calibration images

BACKBONE = "densenet"  # Options: "densenet", "vit"

# -----------------------------------------------------------------------------
# DenseNet Configuration (when BACKBONE="densenet")
# -----------------------------------------------------------------------------
DENSENET_STACKWISE_NUM_REPEATS = [4, 4, 4]  # Number of dense blocks per stack
                                             # [4,4,4] = 3 dense blocks with 4 layers each

# -----------------------------------------------------------------------------
# Vision Transformer Configuration (when BACKBONE="vit")
# -----------------------------------------------------------------------------
VIT_PATCH_SIZE = 4  # Patch size (4x4 → 324 patches from 36×144 image)
VIT_HIDDEN_DIM = 256  # Hidden dimension (TPU-friendly)
VIT_NUM_LAYERS = 6  # Number of transformer layers
VIT_NUM_HEADS = 8  # Number of attention heads
VIT_MLP_DIM = 1024  # MLP dimension in transformer blocks
VIT_DROPOUT = 0.1  # Dropout rate
VIT_ATTENTION_DROPOUT = 0.0  # Attention dropout rate
VIT_USE_CLS_TOKEN = True  # Use CLS token for global representation

# -----------------------------------------------------------------------------
# Data Augmentation
# -----------------------------------------------------------------------------
AUGMENTATION = True  # Enable/disable augmentation

# Augmentation strategy
PER_IMAGE_AUGMENTATION = True  # If True, each image gets random augmentation
                               # If False, same augmentation per sequence

# Blur augmentation
BLUR_PROB = 0.3  # Probability of applying blur
BLUR_SIGMA_RANGE = (0.5, 1.5)  # Gaussian blur sigma range

# Color augmentations
BRIGHTNESS_RANGE = 0.3  # Brightness adjustment range
CONTRAST_RANGE = (0.7, 1.4)  # Contrast adjustment range
GAMMA_RANGE = (0.8, 1.2)  # Gamma correction range

# Noise augmentation
NOISE_PROB = 0.4  # Probability of adding noise
NOISE_STD_RANGE = (0.01, 0.05)  # Noise standard deviation range

# -----------------------------------------------------------------------------
# Training Hyperparameters
# -----------------------------------------------------------------------------
BATCH_SIZE = 5  # Number of subjects per batch
TRAIN_EPOCHS = 50  # Total training epochs

# Learning rate schedule
INITIAL_LEARNING_RATE = 0.00001  # Starting learning rate (warmup)
LEARNING_RATE = 0.01  # Target learning rate (after warmup)
WARMUP_EPOCHS = 2  # Number of warmup epochs
DECAY_EPOCHS = TRAIN_EPOCHS - WARMUP_EPOCHS  # Cosine decay epochs
DECAY_ALPHA = 0.01  # Final learning rate multiplier

In [None]:
config = {
    "embedding_dim": EMBEDDING_DIM,
    "ridge_regularization": RIDGE_REGULARIZATION,
    "train_epochs": TRAIN_EPOCHS,
    "min_cal_points": MIN_CAL_POINTS,
    "max_cal_points": MAX_CAL_POINTS,
    "backbone": BACKBONE,
    "batch_size": BATCH_SIZE,
    "initial_learning_rate": INITIAL_LEARNING_RATE,
    "learning_rate": LEARNING_RATE,
    "augmentation": AUGMENTATION,
    "per_image_augmentation": PER_IMAGE_AUGMENTATION,
    "blur_prob": BLUR_PROB,
    "blur_sigma_range": BLUR_SIGMA_RANGE,
    "brightness_range": BRIGHTNESS_RANGE,
    "contrast_range": CONTRAST_RANGE,
    "gamma_range": GAMMA_RANGE,
    "noise_prob": NOISE_PROB,
    "noise_std_range": NOISE_STD_RANGE,
    "warmup_epochs": WARMUP_EPOCHS,
    "decay_epochs": DECAY_EPOCHS,
    "decay_alpha": DECAY_ALPHA,
    
    # Add DenseNet-specific config when using DenseNet
    **({"densenet_stackwise_num_repeats": DENSENET_STACKWISE_NUM_REPEATS}
      if BACKBONE == "densenet" else {}),
    
    # Add ViT-specific config when using ViT
    **({"vit_patch_size": VIT_PATCH_SIZE,
        "vit_hidden_dim": VIT_HIDDEN_DIM,
        "vit_num_layers": VIT_NUM_LAYERS,
        "vit_num_heads": VIT_NUM_HEADS,
        "vit_mlp_dim": VIT_MLP_DIM,
        "vit_dropout": VIT_DROPOUT,
        "vit_attention_dropout": VIT_ATTENTION_DROPOUT,
        "vit_use_cls_token": VIT_USE_CLS_TOKEN}
      if BACKBONE == "vit" else {})
}

In [None]:
run = wandb.init(
    project='eye-tracking-dense-full-data-set-single-eye',
    config=config
)

# Dataset preparation

## Download dataset from OSF

In [None]:
!osf -p 6b5cd fetch single_eye_tfrecords.tar.gz

## Process raw data records into TF Dataset

In [None]:
!mkdir single_eye_tfrecords
!tar -xf single_eye_tfrecords.tar.gz -C single_eye_tfrecords

In [None]:
def parse(element):
    """Process function that parses a tfr element in a raw dataset for process_tfr_to_tfds function.
    Gets mediapipe landmarks, raw image, image width, image height, subject id, and xy labels.
    Use for data generated with make_single_example_landmarks_and_jpg (i.e. data in
    jpg_landmarks_tfrecords.tar.gz)

    :param element: tfr element in raw dataset
    :return: image, label(x,y), landmarks, subject_id
    """

    data_structure = {
        'landmarks': tf.io.FixedLenFeature([], tf.string),
        'img_width': tf.io.FixedLenFeature([], tf.int64),
        'img_height': tf.io.FixedLenFeature([], tf.int64),
        'x': tf.io.FixedLenFeature([], tf.float32),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'eye_img': tf.io.FixedLenFeature([], tf.string),
        'subject_id': tf.io.FixedLenFeature([], tf.int64),
    }

    content = tf.io.parse_single_example(element, data_structure)

    landmarks = content['landmarks']
    raw_image = content['eye_img']
    width = content['img_width']
    height = content['img_height']
    depth = 3
    label = [content['x'], content['y']]
    subject_id = content['subject_id']

    landmarks = tf.io.parse_tensor(landmarks, out_type=tf.float32)
    landmarks = ops.reshape(landmarks, shape=(478, 3))

    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)

    return image, landmarks, label, subject_id

In [None]:
train_data, validation_data, test_data = dataset_utils.process_tfr_to_tfds(
    'single_eye_tfrecords/',
    parse,
    train_split=1.0,
    val_split=0.0,
    test_split=0.0,
    random_seed=12604,
    group_function=lambda img, landmarks, coords, z: z
)

In [None]:
def rescale_coords_map(eyes, mesh, coords, id):
  return eyes, mesh, coords / 100.0, id

In [None]:
train_data_rescaled = train_data.map(rescale_coords_map)
validation_data_rescaled = validation_data.map(rescale_coords_map)

In [None]:
def prepare_masked_dataset(dataset, calibration_points=None):

    # Step 1: Group dataset by subject_id and batch all images
    def group_by_subject(subject_id, ds):
        return ds.batch(batch_size=MAX_TARGETS)

    grouped_dataset = dataset.group_by_window(
        key_func=lambda img, mesh, coords, subject_id: subject_id,
        reduce_func=group_by_subject,
        window_size=MAX_TARGETS
    )

    # Step 2: Filter out subjects with fewer than 72 images
    def filter_by_image_count(images, meshes, coords, subject_ids):
        return ops.shape(images)[0] >= 144

    grouped_dataset = grouped_dataset.filter(filter_by_image_count)

    # Step 3: Transform each batch to include masks
    def add_masks_to_batch(images, meshes, coords, subject_ids):

        actual_batch_size = ops.shape(images)[0]

        cal_mask = tf.zeros(MAX_TARGETS, dtype="int8")
        target_mask = tf.zeros(MAX_TARGETS, dtype="int8")

        # Determine how many calibration images to use (random between min and max)
        if calibration_points is None:
          n_cal_images = tf.random.uniform(
              shape=[],
              minval=MIN_CAL_POINTS,
              maxval=MAX_CAL_POINTS,
              dtype="int32"
          )
          # Create random indices for calibration images
          # NOTE: These will be different each time through the dataset
          random_indices = tf.random.shuffle(ops.arange(actual_batch_size))
          cal_indices = random_indices[:n_cal_images]

          # Create masks (1 = included, 0 = excluded)
          cal_mask = tf.scatter_nd(
              ops.expand_dims(cal_indices, 1),
              tf.ones(n_cal_images, dtype="int8"),
              [MAX_TARGETS]
          )
        else:
          coords_xpand = ops.expand_dims(coords, axis=1)
          cal_xpand = ops.expand_dims(calibration_points, axis=0)
          equality = ops.equal(coords_xpand, cal_xpand)
          matches = ops.all(equality, axis=-1)
          point_matches = ops.any(matches, axis=1)
          cal_mask = ops.cast(point_matches, dtype="int8")



        target_mask = 1 - cal_mask

        # Pad everything to fixed size

        padded_images = tf.pad(
            ops.reshape(images, (-1, 36, 144, 1)),
            [[0, MAX_TARGETS - actual_batch_size], [0, 0], [0, 0], [0, 0]]
        )
        padded_coords = tf.pad(
            coords,
            [[0, MAX_TARGETS - actual_batch_size], [0, 0]]
        )

        # Ensure all shapes are fixed
        padded_images = tf.ensure_shape(padded_images, [MAX_TARGETS, 36, 144, 1])
        padded_coords = tf.ensure_shape(padded_coords, [MAX_TARGETS, 2])
        padded_cal_mask = tf.ensure_shape(cal_mask, [MAX_TARGETS])
        padded_target_mask = tf.ensure_shape(target_mask, [MAX_TARGETS])
        return (padded_images, padded_coords, padded_cal_mask, padded_target_mask), padded_coords, subject_ids

    # Apply the transformation
    masked_dataset = grouped_dataset.map(
        lambda imgs, meshes, coords, subj_ids: add_masks_to_batch(imgs, meshes, coords, subj_ids),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    return masked_dataset

# Modified model input preparation function
def prepare_model_inputs(features, labels, subject_ids):
    """
    Restructure the inputs for the model, using the mask-based approach.

    Args:
        features: Tuple of (images, coords, cal_mask, target_mask)
        labels: The target coordinates

    Returns:
        Dictionary of inputs for the model and the target labels
    """
    images, coords, cal_mask, target_mask = features

    # Use masking to create the model inputs
    inputs = {
        "Input_All_Images": images,                   # All images (both cal and target)
        "Input_All_Coords": coords,                   # All coordinates
        "Input_Calibration_Mask": cal_mask,           # Mask indicating calibration images
    }

    return inputs, labels, target_mask # target_mask is used as sample weights for loss function

In [None]:
cal_points = np.array([
    [5, 5],
    [5, 27.5],
    [5, 50],
    [5, 72.5],
    [5, 95],
    [35, 5],
    [35, 27.5],
    [35, 50],
    [35, 72.5],
    [35, 95],
    [65, 5],
    [65, 27.5],
    [65, 50],
    [65, 72.5],
    [65, 95],
    [95, 5],
    [95, 27.5],
    [95, 50],
    [95, 72.5],
    [95, 95],
], dtype=np.float32)

scaled_cal_points = cal_points / 100.0

In [None]:
masked_dataset = prepare_masked_dataset(train_data_rescaled)

In [None]:
train_ds_for_model = masked_dataset.map(
    prepare_model_inputs,
    num_parallel_calls=tf.data.AUTOTUNE
).shuffle(200).prefetch(tf.data.AUTOTUNE)

In [None]:
INDIVIDUALS = 0
for e in train_ds_for_model.as_numpy_iterator():
  INDIVIDUALS += 1

## Visualization

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import tensorflow as tf
from IPython.display import HTML
import base64

def visualize_eye_tracking_data(dataset, interval=100, figsize=(10, 6)):
    """
    Visualize eye tracking data by animating through a subject's images and displaying
    where they are looking with a red dot. Returns an HTML animation for direct display
    in a Colab notebook.

    Args:
        dataset: A TensorFlow dataset containing a single batch of one subject's data
                Expected format: ({"Input_All_Images": images, "Input_All_Coords": coords,
                                 "Input_Calibration_Mask": cal_mask, "Input_Target_Mask": target_mask}, labels)
                OR: Your model's dataset with one subject batch
        interval: Time between frames in milliseconds
        figsize: Size of the figure (width, height)

    Returns:
        IPython.display.HTML object with the animation

    Example usage in a Colab notebook:
    ```python
    # For a dataset from a single subject
    subject_dataset = train_ds_for_model.take(1)

    # Or you can create a dataset for just one subject (example from your notebook):
    # This line extracts data for one subject from the complete dataset
    single_subject = train_ds_for_model.filter(
        lambda inputs, labels: ops.equal(inputs["Input_Subject_ID"][0], subject_id)
    ).take(1)

    # Display the animation directly in the notebook
    from IPython.display import display
    animation = visualize_eye_tracking_data(single_subject)
    display(animation)
    ```
    """
    # Extract data from the dataset
    for inputs, labels, _ in dataset.take(1):
        images = inputs["Input_All_Images"]
        coords = inputs["Input_All_Coords"]

        # Convert to numpy arrays
        valid_images = images.numpy()
        valid_coords = coords.numpy()

        # Sort by X coordinate, then by Y coordinate
        sorted_indices = np.lexsort((valid_coords[:, 0], valid_coords[:, 1]))
        valid_images = valid_images[sorted_indices]
        valid_coords = valid_coords[sorted_indices]

        num_images = valid_images.shape[0]

    # Set up the figure
    plt.ioff()  # Turn off interactive mode to avoid displaying during generation
    fig, ax = plt.subplots(figsize=figsize)

    # Create a function that draws each frame
    def draw_frame(frame_num):
        ax.clear()

        # Set up the 16:9 coordinate space
        ax.set_xlim(0, 100)
        ax.set_ylim(100, 0)  # 16:9 aspect ratio

        # Get the current image and coordinates
        img = valid_images[frame_num].squeeze()
        x, y = valid_coords[frame_num]

        # Draw rectangle for the screen
        rect = plt.Rectangle((0, 0), 100, 100, fill=False, color='black', linewidth=2)
        ax.add_patch(rect)

        # Add image display in the top right corner
        img_display = ax.inset_axes([0.35, 0.35, 0.3, 0.3], transform=ax.transAxes)
        img_display.imshow(np.fliplr(img), cmap='gray')
        img_display.axis('off')

        # Draw red dot at the coordinate - make it bigger for visibility
        ax.scatter(x * 100, y * 100, color='red', s=150, zorder=5,
                  edgecolor='white', linewidth=1.5)

        # Draw crosshair
        ax.axhline(y * 100, color='gray', linestyle='--', alpha=0.5, zorder=1)
        ax.axvline(x * 100, color='gray', linestyle='--', alpha=0.5, zorder=1)

        # Set title
        ax.set_title('Eye Tracking Visualization')

        # Set axis labels
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')

    # Create the animation
    anim = animation.FuncAnimation(fig, draw_frame, frames=num_images, interval=interval)

  # Use animation's to_jshtml method which directly generates HTML
    html_animation = anim.to_jshtml()

    # Clean up
    plt.close(fig)

    # Create an HTML object that can be displayed in the notebook
    return HTML(html_animation)

In [None]:
visualize_eye_tracking_data(train_ds_for_model.take(1))

# Model building

## Custom Layers


In [None]:
class SimpleTimeDistributed(keras.layers.Wrapper):
    """A simplified version of TimeDistributed that applies a layer to every temporal slice of an input.

    This implementation avoids for loops by using reshape operations to apply the wrapped layer
    to all time steps at once.

    Args:
        layer: a `keras.layers.Layer` instance.
    """

    def __init__(self, layer, **kwargs):
        super().__init__(layer, **kwargs)
        self.supports_masking = getattr(layer, 'supports_masking', False)

    def build(self, input_shape):
        # Validate input shape has at least 3 dimensions (batch, time, ...)
        if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3:
            raise ValueError(
                "`SimpleTimeDistributed` requires input with at least 3 dimensions"
            )

        # Build the wrapped layer with shape excluding the time dimension
        super().build((input_shape[0], *input_shape[2:]))
        self.built = True

    def compute_output_shape(self, input_shape):
        # Get output shape by applying the layer to a single time slice
        child_output_shape = self.layer.compute_output_shape((input_shape[0], *input_shape[2:]))
        # Include time dimension in the result
        return (child_output_shape[0], input_shape[1], *child_output_shape[1:])

    def call(self, inputs, training=None):
        input_shape = ops.shape(inputs)
        batch_size = input_shape[0]
        time_steps = input_shape[1]

        # Reshape inputs to combine batch and time dimensions: (batch*time, ...)
        reshaped_inputs = ops.reshape(inputs, (-1, *input_shape[2:]))

        # Apply the layer to all time steps at once
        outputs = self.layer.call(reshaped_inputs, training=training)

        # Get output dimensions
        output_shape = ops.shape(outputs)

        # Reshape back to include the separate batch and time dimensions: (batch, time, ...)
        return ops.reshape(outputs, (batch_size, time_steps, *output_shape[1:]))

In [None]:
class MaskedWeightedRidgeRegressionLayer(keras.layers.Layer):
    """
    A custom layer that performs weighted ridge regression with proper masking support.

    This layer takes embeddings, coordinates, weights, and calibration mask as explicit inputs,
    while using Keras' masking system to handle target masking. This separation allows more
    precise control over calibration points while leveraging Keras' built-in mask propagation
    for target predictions.

    Args:
        lambda_ridge (float): Regularization parameter for ridge regression
        embedding_dim (int): Dimension of the embedding vectors
        epsilon (float): Small constant for numerical stability inside sqrt
    """
    def __init__(self, lambda_ridge, epsilon=1e-7, **kwargs): # Added epsilon argument
        self.lambda_ridge = lambda_ridge
        self.epsilon = epsilon # Store epsilon
        super(MaskedWeightedRidgeRegressionLayer, self).__init__(**kwargs)

    def call(self, inputs, mask=None):
        """
        The forward pass of the layer.

        Args:
            inputs: A list containing:
                - embeddings: Embeddings for all points (batch_size, n_points, embedding_dim)
                - coords: Coordinates for all points (batch_size, n_points, 2)
                - calibration_weights: Importance weights (batch_size, n_points, 1)
                - cal_mask: Mask for calibration points (batch_size, n_points) [EXPLICIT]

        Returns:
            Predicted coordinates for the target points (batch_size, n_points, 2)
        """
        # Unpack inputs
        embeddings, coords, calibration_weights, cal_mask  = inputs

        # Ensure correct dtype, especially important for JIT
        embeddings = ops.cast(embeddings, "float32")
        coords = ops.cast(coords, "float32")
        calibration_weights = ops.cast(calibration_weights, "float32")
        cal_mask = ops.cast(cal_mask, "float32")

        # reshape weights to (batch, calibration)
        w = ops.squeeze(calibration_weights, axis=-1)

        # Pre-compute masked weights for calibration points
        w_masked = w * cal_mask
        # Add epsilon inside sqrt for numerical stability
        w_sqrt = ops.sqrt(w_masked + self.epsilon)
        w_sqrt = ops.expand_dims(w_sqrt, -1)  # More efficient than reshape

        # Apply calibration mask to embeddings
        cal_mask_expand = ops.expand_dims(cal_mask, -1)
        X = embeddings * cal_mask_expand # Mask out non-calibration embeddings for regression calculation

        # Weight calibration embeddings and coordinates using the masked weights
        X_weighted = X * w_sqrt
        y_weighted = coords * w_sqrt * cal_mask_expand # Also mask coords just to be safe

        # Matrix operations
        X_t = ops.transpose(X_weighted, axes=[0, 2, 1])
        X_t_X = ops.matmul(X_t, X_weighted)

        # Add regularization - CORRECTED: Removed the extra * 1e-3
        # Ensure the identity matrix is also float32
        identity_matrix = ops.cast(ops.eye(ops.shape(embeddings)[-1]), "float32")
        lhs = X_t_X + self.lambda_ridge * identity_matrix

        # Compute RHS
        rhs = ops.matmul(X_t, y_weighted)

        # Solve the system
        # Consider adding checks or alternative solvers if instability persists,
        # but correcting the regularization should be the primary fix.
        kernel = ops.linalg.solve(lhs, rhs)

        # Apply regression using the *original* full embeddings
        output = ops.matmul(embeddings, kernel)

        return output

    def compute_output_shape(self, input_shapes):
        """
        Computes the output shape of the layer.

        Args:
            input_shapes: List of input shapes

        Returns:
            Output shape tuple
        """
        # Output shape matches coordinates: (batch_size, n_points, 2)
        return (input_shapes[0][0], input_shapes[0][1], 2)

    def get_config(self):
        """
        Returns the configuration of the layer for serialization.

        Returns:
            Dictionary containing the layer configuration
        """
        config = super(MaskedWeightedRidgeRegressionLayer, self).get_config()
        config.update({
            "lambda_ridge": self.lambda_ridge,
            "epsilon": self.epsilon # Add epsilon to config
        })
        return config

In [None]:
class MaskInspectorLayer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MaskInspectorLayer, self).__init__(**kwargs)
        self.supports_masking = True
    def call(self, inputs, mask=None):
      # Print mask information
      print("Layer mask:", mask)
      return inputs

## Custom loss

## Embedding Model

### Backbones

#### DenseNet

In [None]:
def create_dense_net_backbone():
    """
    Creates a DenseNet backbone for eye tracking images.
    
    Architecture:
    - Input: 36x144x1 grayscale images
    - Dense blocks: Configurable via DENSENET_STACKWISE_NUM_REPEATS
    - Default [4,4,4]: 3 dense blocks with 4 layers each
    
    Returns:
        DenseNet backbone model
    """
    return keras_hub.models.DenseNetBackbone(
        stackwise_num_repeats=DENSENET_STACKWISE_NUM_REPEATS,
        image_shape=(36, 144, 1),
    )

#### Involution

#### EfficientNetB0

In [None]:
def create_vit_backbone():
    """
    Creates a Vision Transformer backbone for processing eye images.
    
    Architecture:
    - Input: 36x144x1 grayscale images
    - Patches: 4x4 patches → 9×36 = 324 patches
    - Transformer: 6 layers, 8 heads, 256 hidden dim, 1024 MLP dim
    - Output: (batch, num_patches+1, hidden_dim) with CLS token
    
    Note: Uses custom Conv2D-based patching to handle non-square images
    natively without wasting compute on padding.
    """
    
    image_shape = (36, 144, 1)
    inputs = keras.layers.Input(shape=image_shape)
    
    # Patch embedding using Conv2D (handles non-square images)
    patch_embeddings = keras.layers.Conv2D(
        filters=VIT_HIDDEN_DIM,
        kernel_size=VIT_PATCH_SIZE,
        strides=VIT_PATCH_SIZE,
        padding="valid",
        activation=None,
        name="patch_embedding"
    )(inputs)
    # Shape: (batch, 9, 36, 256)
    
    # Reshape patches to sequence
    num_patches_h = 36 // VIT_PATCH_SIZE  # 9
    num_patches_w = 144 // VIT_PATCH_SIZE  # 36
    num_patches = num_patches_h * num_patches_w  # 324
    
    patches_reshaped = keras.layers.Reshape(
        target_shape=(num_patches, VIT_HIDDEN_DIM),
        name="flatten_patches"
    )(patch_embeddings)
    # Shape: (batch, 324, 256)
    
    # Add CLS token
    class_token = keras.layers.Embedding(
        input_dim=1,
        output_dim=VIT_HIDDEN_DIM,
        embeddings_initializer="random_normal",
        name="cls_token"
    )(keras.ops.zeros((keras.ops.shape(patches_reshaped)[0], 1), dtype="int32"))
    # Shape: (batch, 1, 256)
    
    # Concatenate CLS token with patches
    tokens = keras.layers.Concatenate(axis=1, name="concat_cls")(
        [class_token, patches_reshaped]
    )
    # Shape: (batch, 325, 256)
    
    # Add positional embeddings
    num_positions = num_patches + 1  # 325
    position_embedding = keras.layers.Embedding(
        input_dim=num_positions,
        output_dim=VIT_HIDDEN_DIM,
        embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02),
        name="position_embedding"
    )
    
    position_ids = keras.ops.expand_dims(
        keras.ops.arange(num_positions), axis=0
    )
    position_ids = keras.ops.tile(
        position_ids, [keras.ops.shape(tokens)[0], 1]
    )
    pos_embeddings = position_embedding(position_ids)
    
    embeddings = keras.layers.Add(name="add_position_embedding")(
        [tokens, pos_embeddings]
    )
    # Shape: (batch, 325, 256)
    
    # Apply Transformer Encoder
    x = embeddings
    for i in range(VIT_NUM_LAYERS):
        # Pre-norm architecture
        x_norm = keras.layers.LayerNormalization(epsilon=1e-6, name=f"layer_norm_{i}_1")(x)
        
        # Multi-head self-attention
        attn_output = keras.layers.MultiHeadAttention(
            num_heads=VIT_NUM_HEADS,
            key_dim=VIT_HIDDEN_DIM // VIT_NUM_HEADS,
            dropout=VIT_ATTENTION_DROPOUT,
            name=f"attention_{i}"
        )(x_norm, x_norm)
        
        attn_output = keras.layers.Dropout(VIT_DROPOUT, name=f"attn_dropout_{i}")(attn_output)
        x = keras.layers.Add(name=f"attn_add_{i}")([x, attn_output])
        
        # MLP block
        x_norm = keras.layers.LayerNormalization(epsilon=1e-6, name=f"layer_norm_{i}_2")(x)
        
        mlp = keras.Sequential([
            keras.layers.Dense(VIT_MLP_DIM, activation="gelu", name=f"mlp_{i}_dense_1"),
            keras.layers.Dropout(VIT_DROPOUT, name=f"mlp_{i}_dropout_1"),
            keras.layers.Dense(VIT_HIDDEN_DIM, name=f"mlp_{i}_dense_2"),
            keras.layers.Dropout(VIT_DROPOUT, name=f"mlp_{i}_dropout_2"),
        ], name=f"mlp_{i}")
        
        mlp_output = mlp(x_norm)
        x = keras.layers.Add(name=f"mlp_add_{i}")([x, mlp_output])
    
    # Final layer norm
    encoder_output = keras.layers.LayerNormalization(epsilon=1e-6, name="final_layer_norm")(x)
    # Shape: (batch, 325, 256)
    
    model = keras.Model(inputs=inputs, outputs=encoder_output, name="ViT_Backbone")
    
    return model

#### Vision Transformer

### Create Embedding Model

In [None]:
def create_embedding_model(BACKBONE):
  image_shape = (36, 144, 1)
  input_eyes = keras.layers.Input(shape=image_shape)

  eyes_rescaled = keras.layers.Rescaling(scale=1./255)(input_eyes)

  # Continue with the backbone
  # backbone = create_rednet_backbone()


  # backbone = keras.Sequential([
  #     keras.layers.Flatten(),
  #     keras.layers.Dense(10, activation="relu")
  # ])
  if BACKBONE == "densenet":
    backbone = create_dense_net_backbone()

  # backbone = create_efficientnet_backbone()

  # backbone = keras_hub.models.MiTBackbone(
  #   image_shape=(36,144,1),
  #   layerwise_depths=[2,2,2,2],
  #   num_layers=4,
  #   layerwise_num_heads=[1,2,5,8],
  #   layerwise_sr_ratios=[8,4,2,1],
  #   max_drop_path_rate=0.1,
  #   layerwise_patch_sizes=[7,3,3,3],
  #   layerwise_strides=[4,2,2,2],
  #   hidden_dims=[32,64,160,256]
  # )

  backbone_encoder = backbone(eyes_rescaled)
  flatten_compress = keras.layers.Flatten()(backbone_encoder)
  eye_embedding = keras.layers.Dense(units=EMBEDDING_DIM, activation="tanh")(flatten_compress)

  embedding_model = keras.Model(inputs=input_eyes, outputs=eye_embedding, name="Eye_Image_Embedding")

  return embedding_model

## Full Model

# Model Training


## Augmentation

In [None]:
def get_translation_matrix(tx, ty):
    """Creates a 3x3 translation matrix."""
    return ops.convert_to_tensor([
        [1., 0., tx],
        [0., 1., ty],
        [0., 0., 1.]
    ], dtype="float32")

def get_rotation_matrix(angle, center_x, center_y):
    """Creates a 3x3 rotation matrix around a center point."""
    center_x = ops.cast(center_x, "float32")
    center_y = ops.cast(center_y, "float32")
    cos_a = ops.cos(angle)
    sin_a = ops.sin(angle)
    m_rot = ops.convert_to_tensor([
        [cos_a, -sin_a, 0.],
        [sin_a,  cos_a, 0.],
        [0.,     0.,    1.]
    ], dtype="float32")
    m_trans1 = get_translation_matrix(-center_x, -center_y)
    m_trans2 = get_translation_matrix(center_x, center_y)
    return m_trans2 @ m_rot @ m_trans1

def get_zoom_matrix(zx, zy, center_x, center_y):
    """Creates a 3x3 zoom matrix around a center point."""
    center_x = ops.cast(center_x, "float32")
    center_y = ops.cast(center_y, "float32")
    m_scale = ops.convert_to_tensor([
        [zx, 0., 0.],
        [0., zy, 0.],
        [0., 0., 1.]
    ], dtype="float32")
    m_trans1 = get_translation_matrix(-center_x, -center_y)
    m_trans2 = get_translation_matrix(center_x, center_y)
    return m_trans2 @ m_scale @ m_trans1

def get_flip_matrix(horizontal, center_x, center_y):
    """Creates a 3x3 horizontal flip matrix around a center point."""
    if not horizontal:
        return ops.eye(3, dtype="float32")
    center_x = ops.cast(center_x, "float32")
    center_y = ops.cast(center_y, "float32")
    m_flip = ops.convert_to_tensor([
        [-1., 0., 0.],
        [ 0., 1., 0.],
        [ 0., 0., 1.]
    ], dtype="float32")
    m_trans1 = get_translation_matrix(-center_x, -center_y)
    m_trans2 = get_translation_matrix(center_x, center_y)
    return m_trans2 @ m_flip @ m_trans1

def matrix_to_affine_params(matrix):
    """Extracts the 8 parameters from a 3x3 affine matrix."""
    params = ops.stack([
        matrix[0, 0], matrix[0, 1], matrix[0, 2],
        matrix[1, 0], matrix[1, 1], matrix[1, 2],
        matrix[2, 0], matrix[2, 1]
    ])
    return params

In [None]:
def apply_affine_augmentation(all_inputs, targets, mask):
    """
    Applies the same random AFFINE augmentations to all frames in a sequence
    using keras.ops.image.affine_transform and keras.random for seeded randomness.

    Args:
        sequence: A tensor of shape (timesteps, H, W, C).

    Returns:
        The augmented sequence tensor with the same shape.
    """

    sequence = all_inputs

    sequence_shape = keras.ops.shape(sequence)
    img_height = sequence_shape[1]
    img_width = sequence_shape[2]
    img_height_f = ops.cast(img_height, "float32")
    img_width_f = ops.cast(img_width, "float32")
    center_x = img_width_f / 2.0
    center_y = img_height_f / 2.0

    # --- Generate ONE base seed pair for this sequence ---
    # Use tf.random.uniform (stateful) to get a UNIQUE starting point
    # for EACH sequence processed by the map function. This works reliably.
    base_seed = tf.random.uniform([2], minval=0, maxval=2147483647, dtype="int32")

    # --- Generate Random Parameters using tf.random.stateless_uniform ---
    # Note: tf.random.stateless_uniform requires shape as the first arg.
    # 1. Flip Decision
    flip_seed = base_seed + [0, 1] # Derive seed
    # Use tf.random.stateless_uniform directly
    random_flip_val = tf.random.stateless_uniform(shape=(), seed=flip_seed, minval=0.0, maxval=1.0)
    apply_flip = random_flip_val < 0.5

    # 2. Rotation Angle
    rotation_factor = 0.05
    max_angle = rotation_factor * math.pi
    rotation_seed = base_seed + [0, 2] # Derive seed
    # Use tf.random.stateless_uniform directly
    rotation_angle = tf.random.stateless_uniform(shape=(), seed=rotation_seed, minval=-max_angle, maxval=max_angle)

    # 3. Zoom Factors
    zoom_factor_range = (0.9, 1.1)
    zoom_h_seed = base_seed + [0, 3] # Derive seed H
    zoom_w_seed = base_seed + [0, 4] # Derive seed W
    # Use tf.random.stateless_uniform directly
    zoom_height_factor = tf.random.stateless_uniform(shape=(), seed=zoom_h_seed, minval=zoom_factor_range[0], maxval=zoom_factor_range[1])
    zoom_width_factor = tf.random.stateless_uniform(shape=(), seed=zoom_w_seed, minval=zoom_factor_range[0], maxval=zoom_factor_range[1])

    # 4. Translation Shifts
    translation_height_factor = 0.1
    translation_width_factor = 0.1
    shift_h_seed = base_seed + [0, 5] # Derive seed H
    shift_w_seed = base_seed + [0, 6] # Derive seed W
    # Use tf.random.stateless_uniform directly
    random_shift_h = tf.random.stateless_uniform(shape=(), seed=shift_h_seed, minval=-translation_height_factor, maxval=translation_height_factor)
    random_shift_w = tf.random.stateless_uniform(shape=(), seed=shift_w_seed, minval=-translation_width_factor, maxval=translation_width_factor)
    shift_height = random_shift_h * ops.cast(img_height, "float32")
    shift_width = random_shift_w * ops.cast(img_width, "float32")

    # --- Build Combined Affine Matrix (Same as before) ---
    combined_matrix = ops.eye(3, dtype="float32")
    flip_matrix = get_flip_matrix(apply_flip, center_x, center_y)
    zoom_matrix = get_zoom_matrix(zoom_width_factor, zoom_height_factor, center_x, center_y)
    rotation_matrix = get_rotation_matrix(rotation_angle, center_x, center_y)
    translation_matrix = get_translation_matrix(shift_width, shift_height)
    combined_matrix = translation_matrix @ rotation_matrix @ zoom_matrix @ flip_matrix
    combined_matrix = flip_matrix

    # --- Extract 8 parameters for Keras Op (Same as before) ---
    affine_params = matrix_to_affine_params(combined_matrix)

    # --- Apply the transformation using Keras Ops (Same as before) ---
    # keras.ops.image.affine_transform uses the backend implementation which
    # should handle the transformation correctly based on the calculated params.
    augmented_sequence = keras.ops.image.affine_transform(
        sequence,
        affine_params,
        interpolation="bilinear",
        fill_mode="constant"
    )

     # === PERSPECTIVE AUGMENTATION ===

    # Define the magnitude of perspective shift
    perspective_factor = 0.1 # Max shift of 10% of image width/height

    # --- Define Start Points (corners of the original image) ---
    # Order: Top-left, Bottom-left, Top-right, Bottom-right (consistent with docs example)
    start_points = ops.convert_to_tensor([
        [0.0, 0.0],
        [0.0, img_height_f],
        [img_width_f, 0.0],
        [img_width_f, img_height_f]
    ], dtype="float32") # Shape: (4, 2)

    # --- Generate Random Shifts for End Points ---
    perspective_shift_seed = base_seed + [0, 7] # Next available seed offset
    max_shift_x = perspective_factor * img_width_f
    max_shift_y = perspective_factor * img_height_f

    # Generate random shifts for each of the 4 points (x, y coordinates)
    # Shape (4, 2) -> [[dx_tl, dy_tl], [dx_bl, dy_bl], [dx_tr, dy_tr], [dx_br, dy_br]]
    random_shifts = tf.random.stateless_uniform(
        shape=(4, 2),
        seed=perspective_shift_seed,
        minval=-1.0, # Generate normalized shifts first
        maxval=1.0
    )

    # Scale normalized shifts by max possible shift
    scaled_shifts = random_shifts * ops.convert_to_tensor([[max_shift_x, max_shift_y]], dtype="float32") # Shape (1, 2) broadcasts

    # --- Calculate End Points ---
    end_points = start_points + scaled_shifts # Shape (4, 2)

    start_points_tiled = ops.tile(ops.expand_dims(start_points, axis=0), [sequence_shape[0], 1, 1])
    end_points_tiled = ops.tile(ops.expand_dims(end_points, axis=0), [sequence_shape[0], 1, 1])

    # --- Apply the Perspective transformation using Keras Ops ---
    # Apply perspective AFTER affine transformation
    augmented_sequence = keras.ops.image.perspective_transform(
        augmented_sequence, # Input is the result from affine transform
        start_points_tiled,       # Shape (4, 2)
        end_points_tiled,         # Shape (4, 2)
        interpolation="bilinear",
        fill_value=0.0      # Value for pixels outside the warped image
    )

    # --- Optional Color Augmentations (using tf.random.stateless_*) ---
    # Example:
    # brightness_seed = base_seed + [0, 7]
    # brightness_delta = tf.random.stateless_uniform((), seed=brightness_seed, minval=-0.2, maxval=0.2)
    # augmented_sequence = augmented_sequence + brightness_delta
    # augmented_sequence = keras.ops.clip(augmented_sequence, 0.0, 1.0)

    # contrast_seed = base_seed + [0, 8]
    # contrast_factor = tf.random.stateless_uniform((), seed=contrast_seed, minval=0.7, maxval=1.3)
    # pixel_mean = keras.ops.mean(augmented_sequence, axis=[1, 2], keepdims=True)
    # augmented_sequence = (augmented_sequence - pixel_mean) * contrast_factor + pixel_mean
    # augmented_sequence = keras.ops.clip(augmented_sequence, 0.0, 1.0)



    return augmented_sequence

In [None]:
def apply_non_affine_augmentations(sequence, per_image=False):
    """
    Apply non-affine augmentations (blur, brightness, contrast, gamma, noise) to a sequence.

    Args:
        sequence: Input image sequence tensor of shape (timesteps, H, W, C) - uint8 [0,255]
        per_image: If True, each image gets different augmentation; if False, same for all

    Returns:
        Augmented sequence tensor - uint8 [0,255]
    """
    sequence_shape = keras.ops.shape(sequence)
    timesteps = sequence_shape[0]

    # Generate the base_seed internally
    base_seed = tf.random.uniform([2], minval=0, maxval=2147483647, dtype="int32")

    # Work with uint8 input directly - convert to float32 only for processing
    augmented_sequence = keras.ops.cast(sequence, "float32")

    if per_image:
        # Apply different augmentation to each image in the sequence

        # Generate random parameters for each timestep
        # Blur
        blur_seed = base_seed + [1, 0]
        blur_probs = tf.random.stateless_uniform(
            shape=(timesteps,),
            seed=blur_seed,
            minval=0.0,
            maxval=1.0
        )
        blur_sigmas = tf.random.stateless_uniform(
            shape=(timesteps,),
            seed=blur_seed + [0, 1],
            minval=BLUR_SIGMA_RANGE[0],
            maxval=BLUR_SIGMA_RANGE[1]
        )

        # Brightness - scale to uint8 range
        brightness_seed = base_seed + [2, 0]
        brightness_deltas = tf.random.stateless_uniform(
            shape=(timesteps,),
            seed=brightness_seed,
            minval=-25.0,  # Equivalent to -0.1 * 255
            maxval=25.0    # Equivalent to 0.1 * 255
        )

        # Contrast - keep same range as before
        contrast_seed = base_seed + [3, 0]
        contrast_factors = tf.random.stateless_uniform(
            shape=(timesteps,),
            seed=contrast_seed,
            minval=0.9,
            maxval=1.1
        )

        # Gamma - keep same range
        gamma_seed = base_seed + [4, 0]
        gamma_values = tf.random.stateless_uniform(
            shape=(timesteps,),
            seed=gamma_seed,
            minval=0.9,
            maxval=1.1
        )

        # Noise - scale to uint8 range
        noise_seed = base_seed + [5, 0]
        noise_probs = tf.random.stateless_uniform(
            shape=(timesteps,),
            seed=noise_seed,
            minval=0.0,
            maxval=1.0
        )
        noise_stds = tf.random.stateless_uniform(
            shape=(timesteps,),
            seed=noise_seed + [0, 1],
            minval=1.0,    # Equivalent to 0.005 * 255
            maxval=5.0     # Equivalent to 0.02 * 255
        )

        # Apply augmentations per image
        augmented_images = []
        for t in range(timesteps):
            img = augmented_sequence[t]

            # Blur
            if blur_probs[t] < BLUR_PROB:
                # Apply Gaussian blur using tf.nn.depthwise_conv2d with Gaussian kernel
                sigma = blur_sigmas[t]
                kernel_size = ops.cast(ops.ceil(sigma * 3) * 2 + 1, "int32")
                kernel_size = ops.minimum(kernel_size, 15)  # Limit kernel size

                # Create Gaussian kernel
                x = ops.arange(ops.cast(-kernel_size // 2 + 1, "float32"),
                           ops.cast(kernel_size // 2 + 1, "float32"))
                gauss_kernel_1d = ops.exp(-0.5 * ops.square(x / sigma))
                gauss_kernel_1d = gauss_kernel_1d / ops.sum(gauss_kernel_1d)

                # Create 2D kernel
                gauss_kernel_2d = ops.tensordot(gauss_kernel_1d, gauss_kernel_1d, axes=0)
                gauss_kernel_2d = ops.expand_dims(ops.expand_dims(gauss_kernel_2d, -1), -1)

                # Apply convolution
                img_expanded = ops.expand_dims(img, 0)  # Add batch dimension
                img_blurred = tf.nn.depthwise_conv2d(
                    img_expanded, gauss_kernel_2d, strides=[1, 1, 1, 1], padding='SAME'
                )
                img = ops.squeeze(img_blurred, 0)  # Remove batch dimension

            # Brightness - add directly to uint8 values
            img = img + brightness_deltas[t]

            # Contrast - apply around 127.5 (middle of uint8 range)
            img = (img - 127.5) * contrast_factors[t] + 127.5

            # Gamma correction - normalize to [0,1], apply gamma, then scale back
            img_normalized = img / 255.0
            img_normalized = ops.power(ops.maximum(img_normalized, 1e-8), gamma_values[t])
            img = img_normalized * 255.0

            # Noise
            if noise_probs[t] < NOISE_PROB:
                noise = tf.random.stateless_normal(
                    shape=ops.shape(img),
                    seed=noise_seed + [t, 0],
                    stddev=noise_stds[t]
                )
                img = img + noise

            # Clip to valid uint8 range
            img = ops.clip(img, 0.0, 255.0)
            augmented_images.append(img)

        augmented_sequence = ops.stack(augmented_images, axis=0)

    else:
        # Apply same augmentation to all images in the sequence

        # Generate single random parameters for the entire sequence
        # Blur
        blur_seed = base_seed + [1, 0]
        blur_prob = tf.random.stateless_uniform(shape=(), seed=blur_seed, minval=0.0, maxval=1.0)
        blur_sigma = tf.random.stateless_uniform(
            shape=(),
            seed=blur_seed + [0, 1],
            minval=BLUR_SIGMA_RANGE[0],
            maxval=BLUR_SIGMA_RANGE[1]
        )

        # Brightness - scale to uint8 range
        brightness_seed = base_seed + [2, 0]
        brightness_delta = tf.random.stateless_uniform(
            shape=(),
            seed=brightness_seed,
            minval=-25.0,  # Equivalent to -0.1 * 255
            maxval=25.0    # Equivalent to 0.1 * 255
        )

        # Contrast
        contrast_seed = base_seed + [3, 0]
        contrast_factor = tf.random.stateless_uniform(
            shape=(),
            seed=contrast_seed,
            minval=0.9,
            maxval=1.1
        )

        # Gamma
        gamma_seed = base_seed + [4, 0]
        gamma_value = tf.random.stateless_uniform(
            shape=(),
            seed=gamma_seed,
            minval=0.9,
            maxval=1.1
        )

        # Noise - scale to uint8 range
        noise_seed = base_seed + [5, 0]
        noise_prob = tf.random.stateless_uniform(shape=(), seed=noise_seed, minval=0.0, maxval=1.0)
        noise_std = tf.random.stateless_uniform(
            shape=(),
            seed=noise_seed + [0, 1],
            minval=1.0,    # Equivalent to 0.005 * 255
            maxval=5.0     # Equivalent to 0.02 * 255
        )

        # Apply blur to entire sequence
        if blur_prob < BLUR_PROB:
            # Create Gaussian kernel
            kernel_size = ops.cast(ops.ceil(blur_sigma * 3) * 2 + 1, "int32")
            kernel_size = ops.minimum(kernel_size, 15)  # Limit kernel size

            x = ops.arange(ops.cast(-kernel_size // 2 + 1, "float32"),
                       ops.cast(kernel_size // 2 + 1, "float32"))
            gauss_kernel_1d = ops.exp(-0.5 * ops.square(x / blur_sigma))
            gauss_kernel_1d = gauss_kernel_1d / ops.sum(gauss_kernel_1d)

            # Create 2D kernel
            gauss_kernel_2d = ops.tensordot(gauss_kernel_1d, gauss_kernel_1d, axes=0)
            gauss_kernel_2d = ops.expand_dims(ops.expand_dims(gauss_kernel_2d, -1), -1)

            # Apply to all frames
            augmented_sequence = tf.nn.depthwise_conv2d(
                augmented_sequence, gauss_kernel_2d, strides=[1, 1, 1, 1], padding='SAME'
            )

        # Apply brightness
        augmented_sequence = augmented_sequence + brightness_delta

        # Apply contrast around 127.5 (middle of uint8 range)
        augmented_sequence = (augmented_sequence - 127.5) * contrast_factor + 127.5

        # Apply gamma correction
        sequence_normalized = augmented_sequence / 255.0
        sequence_normalized = ops.power(ops.maximum(sequence_normalized, 1e-8), gamma_value)
        augmented_sequence = sequence_normalized * 255.0

        # Apply noise
        if noise_prob < NOISE_PROB:
            noise = tf.random.stateless_normal(
                shape=ops.shape(augmented_sequence),
                seed=noise_seed + [0, 2],
                stddev=noise_std
            )
            augmented_sequence = augmented_sequence + noise

        # Clip to valid uint8 range
        augmented_sequence = ops.clip(augmented_sequence, 0.0, 255.0)

    # Convert back to uint8 to save memory
    augmented_sequence = keras.ops.cast(augmented_sequence, "uint8")

    return augmented_sequence

In [None]:
def augment_sequence_complete(all_inputs, targets, mask):
    """
    Complete augmentation function that applies both affine and non-affine augmentations.

    Affine transformations (flip, rotation, zoom, translation, perspective) are always
    applied consistently across the sequence. Non-affine transformations (blur, brightness,
    contrast, gamma, noise) can be applied either consistently or per-image based on
    PER_IMAGE_AUGMENTATION setting.
    """

    input_all_images = all_inputs["Input_All_Images"]
    input_all_coords = all_inputs["Input_All_Coords"]
    input_cal_mask = all_inputs["Input_Calibration_Mask"]

    augmented_sequence = input_all_images

    # Apply non-affine augmentations (blur, brightness, contrast, gamma, noise)
    augmented_sequence = apply_non_affine_augmentations(
        augmented_sequence,
        per_image=PER_IMAGE_AUGMENTATION
    )

    # Apply affine augmentations (rotation, perspective shift, etc.)
    augmented_sequence = apply_affine_augmentation(
        augmented_sequence, targets, mask
    )

    inputs = {
        "Input_All_Images": augmented_sequence,
        "Input_All_Coords": input_all_coords,
        "Input_Calibration_Mask": input_cal_mask,
    }

    return inputs, targets, mask

In [None]:
if AUGMENTATION:
  train_ds_for_model_augmented = train_ds_for_model.map(
    augment_sequence_complete,
    num_parallel_calls=tf.data.AUTOTUNE
  ).prefetch(tf.data.AUTOTUNE)
else:
  train_ds_for_model_augmented = train_ds_for_model

### Visualizing augmentation

In [None]:
def plot_sequence_grid(sequence, title="Image Sequence Timesteps", figsize=None, cmap='gray'):
    """
    Plots all frames (timesteps) of an image sequence in a grid layout.

    This helps visualize if the same augmentation was applied consistently
    across the time dimension.

    Args:
        sequence: A NumPy array or TensorFlow tensor representing the image
                  sequence. Expected shape: (timesteps, H, W, C).
        title (str): The overall title for the plot.
        figsize (tuple, optional): Desired figure size (width, height) in inches.
                                   If None, a default size is calculated based
                                   on the number of timesteps.
        cmap (str): The colormap to use for grayscale images (when C=1).
                    Defaults to 'viridis'. Ignored for RGB images (C=3).
    """
    # --- Input Validation and Conversion ---
    if hasattr(sequence, 'numpy'): # Check if it's a TF tensor
        sequence = sequence.numpy()

    if not isinstance(sequence, np.ndarray):
        raise TypeError("Input 'sequence' must be a NumPy array or TensorFlow tensor.")

    if sequence.ndim != 4:
        raise ValueError(f"Input 'sequence' must be 4-dimensional (timesteps, H, W, C), got shape {sequence.shape}")

    num_timesteps, height, width, channels = sequence.shape

    if num_timesteps == 0:
        print("Sequence has 0 timesteps, nothing to plot.")
        return

    # --- Determine Grid Size ---
    # Calculate a reasonable grid size (aiming for roughly square or slightly wider)
    # Ignore the 12x12 request as it's usually too large for typical timesteps;
    # a dynamic grid is more practical.
    ncols = int(math.ceil(math.sqrt(num_timesteps)))
    # Try to make it slightly wider if not square - adjust ncols calculation if needed
    # A simpler approach: use a max number of columns
    # ncols = min(num_timesteps, 6) # Example: Limit to 6 columns max
    nrows = int(math.ceil(num_timesteps / float(ncols)))

    # --- Determine Figure Size ---
    if figsize is None:
        # Default figsize: scale based on grid size
        # Adjust the multipliers (2.5, 2) as needed for your image aspect ratio/preference
        figsize = (ncols * 2.5, nrows * 2)

    # --- Create Subplots ---
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)

    # Flatten axes array for easy iteration, handle cases with single row/col
    if isinstance(axes, plt.Axes): # Single subplot case (nrows=1, ncols=1)
        axes_flat = [axes]
    else:
        axes_flat = axes.flat # Flatten the 2D array of axes

    # --- Prepare Data for Plotting ---
    plot_cmap = None
    if channels == 1:
        plot_cmap = cmap # Use provided cmap for grayscale
        # Remove channel dim for imshow if it exists and is 1
        sequence_to_plot = np.squeeze(sequence, axis=-1)
    elif channels == 3:
         sequence_to_plot = sequence # imshow handles RGB (expects float 0-1 or int 0-255)
         # Ensure data range is suitable for imshow if it's float
         if sequence_to_plot.dtype == np.float32 or sequence_to_plot.dtype == np.float64:
              sequence_to_plot = np.clip(sequence_to_plot, 0, 1)
    else:
        # Handle other channel numbers? Plot the first channel as grayscale.
        print(f"Warning: Sequence has {channels} channels. Plotting the first channel as grayscale.")
        plot_cmap = cmap
        sequence_to_plot = sequence[..., 0] # Take the first channel

    # Determine global vmin/vmax for consistent coloring across grayscale timesteps
    vmin = np.min(sequence_to_plot)
    vmax = np.max(sequence_to_plot)

    # --- Plot Each Timestep ---
    for t in range(num_timesteps):
        ax = axes_flat[t] # Get the current subplot axis
        ax.imshow(sequence_to_plot[t], cmap=plot_cmap, vmin=vmin, vmax=vmax)
        ax.set_title(f"t={t}")
        ax.axis('off') # Hide axes ticks and labels

    # --- Hide Unused Subplots ---
    # If nrows * ncols > num_timesteps, hide the remaining axes
    for t in range(num_timesteps, len(axes_flat)):
         axes_flat[t].axis('off')

    # --- Final Touches ---
    fig.suptitle(title, fontsize=16)
    # Adjust layout to prevent titles/labels overlapping and make space for suptitle
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

In [None]:
for element in train_ds_for_model_augmented.take(1):
  plot_sequence_grid(element[0]["Input_All_Images"])

## Run the model

In [None]:
learning_rate_scheduler = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = INITIAL_LEARNING_RATE,
    decay_steps = INDIVIDUALS // BATCH_SIZE * DECAY_EPOCHS,
    alpha = DECAY_ALPHA,
    warmup_target = LEARNING_RATE,
    warmup_steps = INDIVIDUALS // BATCH_SIZE * WARMUP_EPOCHS
)

In [None]:
mask_model = create_masked_model()
mask_model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=learning_rate_scheduler),
    loss=normalized_weighted_euc_dist,
    jit_compile=True,
)
mask_model.summary()

In [None]:
mask_model.fit(train_ds_for_model_augmented.batch(BATCH_SIZE), epochs=TRAIN_EPOCHS, callbacks=[WandbMetricsLogger()])

# Save and export

In [None]:
mask_model.save('full_model.keras')

In [None]:
mask_model.save_weights('full_model.weights.h5')

In [None]:
wandb.save('full_model.keras')
wandb.save('full_model.weights.h5')

In [None]:
masked_dataset = prepare_masked_dataset(train_data_rescaled, calibration_points=scaled_cal_points)

In [None]:
test_ds = masked_dataset.map(
    prepare_model_inputs,
    num_parallel_calls=tf.data.AUTOTUNE
).prefetch(tf.data.AUTOTUNE)

In [None]:
test_ds.element_spec

In [None]:
predictions = mask_model.predict(test_ds.batch(1))

y_true = np.zeros(predictions.shape, dtype=np.float32)

i = 0
for e in test_ds.as_numpy_iterator():
  y_true[i,:,:] = e[1]
  i = i+1
# y_true = np.array(y_true).reshape(-1, 144, 2)

# # this step is slower than expected. not sure what's going on.
# predictions = full_model.predict(test_ds.)

batch_losses = np.zeros((predictions.shape[0], predictions.shape[1]))
for i in range(len(predictions)):
  loss = normalized_weighted_euc_dist(y_true[i], predictions[i]).numpy()
  batch_losses[i,:] = loss

# Get mean per subject
batch_losses = np.mean(batch_losses, axis=1)

In [None]:
table = wandb.Table(data=[[s] for s in batch_losses], columns=["Loss"])
final_loss_hist = wandb.plot.histogram(table, value="Loss", title="Normalized Euclidean Distance")

In [None]:
final_loss_mean = np.mean(batch_losses)

In [None]:
wandb.log({"final_val_loss_hist": final_loss_hist, "final_loss_mean": final_loss_mean})

In [None]:
plt.figure(figsize=(8, 6))
plt.hist(batch_losses, bins=20, edgecolor='black')
plt.title('Histogram of Batch Losses')
plt.xlabel('Loss')
plt.ylabel('Frequency')
plt.show()

In [None]:
wandb.finish()

In [None]:
from google.colab import runtime
runtime.unassign()