# Setup

In [None]:
!pip install osfclient --quiet
!pip install git+https://github.com/jspsych/eyetracking-utils.git --quiet
!pip install keras_cv --quiet
!pip install plotnine --quiet
!pip install wandb --quiet
!pip install albumentations --quiet

In [None]:
import os
import tensorflow as tf
import numpy as np
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow.keras as keras
import keras_cv
from plotnine import ggplot, geom_point, aes, geom_line, scale_y_reverse, theme_void, scale_color_manual
import pandas as pd
import wandb
from wandb.integration.keras import WandbMetricsLogger
import albumentations as A

import matplotlib.pyplot as plt
from google.colab import userdata



import et_util.dataset_utils as dataset_utils
import et_util.embedding_preprocessing as embed_pre
import et_util.model_layers as model_layers
from et_util import experiment_utils
from et_util.custom_loss import normalized_weighted_euc_dist
from et_util.model_analysis import plot_model_performance

In [None]:
os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
os.environ['OSF_TOKEN'] = userdata.get('osftoken')
os.environ['OSF_USERNAME'] = userdata.get('osfusername')

In [None]:
keras.mixed_precision.set_global_policy('mixed_float16')

# Configure W&B experiment

In [None]:
wandb.login()

In [None]:
# Fixed constants
MAX_TARGETS = 144

# Config constants
EMBEDDING_DIM = 200
RIDGE_REGULARIZATION = 0.1
TRAIN_EPOCHS = 30
MIN_CAL_POINTS = 8
MAX_CAL_POINTS = 40
DENSE_NET_STACKWISE_NUM_REPEATS = [4,4,4]
LEARNING_RATE = 0.001
AUGMENTATION = True

In [None]:
config = {
    "embedding_dim": EMBEDDING_DIM,
    "ridge_regularization": RIDGE_REGULARIZATION,
    "train_epochs": TRAIN_EPOCHS,
    "min_cal_points": MIN_CAL_POINTS,
    "max_cal_points": MAX_CAL_POINTS,
    "dense_net_stackwise_num_repeats": DENSE_NET_STACKWISE_NUM_REPEATS,
    "learning_rate": LEARNING_RATE,
    "augmentation": AUGMENTATION,
}

In [None]:
run = wandb.init(
    project='eye-tracking-dense-full-data-set-single-eye',
    config=config
)

# Download dataset from OSF

In [None]:
!osf -p 6b5cd fetch single_eye_tfrecords.tar.gz

# Process raw data records into TF Dataset

In [None]:
!mkdir single_eye_tfrecords
!tar -xf single_eye_tfrecords.tar.gz -C single_eye_tfrecords

In [None]:
def parse(element):
    """Process function that parses a tfr element in a raw dataset for process_tfr_to_tfds function.
    Gets mediapipe landmarks, raw image, image width, image height, subject id, and xy labels.
    Use for data generated with make_single_example_landmarks_and_jpg (i.e. data in
    jpg_landmarks_tfrecords.tar.gz)

    :param element: tfr element in raw dataset
    :return: image, label(x,y), landmarks, subject_id
    """

    data_structure = {
        'landmarks': tf.io.FixedLenFeature([], tf.string),
        'img_width': tf.io.FixedLenFeature([], tf.int64),
        'img_height': tf.io.FixedLenFeature([], tf.int64),
        'x': tf.io.FixedLenFeature([], tf.float32),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'eye_img': tf.io.FixedLenFeature([], tf.string),
        'subject_id': tf.io.FixedLenFeature([], tf.int64),
    }

    content = tf.io.parse_single_example(element, data_structure)

    landmarks = content['landmarks']
    raw_image = content['eye_img']
    width = content['img_width']
    height = content['img_height']
    depth = 3
    label = [content['x'], content['y']]
    subject_id = content['subject_id']

    landmarks = tf.io.parse_tensor(landmarks, out_type=tf.float32)
    landmarks = tf.reshape(landmarks, shape=(478, 3))

    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)

    return image, landmarks, label, subject_id

In [None]:
train_data, validation_data, test_data = dataset_utils.process_tfr_to_tfds(
    'single_eye_tfrecords/',
    parse,
    train_split=1.0,
    val_split=0.0,
    test_split=0.0,
    random_seed=12604,
    group_function=lambda img, landmarks, coords, z: z
)

