# VTNet Leave-One-User-Out Evaluation with 3-Second DGM Features

This notebook mirrors the leave-one-user-out (LOUO) evaluation performed on raw gaze coordinates, but replaces the temporal branch with Dynamic Gaze Metrics (DGMs) computed over 3-second tumbling windows. The CNN branch still consumes non-contaminated scanpath images while the RNN branch processes DGM sequences. The goal is to assess cross-user generalisation on the clean dataset for each graph type (`bar`, `line`, `pie`).


In [None]:
import os
from pathlib import Path
from functools import lru_cache

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    precision_recall_fscore_support,
)



In [None]:
print(f"TensorFlow version: {tf.__version__}")
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("\n=== GPU Diagnostics ===")
print(f"Built with CUDA: {tf.test.is_built_with_cuda()}")
try:
    print(f"CUDA available: {tf.test.is_gpu_available(cuda_only=True)}")
except Exception:
    print("CUDA availability check failed (method may be deprecated)")

print(f"\nPhysical devices: {tf.config.list_physical_devices()}")
print(f"GPU devices: {tf.config.list_physical_devices('GPU')}")

# Configure GPU memory growth if possible
if tf.config.list_physical_devices('GPU'):
    try:
        for gpu in tf.config.experimental.list_physical_devices('GPU'):
            tf.config.experimental.set_memory_growth(gpu, True)
        print("\n✅ GPU memory growth enabled")
    except RuntimeError as e:
        print(f"Warning: Could not configure GPU memory growth: {e}")
else:
    print("\n❌ No GPU detected by TensorFlow. Training will fall back to CPU.")



In [None]:
def find_project_root(start: Path, marker: str = "non-contaminated datasets", max_levels: int = 6) -> Path:
    current = start.resolve()
    for _ in range(max_levels):
        if (current / marker).exists():
            return current
        current = current.parent
    raise FileNotFoundError(
        f"Could not locate project root containing '{marker}' starting from {start}"
    )

current_dir = Path.cwd()
try:
    base_dir = find_project_root(current_dir)
except FileNotFoundError as e:
    print(e)
    base_dir = current_dir  # fallback to current directory

scanpaths_dir = base_dir / "non-contaminated datasets" / "Scanpaths"
dgm_dir = (
    base_dir
    / "non-contaminated datasets"
    / "Organized Normalized Tumbling Window DGMs (3s)"
)
raw_csv_dir = base_dir / "non-contaminated datasets" / "Raw Eye Tracking Data"
target_csv_path = base_dir / "Code" / "Utilities" / "users_literacy_results.csv"

print(f"Working directory: {current_dir}")
print(f"Project root: {base_dir}")
print(f"Scanpaths dir exists: {scanpaths_dir.exists()}")
print(f"DGM dir exists: {dgm_dir.exists()}")
print(f"Raw CSV dir exists: {raw_csv_dir.exists()}")
print(f"Target CSV path: {target_csv_path}")

# Data constants
IMG_HEIGHT = 150
IMG_WIDTH = 150
SEQ_LENGTH = 20  # maximum number of DGM windows
BATCH_SIZE = 16
EPOCHS = 12
VALIDATION_SPLIT = 0.1
RANDOM_SEED = 42

skip_users = {5, 20}
graph_types = ["bar", "line", "pie"]
class_map = {"illiterate": 0, "literate": 1}
label_names = ["illiterate", "literate"]

np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)



In [None]:
literacy_df = pd.read_csv(target_csv_path)
literacy_df["MEDIA_ID"] = literacy_df["MEDIA_ID"].astype(int)
literacy_df["LITERACY"] = literacy_df["LITERACY"].astype(int)
literacy_map = dict(zip(literacy_df["MEDIA_ID"], literacy_df["LITERACY"]))



In [None]:
def parse_user_id(path: Path) -> int | None:
    name = path.name
    if name.startswith("user_"):
        try:
            return int(name.split("_")[1])
        except (IndexError, ValueError):
            return None
    return None