## Rescale the `x,y` coordinates to be 0-1 instead of 0-100.

In [None]:
def rescale_coords_map(eyes, mesh, coords, id):
  return eyes, mesh, tf.divide(coords, tf.constant([100.])), id

In [None]:
train_data_rescaled = train_data.map(rescale_coords_map)
validation_data_rescaled = validation_data.map(rescale_coords_map)

# Generate dataset that has calibration points, target point, and target output

In [None]:
def pad_to_fixed_size(tensor, target_shape, pad_value=0):
    """Pad tensor to target shape."""
    current_shape = tf.shape(tensor)
    paddings = tf.maximum(0, target_shape - current_shape)
    padded = tf.pad(tensor, [[0, paddings[0]], [0, paddings[1]], [0, paddings[2]], [0, paddings[3]]])
    # Ensure fixed shape
    padded = tf.ensure_shape(padded, [target_shape[0], target_shape[1], target_shape[2], target_shape[3]])
    return padded

def create_padding_mask(tensor, max_len):
    """Create a mask indicating which elements are padding."""
    actual_len = tf.shape(tensor)[0]
    mask = tf.range(max_len) < actual_len
    return tf.cast(mask, tf.float32)

Cols = 5, 11, 17, 23, 29, 35, 41, 47, 53, 59, 65, 71, 77, 83, 89, 95

Rows = 5, 16.25, 27.5, 38.75, 50, 61.25, 72.5, 83.75, 95  


Fixed Points as calibration

In [None]:
cal_points = tf.constant([
    [5, 5],
    [5, 27.5],
    [5, 50],
    [5, 72.5],
    [5, 95],
    [35, 5],
    [35, 27.5],
    [35, 50],
    [35, 72.5],
    [35, 95],
    [65, 5],
    [65, 27.5],
    [65, 50],
    [65, 72.5],
    [65, 95],
    [95, 5],
    [95, 27.5],
    [95, 50],
    [95, 72.5],
    [95, 95],
], dtype=tf.float32)

scaled_cal_points = tf.divide(cal_points, tf.constant([100.]))

def filter_cal_points(image, mesh, coords, id):

  return tf.reduce_any(tf.reduce_all(tf.equal(coords, scaled_cal_points), axis=1))

def filter_non_cal_points(image, mesh, coords, id):

  return tf.reduce_all(tf.reduce_any(tf.not_equal(coords, scaled_cal_points), axis=1))

def filter_subjects_missing_cal_points(cal, target):

  return tf.equal(tf.shape(cal[0])[0], tf.shape(cal_points)[0])

def map_for_calibration_pts(image, mesh, coords, id):
  img = tf.reshape(image, (-1, 36, 144, 1))

  # No padding here as we'll handle padding in the merged function
  return img, coords

def map_for_non_calibration_pts(image, mesh, coords, id):
  img = tf.reshape(image, (-1, 36, 144, 1))

  # No padding here as we'll handle padding in the merged function
  return img, coords

def map_for_merged(cal, non_cal):
  # Unpack the inputs
  cal_imgs, cal_coords = cal
  target_imgs, target_coords = non_cal

  # Get the sizes
  cal_size = tf.shape(cal_imgs)[0]
  target_size = tf.shape(target_imgs)[0]

  # Create masks
  cal_mask = create_padding_mask(cal_imgs, MAX_CAL_POINTS)
  target_mask = create_padding_mask(target_imgs, MAX_TARGETS)  # Assuming max target points is 144

  # Pad images and coordinates
  padded_cal_imgs = tf.pad(cal_imgs, [[0, MAX_CAL_POINTS - cal_size], [0, 0], [0, 0], [0, 0]])
  padded_cal_coords = tf.pad(cal_coords, [[0, MAX_CAL_POINTS - cal_size], [0, 0]])
  padded_target_imgs = tf.pad(target_imgs, [[0, MAX_TARGETS - target_size], [0, 0], [0, 0], [0, 0]])
  padded_target_coords = tf.pad(target_coords, [[0, MAX_TARGETS - target_size], [0, 0]])

  # Ensure fixed shapes
  padded_cal_imgs = tf.ensure_shape(padded_cal_imgs, [MAX_CAL_POINTS, 36, 144, 1])
  padded_cal_coords = tf.ensure_shape(padded_cal_coords, [MAX_CAL_POINTS, 2])
  padded_target_imgs = tf.ensure_shape(padded_target_imgs, [MAX_TARGETS, 36, 144, 1])
  padded_target_coords = tf.ensure_shape(padded_target_coords, [MAX_TARGETS, 2])

  # Return with masks
  return (padded_cal_imgs, padded_cal_coords, cal_mask, padded_target_imgs, target_mask), padded_target_coords

def map_for_merged_with_id(cal, non_cal, id):
  # Unpack the inputs
  cal_imgs, cal_coords = cal
  target_imgs, target_coords = non_cal

  # Get the sizes
  cal_size = tf.shape(cal_imgs)[0]
  target_size = tf.shape(target_imgs)[0]

  # Create masks
  cal_mask = create_padding_mask(cal_imgs, MAX_CAL_POINTS)
  target_mask = create_padding_mask(target_imgs, MAX_TARGETS)  # Assuming max target points is 144

  # Pad images and coordinates
  padded_cal_imgs = tf.pad(cal_imgs, [[0, MAX_CAL_POINTS - cal_size], [0, 0], [0, 0], [0, 0]])
  padded_cal_coords = tf.pad(cal_coords, [[0, MAX_CAL_POINTS - cal_size], [0, 0]])
  padded_target_imgs = tf.pad(target_imgs, [[0, MAX_TARGETS - target_size], [0, 0], [0, 0], [0, 0]])
  padded_target_coords = tf.pad(target_coords, [[0, MAX_TARGETS - target_size], [0, 0]])

  # Ensure fixed shapes
  padded_cal_imgs = tf.ensure_shape(padded_cal_imgs, [MAX_CAL_POINTS, 36, 144, 1])
  padded_cal_coords = tf.ensure_shape(padded_cal_coords, [MAX_CAL_POINTS, 2])
  padded_target_imgs = tf.ensure_shape(padded_target_imgs, [MAX_TARGETS, 36, 144, 1])
  padded_target_coords = tf.ensure_shape(padded_target_coords, [MAX_TARGETS, 2])

  # Return with masks
  return (padded_cal_imgs, padded_cal_coords, cal_mask, padded_target_imgs, target_mask), padded_target_coords, id

def reducer_function(k, ds):
  ds_random = ds.shuffle(MAX_TARGETS)

  n_cal_points = tf.random.uniform(shape=[], minval=MIN_CAL_POINTS, maxval=MAX_CAL_POINTS, dtype=tf.int64)

  calibration_points = ds_random.take(n_cal_points).batch(n_cal_points).map(map_for_calibration_pts).repeat()

  non_calibration_points = ds.batch(MAX_TARGETS).map(map_for_non_calibration_pts)

  merged = tf.data.Dataset.zip(calibration_points, non_calibration_points)
  return merged.map(map_for_merged, num_parallel_calls=tf.data.AUTOTUNE)

# group data by subject id, create datasets with calibration points
def reducer_function_fixed_pts(subject_id, ds):
  non_cal_points = ds.batch(MAX_TARGETS, drop_remainder=True).map(map_for_non_calibration_pts)

  points = ds.filter(filter_cal_points).batch(len(cal_points), drop_remainder=True).map(map_for_calibration_pts).repeat()

  merged = tf.data.Dataset.zip(points, non_cal_points)
  return merged.map(map_for_merged)

def reducer_function_fixed_pts_with_id(subject_id, ds):
  non_cal_points = ds.batch(MAX_TARGETS, drop_remainder=True).map(map_for_non_calibration_pts)

  points = ds.filter(filter_cal_points).batch(len(cal_points)).map(map_for_calibration_pts).repeat()

  merged = tf.data.Dataset.zip(points, non_cal_points)

  return merged.map(lambda x, y: map_for_merged_with_id(x, y, subject_id))

## Create augmentation pipeline

In [None]:
augmentation_layers = [
    #keras.layers.RandomBrightness(factor=0.1),
    #keras.layers.RandomContrast(factor=0.1),
    keras.layers.RandomRotation(factor=.02, fill_mode='constant', fill_value=0),
    #keras.layers.RandomTranslation(height_factor=0.05, width_factor=0.05)
]