def parse_question_id(filename: str) -> int | None:
    try:
        part = filename.split("question_")[1]
        return int(part.split("_")[0].split(".")[0])
    except (IndexError, ValueError):
        return None



In [None]:
@lru_cache(maxsize=1)
def get_dgm_feature_indices() -> list[int]:
    columns_to_exclude = {
        "sum_peak_saccade_velocity",
        "mean_peak_saccade_velocity",
        "median_peak_saccade_velocity",
        "std_peak_saccade_velocity",
        "min_peak_saccade_velocity",
        "max_peak_saccade_velocity",
        "sum_mean_saccade_velocity",
        "mean_mean_saccade_velocity",
        "median_mean_saccade_velocity",
        "std_mean_saccade_velocity",
        "min_mean_saccade_velocity",
        "max_mean_saccade_velocity",
        "stationary_entropy",
        "transition_entropy",
        "total_number_of_blinks",
        "average_blink_rate_per_minute",
        "total_number_of_l_mouse_clicks",
        "average_pupil_size_of_left_eye",
        "average_pupil_size_of_right_eye",
        "average_pupil_size_of_both_eyes",
    }

    all_columns = [
        "total_number_of_fixations",
        "sum_of_all_fixation_duration_s",
        "mean_fixation_duration_s",
        "median_fixation_duration_s",
        "stdev_of_fixation_durations_s",
        "min_fixation_duration_s",
        "max_fixation_duration_s",
        "total_number_of_saccades",
        "sum_of_all_saccade_lengths",
        "mean_saccade_length",
        "median_saccade_length",
        "stdev_of_saccade_lengths",
        "min_saccade_length",
        "max_saccade_length",
        "sum_of_all_saccade_durations",
        "mean_saccade_duration",
        "median_saccade_duration",
        "stdev_of_saccade_durations",
        "min_saccade_duration",
        "max_saccade_duration",
        "sum_of_all_saccade_amplitudes",
        "mean_saccade_amplitude",
        "median_saccade_amplitude",
        "stdev_of_saccade_amplitude",
        "min_saccade_amplitude",
        "max_saccade_amplitude",
        "scanpath_duration",
        "fixation_to_saccade_ratio",
        "sum_peak_saccade_velocity",
        "mean_peak_saccade_velocity",
        "median_peak_saccade_velocity",
        "std_peak_saccade_velocity",
        "min_peak_saccade_velocity",
        "max_peak_saccade_velocity",
        "sum_mean_saccade_velocity",
        "mean_mean_saccade_velocity",
        "median_mean_saccade_velocity",
        "std_mean_saccade_velocity",
        "min_mean_saccade_velocity",
        "max_mean_saccade_velocity",
        "sum_of_all_absolute_degrees",
        "mean_absolute_degree",
        "median_absolute_degree",
        "stdev_of_absolute_degrees",
        "min_absolute_degree",
        "max_absolute_degree",
        "sum_of_all_relative_degrees",
        "mean_relative_degree",
        "median_relative_degree",
        "stdev_of_relative_degrees",
        "min_relative_degree",
        "max_relative_degree",
        "convex_hull_area",
        "stationary_entropy",
        "transition_entropy",
        "total_number_of_blinks",
        "average_blink_rate_per_minute",
        "total_number_of_valid_recordings",
        "average_pupil_size_of_left_eye",
        "average_pupil_size_of_right_eye",
        "average_pupil_size_of_both_eyes",
        "total_number_of_l_mouse_clicks",
        "beginning_timestamp",
        "ending_timestamp",
        "window_duration",
        "initial_seconds_elapsed_since_start",
        "final_seconds_elapsed_since_start",
        "delta_average_pupil_size_of_left_eye",
        "delta_average_pupil_size_of_right_eye",
        "delta_average_pupil_size_of_both_eyes",
    ]

    return [idx for idx, col in enumerate(all_columns) if col not in columns_to_exclude]