augmentation_model = keras.Sequential(augmentation_layers)

@tf.function
def apply_consistent_augmentations(inputs, augmentation_model):
    """Apply consistent augmentations to calibration and target images.

    Args:
        inputs: Tuple of (cal_imgs, cal_coords, cal_mask, target_imgs, target_mask)
        augmentation_fn: Function created by create_augmentation_fn

    Returns:
        Tuple with augmented images in the same structure
    """
    cal_imgs, cal_coords, cal_mask, target_imgs, target_mask = inputs

    merged_cal_images = tf.transpose(cal_imgs, perm=[3, 1, 2, 0])
    merged_target_images = tf.transpose(target_imgs, perm=[3, 1, 2, 0])

    merged_cal_imgs_aug = augmentation_model(merged_cal_images)
    merged_target_imgs_aug = augmentation_model(merged_target_images)

    cal_imgs_aug = tf.transpose(merged_cal_imgs_aug, perm=[3, 1, 2, 0])
    target_imgs_aug = tf.transpose(merged_target_imgs_aug, perm=[3, 1, 2, 0])

    return cal_imgs_aug, cal_coords, cal_mask, target_imgs_aug, target_mask

In [None]:
t0 = train_data_rescaled.cache()

t1 = t0.group_by_window(
    key_func = lambda img, m, c, z: z,
    reduce_func = reducer_function,
    window_size = 200
)

t2 = t0.group_by_window(
    key_func = lambda img, m, c, z: z,
    reduce_func = reducer_function_fixed_pts_with_id,
    window_size = 200
)

# Apply to dataset - ensure consistent shapes
train_ds = t1.map(
    lambda x, y: (apply_consistent_augmentations(x, augmentation_model), y),
    num_parallel_calls=tf.data.AUTOTUNE
)

# Update the batching step to use a custom padded_batch function
def prepare_batch_for_model(features, labels):
    """Ensure inputs are correctly formatted for the model"""
    cal_imgs, cal_coords, cal_mask, target_imgs, target_mask = features

    # Prepare inputs as a dictionary matching the model's expected inputs
    inputs = {
        "Input_Calibration_Eyes": cal_imgs,
        "Input_Calibration_Points": cal_coords,
        "Input_Calibration_Mask": cal_mask,
        "Input_Target_Eyes": target_imgs,
        "Input_Target_Mask": target_mask
    }

    return inputs, labels

# Use the updated version with proper input preparation
train_ds_for_model = train_ds.map(prepare_batch_for_model).prefetch(tf.data.AUTOTUNE)


In [None]:
train_ds.element_spec

## Visualize augmentation

In [None]:
def visualize_augmentations(dataset, num_examples=5):
    """Visualize augmentations from a tf.data.Dataset."""

    # Get a sample element from the dataset
    for element in dataset.take(1):
        # Access the calibration images from the element
        images = element[0][0]  # Access all calibration images

    # Calculate grid dimensions
    num_images = images.shape[0]  # Get the number of images in the batch
    grid_cols = 5  # Number of columns in the grid
    grid_rows = (num_images + grid_cols - 1) // grid_cols  # Calculate number of rows

    # Create a figure to display results
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(15, 3 * grid_rows))  # Create subplots in a grid

    # Display all images
    for i in range(num_images):
        row = i // grid_cols
        col = i % grid_cols
        axes[row, col].imshow(images[i].numpy().squeeze(), cmap='gray')  # Display image
        axes[row, col].set_title(f"Image {i + 1}")  # Set title for each subplot
        axes[row, col].axis('off')  # Turn off axis

    # Hide empty subplots if any
    for i in range(num_images, grid_rows * grid_cols):
        row = i // grid_cols
        col = i % grid_cols
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.show()

# Use this to check your augmentations
visualize_augmentations(train_ds)

# Construct Model



## Eye image processing

In [None]:
def create_embedding_model():
    input_eyes = keras.layers.Input(shape=(36,144,1))

    # Continue with the backbone
    backbone = keras_cv.models.DenseNetBackbone(
        include_rescaling=False,
        input_shape=(36,144,1),
        stackwise_num_repeats=DENSE_NET_STACKWISE_NUM_REPEATS
    )
    backbone_encoder = backbone(input_eyes)

    flatten_compress = keras.layers.Flatten()(backbone_encoder)
    eye_embedding = keras.layers.Dense(units=EMBEDDING_DIM, activation="tanh")(flatten_compress)

    embedding_model = keras.Model(inputs=input_eyes, outputs=eye_embedding, name="Eye_Image_Embedding")

    return embedding_model

## Regression Layer

In [None]:
class WeightedRidgeRegressionLayer(keras.layers.Layer):
    def __init__(self, lambda_ridge, embedding_dim, epsilon=1e-6, **kwargs):
        self.lambda_ridge = lambda_ridge
        self.epsilon = epsilon
        self.embedding_dim = embedding_dim  # Required static embedding dimension
        super(WeightedRidgeRegressionLayer, self).__init__(**kwargs)

        # Create identity matrix constant during initialization
        # This ensures it's created only once and is available immediately
        self._identity = tf.eye(embedding_dim, dtype=tf.float32)

    @tf.function(jit_compile=True)
    def call(self, inputs):
        unknown_embeddings, calibration_embeddings, calibration_coords, weights, cal_mask, target_mask = inputs

        # Explicitly set shapes where possible (assuming batch_size can remain dynamic)
        batch_size = tf.shape(calibration_embeddings)[0]

        # Force shapes to match the static embedding dimension
        calibration_embeddings = tf.ensure_shape(
            calibration_embeddings, [None, None, self.embedding_dim]
        )
        unknown_embeddings = tf.ensure_shape(
            unknown_embeddings, [None, None, self.embedding_dim]
        )

        # Cast all inputs to the same precision
        X = tf.cast(calibration_embeddings, tf.float32)  # (batch_size, n_calibration, embedding_dim)
        y = tf.cast(calibration_coords, tf.float32)      # (batch_size, n_calibration, 2)
        w = tf.cast(weights, tf.float32)                 # (batch_size, n_calibration)
        cal_mask = tf.cast(cal_mask, tf.float32)         # (batch_size, n_calibration)
        target_mask = tf.cast(target_mask, tf.float32)   # (batch_size, n_target)

        # Ensure shapes are consistent
        n_calibration = tf.shape(X)[1]
        n_target = tf.shape(unknown_embeddings)[1]

        # Broadcast the identity matrix to batch size
        I = tf.broadcast_to(self._identity, [batch_size, self.embedding_dim, self.embedding_dim])

        # Apply mask to weights - ensure this is broadcastable
        w = w * cal_mask

        # Apply weights to X and y with explicit reshaping
        w_sqrt = tf.sqrt(w)
        w_sqrt = tf.reshape(w_sqrt, [batch_size, n_calibration, 1])  # Reshape for broadcasting

        X_weighted = X * w_sqrt  # (batch_size, n_calibration, embedding_dim)
        y_weighted = y * w_sqrt  # (batch_size, n_calibration, 2)

        # Matrix multiplication with explicit transpose
        X_t = tf.transpose(X_weighted, perm=[0, 2, 1])  # (batch_size, embedding_dim, n_calibration)
        X_t_X = tf.matmul(X_t, X_weighted)  # (batch_size, embedding_dim, embedding_dim)

        # Add regularization with stable epsilon
        ridge_term = tf.multiply(I, self.lambda_ridge + self.epsilon)
        lhs = tf.add(X_t_X, ridge_term)  # (batch_size, embedding_dim, embedding_dim)

        # Compute right-hand side
        rhs = tf.matmul(X_t, y_weighted)  # (batch_size, embedding_dim, 2)

        # Solve linear system
        # Add stability check for better XLA compatibility
        is_singular = tf.math.reduce_any(tf.math.is_nan(lhs))
        kernel = tf.cond(
            is_singular,
            lambda: tf.zeros([batch_size, self.embedding_dim, 2], dtype=tf.float32),
            lambda: tf.linalg.solve(lhs, rhs)
        )  # Shape: [batch_size, embedding_dim, 2]

        # Apply regression to unknown point with explicit masking
        unknown_embeddings = tf.cast(unknown_embeddings, tf.float32)  # Shape: [batch_size, n_target, embedding_dim]
        target_mask_expanded = tf.reshape(target_mask, [batch_size, n_target, 1])
        unknown_embeddings_masked = unknown_embeddings * target_mask_expanded

        # Matrix multiplication with proper dimensions:
        # unknown_embeddings_masked: [batch_size, n_target, embedding_dim]
        # kernel:                    [batch_size, embedding_dim, 2]
        # Result:                    [batch_size, n_target, 2]
        output = tf.matmul(unknown_embeddings_masked, kernel)  # Shape: [batch_size, n_target, 2]

        # Apply final mask
        output = output * target_mask_expanded

        return output

    def compute_output_shape(self, input_shapes):
        unknown_embeddings_shape, _, _, _, _, _ = input_shapes
        return (unknown_embeddings_shape[0], unknown_embeddings_shape[1], 2)

    def build(self, input_shapes):
        # This method is called once before the first call() to build the layer
        # We already handle initialization in __init__, but this is a good place to validate shapes
        unknown_shape, cal_shape, coords_shape, weights_shape, cal_mask_shape, target_mask_shape = input_shapes

        if unknown_shape[-1] != self.embedding_dim:
            raise ValueError(f"Unknown embeddings dimension {unknown_shape[-1]} doesn't match specified embedding_dim {self.embedding_dim}")

        if cal_shape[-1] != self.embedding_dim:
            raise ValueError(f"Calibration embeddings dimension {cal_shape[-1]} doesn't match specified embedding_dim {self.embedding_dim}")

        # We don't need to create any weights in build() since we don't have trainable parameters
        super(WeightedRidgeRegressionLayer, self).build(input_shapes)

    def get_config(self):
        config = super(WeightedRidgeRegressionLayer, self).get_config()
        config.update({
            "lambda_ridge": self.lambda_ridge,
            "epsilon": self.epsilon,
            "embedding_dim": self.embedding_dim
        })
        return config