def load_dgm_from_csv(csv_path: Path, max_sequence_length: int = SEQ_LENGTH) -> np.ndarray:
    try:
        df = pd.read_csv(csv_path)
        keep_indices = get_dgm_feature_indices()
        if len(df.columns) < len(keep_indices):
            raise ValueError(
                f"DGM file {csv_path} has insufficient columns: {len(df.columns)}"
            )
        data = df.iloc[:, keep_indices].to_numpy(dtype=np.float32)
        data = np.nan_to_num(data, nan=0.0)
    except Exception as e:
        print(f"Error loading DGM {csv_path}: {e}")
        data = np.zeros((0, len(get_dgm_feature_indices())), dtype=np.float32)

    if data.shape[0] >= max_sequence_length:
        data = data[-max_sequence_length:]
    else:
        pad_len = max_sequence_length - data.shape[0]
        padding = np.zeros((pad_len, data.shape[1]), dtype=np.float32)
        data = np.vstack([data, padding])

    return data.astype(np.float32)


def get_dgm_path_from_image_path(image_path: Path, base_dgm_dir: Path) -> Path:
    filename = image_path.name.replace("_scanpath.png", "_tumbling_all_window_DGMs.csv")
    if "illiterate" in image_path.parts:
        return base_dgm_dir / "illiterate" / filename
    if "literate" in image_path.parts:
        return base_dgm_dir / "literate" / filename
    raise ValueError(f"Could not infer literacy label from image path: {image_path}")



In [None]:
def collect_samples_for_graph(graph_type: str):
    samples = []
    scan_graph_dir = scanpaths_dir / graph_type
    dgm_graph_dir = dgm_dir / graph_type

    for literacy_name, label in class_map.items():
        scan_class_dir = scan_graph_dir / literacy_name
        dgm_class_dir = dgm_graph_dir / literacy_name
        if not scan_class_dir.exists() or not dgm_class_dir.exists():
            continue

        for user_dir in sorted(scan_class_dir.glob("user_*")):
            user_id = parse_user_id(user_dir)
            if user_id is None or user_id in skip_users:
                continue

            dgm_user_dir = dgm_class_dir / f"user_{user_id}"
            if not dgm_user_dir.exists():
                continue

            for img_path in sorted(user_dir.glob("*.png")):
                question_id = parse_question_id(img_path.name)
                dgm_filename = img_path.name.replace(
                    "_scanpath.png", "_tumbling_all_window_DGMs.csv"
                )
                dgm_path = dgm_user_dir / dgm_filename
                if not dgm_path.exists():
                    continue

                samples.append(
                    {
                        "image": img_path,
                        "dgm": dgm_path,
                        "label": label,
                        "user_id": user_id,
                        "question_id": question_id,
                    }
                )
    return samples



In [None]:
def _py_load_dgm(path_tensor):
    if hasattr(path_tensor, "numpy"):
        path_bytes = path_tensor.numpy()
    else:
        path_bytes = path_tensor
    if isinstance(path_bytes, bytes):
        path_str = path_bytes.decode("utf-8")
    else:
        path_str = str(path_bytes)
    return load_dgm_from_csv(Path(path_str))