In [None]:
def create_full_model():

  embedding_model = create_embedding_model()

  input_calibration_eyes = keras.layers.Input(shape=(MAX_CAL_POINTS, 36, 144, 1), name="Input_Calibration_Eyes")
  input_calibration_points = keras.layers.Input(shape=(MAX_CAL_POINTS, 2), name="Input_Calibration_Points")
  input_calibration_mask = keras.layers.Input(shape=(MAX_CAL_POINTS,), name="Input_Calibration_Mask")

  input_target_eyes = keras.layers.Input(shape=(144, 36, 144, 1), name="Input_Target_Eyes")
  input_target_mask = keras.layers.Input(shape=(144,), name="Input_Target_Mask")

  # Apply the embedding model using masking
  target_embedding = keras.layers.TimeDistributed(embedding_model)(input_target_eyes)
  calibration_embeddings = keras.layers.TimeDistributed(embedding_model)(input_calibration_eyes)

  calibration_weights = keras.layers.Dense(1, activation="sigmoid", name="Calibration_Weights")(calibration_embeddings)
  calibration_weights_reshaped = keras.layers.Reshape((-1,), name="Calibration_Weights_Reshaped")(calibration_weights)

  ridge = WeightedRidgeRegressionLayer(RIDGE_REGULARIZATION, embedding_dim=EMBEDDING_DIM)([
    target_embedding,
    calibration_embeddings,
    input_calibration_points,
    calibration_weights_reshaped,
    input_calibration_mask,
    input_target_mask
  ])

  full_model = keras.Model(inputs=[
    input_calibration_eyes,
    input_calibration_points,
    input_calibration_mask,
    input_target_eyes,
    input_target_mask
  ], outputs=ridge, name="FullEyePredictionModel")

  return full_model

## Full trainable model

In [None]:
full_model = create_full_model()

full_model.summary()

def masked_normalized_weighted_euc_dist(y_true, y_pred):
    """Custom loss function that calculates a weighted Euclidean distance between two sets of points,
    respecting masks for padded sequences.

    Weighting multiplies the x-coordinate of each input by 1.778 (derived from the 16:9 aspect ratio of most laptops)
    to match the scale of the y-axis. For interpretability, the distances are normalized to the diagonal
    such that the maximum distance between two points (corner to corner) is 100.

    :param y_true (tensor): A tensor of shape (batch_size, seq_len, 2) containing ground-truth x- and y- coordinates
    :param y_pred (tensor): A tensor of shape (batch_size, seq_len, 2) containing predicted x- and y- coordinates

    :returns: A tensor of shape (1,) with the weighted and normalized euclidean distances between valid points.
    """
    # Extract shape information
    batch_size = tf.shape(y_true)[0]
    seq_len = tf.shape(y_true)[1]

    # Create a mask based on zero-padding in y_pred
    # Assuming padded values are exactly 0 in both dimensions
    mask = tf.reduce_any(tf.not_equal(y_pred, 0), axis=-1)
    mask = tf.cast(mask, tf.float32)

    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    # Weighting treats y-axis as unit-scale and creates a rectangle that's 177.8x100 units.
    x_weight = tf.constant([1.778, 1.0], dtype=tf.float32)

    # Multiply x-coordinate by 16/9 = 1.778
    y_true_weighted = tf.math.multiply(x_weight, y_true)
    y_pred_weighted = tf.math.multiply(x_weight, y_pred)

    # Calculate Euclidean distance with weighted coordinates
    squared_diff = tf.math.square(y_pred_weighted - y_true_weighted)
    squared_dist = tf.math.reduce_sum(squared_diff, axis=-1)
    dist = tf.math.sqrt(squared_dist)

    # Euclidean Distance from [0,0] to [177.8, 100] = 203.992
    norm_scale = tf.constant(203.992, dtype=tf.float32)

    # Normalizes loss values to the diagonal-- makes loss easier to interpret
    normalized_dist = (dist / norm_scale) * 100

    # Apply mask to only include valid points in the loss calculation
    masked_dist = normalized_dist * mask

    # Sum the losses and divide by the number of non-padded elements
    num_valid = tf.maximum(tf.reduce_sum(mask), 1.0)  # Avoid division by zero
    loss = tf.reduce_sum(masked_dist) / num_valid

    return loss

full_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=masked_normalized_weighted_euc_dist,
    metrics=[normalized_weighted_euc_dist],  # Keep the original for comparison if needed
    jit_compile=True
)

In [None]:
def lr_schedule(epoch):
  if epoch < 15:
    return LEARNING_RATE
  else:
    return LEARNING_RATE * 0.1

lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule)

# Train model

In [None]:
train_history = full_model.fit(
    train_ds_for_model.batch(64),
    epochs=TRAIN_EPOCHS,
    callbacks=[WandbMetricsLogger(), lr_scheduler]
)

In [None]:
full_model.save('full_model.keras')

In [None]:
full_model.save_weights('full_model.weights.h5')

In [None]:
wandb.save('full_model.keras')
wandb.save('full_model.weights.h5')

In [None]:
# Update the prediction code
t2_processed = t2.map(
    lambda x, y, id: (prepare_batch_for_model((x[0], x[1], x[2], x[3], x[4]), y)[0], y, id)
).prefetch(tf.data.AUTOTUNE)

# Initialize a list to store the batch losses
batch_losses = []
subject_ids = []
y_true = []

for e in t2_processed.batch(1).as_numpy_iterator():
    inputs = e[0]
    y_true.append(e[1])
    subject_ids.append(e[2][0])

y_true = np.array(y_true).reshape(-1, 144, 2)

# this step is slower than expected. not sure what's going on.
predictions = full_model.predict(t2_processed.batch(1))

for i in range(len(predictions)):
  loss = normalized_weighted_euc_dist(y_true[i], predictions[i]).numpy()
  batch_losses.append(loss)

batch_losses = np.array(batch_losses)

# Get mean per subject
batch_losses = np.mean(batch_losses, axis=1)

In [None]:
final_loss_table = wandb.Table(data=[[s, l] for (s, l) in zip(subject_ids, batch_losses)], columns=["subject", "scores"])

In [None]:
final_loss_hist = wandb.plot.histogram(final_loss_table, "scores", title="Normalized Euclidean Distance")

In [None]:
final_loss_mean = np.mean(batch_losses)

In [None]:
wandb.log({"final_val_loss_table": final_loss_table, "final_val_loss_hist": final_loss_hist, "final_loss_mean": final_loss_mean})

In [None]:
plt.figure(figsize=(8, 6))
plt.hist(batch_losses, bins=20, edgecolor='black')
plt.title('Histogram of Batch Losses')
plt.xlabel('Loss')
plt.ylabel('Frequency')
plt.show()

In [None]:
wandb.finish()