def make_dataset(samples, shuffle=False, batch_size=BATCH_SIZE):
    if len(samples) == 0:
        raise ValueError("No samples provided to build dataset.")

    image_paths = [str(s["image"]) for s in samples]
    dgm_paths = [str(s["dgm"]) for s in samples]
    labels = [s["label"] for s in samples]

    dataset = tf.data.Dataset.from_tensor_slices((image_paths, dgm_paths, labels))

    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(samples), reshuffle_each_iteration=True)

    num_features = len(get_dgm_feature_indices())

    def load_sample(image_path, dgm_path, label):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_png(image, channels=3)
        image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
        image = tf.image.rgb_to_grayscale(image)
        image = tf.cast(image, tf.float32) / 255.0

        dgm = tf.py_function(func=_py_load_dgm, inp=[dgm_path], Tout=tf.float32)
        dgm.set_shape((SEQ_LENGTH, num_features))

        label_onehot = tf.one_hot(label, depth=2)
        return (image, dgm), label_onehot

    dataset = dataset.map(load_sample, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset



In [None]:
def create_vtnet_dgm_model(rnn_type: str = "gru", hidden_size: int = 256):
    scanpath_input = tf.keras.layers.Input(
        shape=(IMG_HEIGHT, IMG_WIDTH, 1), name="scanpath_input"
    )
    dgm_input = tf.keras.layers.Input(
        shape=(SEQ_LENGTH, len(get_dgm_feature_indices())), name="dgm_input"
    )

    x = tf.keras.layers.Conv2D(6, (5, 5), activation="relu", name="conv1")(scanpath_input)
    x = tf.keras.layers.MaxPooling2D((2, 2), name="pool1")(x)
    x = tf.keras.layers.Conv2D(16, (5, 5), activation="relu", name="conv2")(x)
    x = tf.keras.layers.MaxPooling2D((2, 2), name="pool2")(x)
    x = tf.keras.layers.Flatten(name="cnn_flatten")(x)
    cnn_features = tf.keras.layers.Dense(50, activation="relu", name="cnn_fc1")(x)

    if rnn_type.lower() == "gru":
        rnn_layer = tf.keras.layers.GRU(hidden_size, name="dgm_gru")
    elif rnn_type.lower() == "lstm":
        rnn_layer = tf.keras.layers.LSTM(hidden_size, name="dgm_lstm")
    else:
        rnn_layer = tf.keras.layers.SimpleRNN(hidden_size, name="dgm_simple")

    dgm_features = rnn_layer(dgm_input)

    fused = tf.keras.layers.Concatenate(name="fusion_concat")(
        [cnn_features, dgm_features]
    )
    fused = tf.keras.layers.Dense(20, activation="relu", name="fusion_fc1")(fused)
    output = tf.keras.layers.Dense(2, activation="softmax", name="output")(fused)

    model = tf.keras.Model(
        inputs=[scanpath_input, dgm_input], outputs=output, name="VTNet_LOUO_DGM"
    )
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
    return model



In [None]:
def run_fold(train_samples, test_samples, epochs=EPOCHS):
    test_size = max(1, int(len(train_samples) * VALIDATION_SPLIT))
    labels = [s["label"] for s in train_samples]
    n_classes = len(set(labels))
    use_stratify = test_size >= n_classes

    train_subset, val_subset = train_test_split(
        train_samples,
        test_size=test_size,
        random_state=RANDOM_SEED,
        stratify=labels if use_stratify else None,
    )

    train_ds = make_dataset(train_subset, shuffle=True, batch_size=BATCH_SIZE)
    val_ds = make_dataset(val_subset, shuffle=False, batch_size=BATCH_SIZE)
    test_ds = make_dataset(test_samples, shuffle=False, batch_size=BATCH_SIZE)

    model = create_vtnet_dgm_model()

    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=3, restore_best_weights=True
    )

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=[early_stop],
        verbose=0,
    )

    y_true, y_pred = [], []
    for (images, dgms), labels_batch in test_ds:
        preds = model.predict([images, dgms], verbose=0)
        y_true.extend(np.argmax(labels_batch.numpy(), axis=1))
        y_pred.extend(np.argmax(preds, axis=1))

    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": confusion_matrix(y_true, y_pred, labels=[0, 1]),
    }

    return metrics, history.history, y_true, y_pred



In [None]:
def run_louo_for_graph(graph_type: str):
    samples = collect_samples_for_graph(graph_type)
    if len(samples) == 0:
        raise ValueError(f"No samples found for graph type {graph_type}")

    df = pd.DataFrame(samples)
    user_ids = sorted(df["user_id"].unique())

    overall_true, overall_pred = [], []
    fold_results = {}

    for user_id in user_ids:
        train_samples = df[df["user_id"] != user_id].to_dict("records")
        test_samples = df[df["user_id"] == user_id].to_dict("records")
        if len(train_samples) == 0 or len(test_samples) == 0:
            continue

        metrics, history, y_true, y_pred = run_fold(train_samples, test_samples)
        fold_results[user_id] = {"metrics": metrics, "history": history}

        overall_true.extend(y_true)
        overall_pred.extend(y_pred)

    precision, recall, f1, _ = precision_recall_fscore_support(
        overall_true, overall_pred, average="binary", zero_division=0
    )
    aggregate = {
        "accuracy": accuracy_score(overall_true, overall_pred),
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": confusion_matrix(overall_true, overall_pred, labels=[0, 1]),
        "folds": fold_results,
    }
    return aggregate



In [None]:
results = {}

for graph in graph_types:
    print(f"\n=== Running VTNet DGM LOUO for {graph} graphs ===")
    agg = run_louo_for_graph(graph)
    results[graph] = agg
    print(
        f"Accuracy: {agg['accuracy']:.3f} | Precision: {agg['precision']:.3f} | "
        f"Recall: {agg['recall']:.3f} | F1: {agg['f1']:.3f}"
    )



In [None]:
summary_df = pd.DataFrame(
    [
        {
            "graph_type": graph,
            "accuracy": res["accuracy"],
            "precision": res["precision"],
            "recall": res["recall"],
            "f1": res["f1"],
        }
        for graph, res in results.items()
    ]
).sort_values("graph_type")
summary_df


In [None]:
graph_type = "bar"

fold_rows = []
for user_id, info in results.get(graph_type, {}).get("folds", {}).items():
    m = info["metrics"]
    fold_rows.append(
        {
            "user_id": user_id,
            "accuracy": m["accuracy"],
            "precision": m["precision"],
            "recall": m["recall"],
            "f1": m["f1"],
        }
    )

pd.DataFrame(fold_rows).sort_values("user_id")


In [None]:
graph_type = "bar"

if graph_type in results:
    for user_id, info in results[graph_type]["folds"].items():
        cm = info["metrics"]["confusion_matrix"]
        plt.figure(figsize=(3.5, 3))
        sns.heatmap(
            cm,
            annot=True,
            fmt="d",
            cmap="Blues",
            xticklabels=["Pred 0", "Pred 1"],
            yticklabels=["True 0", "True 1"],
        )
        plt.title(f"{graph_type} | User {user_id}")
        plt.ylabel("True Label")
        plt.xlabel("Predicted Label")
        plt.tight_layout()
        plt.show()
else:
    print(f"Graph type '{graph_type}' not in results.")



In [None]:
user_level_accuracy = {}

for graph_type in graph_types:
    if graph_type not in results:
        continue

    correct_users = 0
    total_users = 0
    user_details = []

    for user_id, info in results[graph_type]["folds"].items():
        user_acc = info["metrics"]["accuracy"]
        total_users += 1
        is_correct = user_acc > 0.5
        if is_correct:
            correct_users += 1
        user_details.append(
            {
                "user_id": user_id,
                "accuracy": user_acc,
                "correct": is_correct,
            }
        )

    new_accuracy = correct_users / total_users if total_users else 0.0
    user_level_accuracy[graph_type] = {
        "new_accuracy": new_accuracy,
        "correct_users": correct_users,
        "total_users": total_users,
        "user_details": user_details,
    }

    print(f"{graph_type} graphs:")
    print(f"  New Accuracy: {new_accuracy:.3f} ({correct_users}/{total_users} users)\n")

summary_user_level = pd.DataFrame(
    [
        {
            "graph_type": graph,
            "new_accuracy": info["new_accuracy"],
            "correct_users": info["correct_users"],
            "total_users": info["total_users"],
        }
        for graph, info in user_level_accuracy.items()
    ]
).sort_values("graph_type")

print("Summary:")
summary_user_level
