# Joint Cross-Attnetion Fusion Net (JCAF-Net)

### 1. Inputs gaze and mouse:
- X/Y coordinates
- Speed and direction (velocity)
- Joint features: Euclidean distance, direcational angle between gaze and mouse

### 2. Dual-Pathway design:
- Cross-Attention Path: Learns interations between gaze and mouse using attention scores
- Joint Feature Path: Processes handcrafted joint features (e.g. distance, angle) with CNNs
- WARNING: VERY FEW DATA (AND BAD QUALITY) FOR THE MOUSE

### 3. Feature Extractor:
- ResNet-34 (2D CNNs) as backbone for unimodal (gaze and mouse) and joint inputs

### 4. Cross-Attention Module:
- Applied mid-way trough ResNet (between resudial blocks 2 and 3)
- Attends to gaze features using mouse as key/value and vice versa

### 5. Fusion and Classification:
- Attended features + joint CNN features are concatenated
- FCN (with re-weighting) produces final class logits.


In [None]:
# TODO: 
# - Modify gaze mouse dataset to accomotade jcafnet inputs
# - Modify Train.py to accomodate both lstm and jcafnet
# -ADD A BOOLEAN MASK IDENTIFYING MISSING VALUES FOR MOUSE AND GAZE ?

In [1]:
import pandas as pd
import numpy as np
import os

from pathlib import Path

os.chdir('/cluster/home/kruu/git/eye_tracking/')
# os.chdir('/home/kruu/git_folder/eye_tracking/')
from utils.data_processing import EyeTrackingProcessor

from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from models.lstm_classifier import LSTMClassifier
from utils.train import train_classifier, split_by_participant, export_to_onnx
from utils.dataset import GazeMouseDatasetJCAFNet

from utils.data_processing import EyeTrackingProcessor, GazeMetricsProcessor, MouseMetricsProcessor

In [2]:
# data_path = "/store/kruu/eye_tracking"
data_path = "/cluster/home/kruu/store/eye_tracking"
files_list = os.listdir(data_path)
files_list = [os.path.join(data_path, file) for file in files_list]

tasks = ['Task 1', 'Task 2', 'Task 3', 'Task 4', 'Task 5', 'Task 6']
features = ['Recording timestamp', 'Gaze point X', 'Gaze point Y', 'Mouse position X', 'Mouse position Y', 'Event', 'Participant name']
interpolate_col = ['Recording timestamp', 'Gaze point X', 'Gaze point Y', 'Mouse position X', 'Mouse position Y', 'Blink']

processor = EyeTrackingProcessor()
all_data = processor.load_data(files_list)
dataset = processor.get_features(all_data, tasks, features)
dataset, blinks = processor.detect_blinks(dataset)

# Fixed Time step resampling
dataset_time_resampled = processor.resample_tasks_fixed_time(dataset, interpolate_col, timestep = 0.01)
dataset_time_resampled.Blink = (dataset_time_resampled.Blink > 0.5) #Transform interpolated data
dataset_time_resampled["id"] = dataset_time_resampled["Participant name"].astype(str) + "_" + dataset_time_resampled["Task_id"].astype(str) + "_" + dataset_time_resampled["Task_execution"].astype(str)

  df = pd.read_csv(path, sep='\t')
  df = pd.read_csv(path, sep='\t')
  df = pd.read_csv(path, sep='\t')
  df = pd.read_csv(path, sep='\t')


In [19]:
def compute_joint_features(df):
    gx, gy = df["Gaze point X"].fillna(0), df["Gaze point Y"].fillna(0)
    mx, my = df["Mouse position X"].fillna(0), df["Mouse position Y"].fillna(0)

    distance = np.sqrt((gx - mx)**2 + (gy - my)**2)
    angle = np.arctan2(gy - my, gx - mx)  # radians
    return distance, angle

#Enrich dataframe with metrics for JCAFNet
def enrich_with_gaze_mouse_metrics(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    
    df["Gaze Velocity"] = np.nan
    df["Gaze Acceleration"] = np.nan
    df["Mouse Velocity"] = np.nan
    df["Mouse Acceleration"] = np.nan
    df["Gaze-Mouse Distance"] = np.nan
    df["Angle Between Gaze and Mouse"] = np.nan

    task_group = df.groupby("id")

    for id, group in tqdm(task_group, desc="Enriching metrics"):
        mask = df["id"] == id

        # Gaze metrics
        gaze_proc = GazeMetricsProcessor(group)
        gaze_vel, gaze_acc = gaze_proc.compute_velocity_acceleration()
        df.loc[mask, "Gaze Velocity"] = gaze_vel.fillna(0).values
        df.loc[mask, "Gaze Acceleration"] = gaze_acc.fillna(0).values

        # Mouse metrics
        mouse_proc = MouseMetricsProcessor(group)
        mouse_vel, mouse_acc = mouse_proc.compute_velocity_acceleration()
        df.loc[mask, "Mouse Velocity"] = mouse_vel.fillna(0).values
        df.loc[mask, "Mouse Acceleration"] = mouse_acc.fillna(0).values

        # Joint features
        dist, angle = compute_joint_features(group)
        df.loc[mask, "Gaze-Mouse Distance"] = dist.fillna(0).values
        df.loc[mask, "Angle Between Gaze and Mouse"] = angle.fillna(0).values

    return df

In [None]:
features = {
    "gaze": ["Gaze point X", "Gaze point Y", "Gaze Velocity", "Gaze Acceleration"],
    "mouse": ["Mouse position X", "Mouse position Y", "Mouse Velocity", "Mouse Acceleration"],
    "joint": ["Gaze-Mouse Distance", "Angle Between Gaze and Mouse"]
}

dataset_enriched = enrich_with_gaze_mouse_metrics(dataset_time_resampled)

Enriching metrics: 100%|██████████| 684/684 [01:01<00:00, 11.08it/s]


In [None]:
test_set = GazeMouseDatasetJCAFNet(
    dataset=dataset_enriched,
    gaze_features=features["gaze"],
    mouse_features=features["mouse"],
    joint_features=features["joint"],
    augment=None, 
    mean = None, 
    std = None,
)

In [24]:
test_set.__getitem__(0)

{'gaze': tensor([[        nan,         nan,         nan,  ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan,  ...,         nan,
                  nan,         nan],
         [-2.9660e-01, -2.9660e-01, -2.9660e-01,  ..., -2.9660e-01,
          -2.9660e-01, -2.9660e-01],
         [-1.0893e-04, -1.0893e-04, -1.0893e-04,  ..., -1.0893e-04,
          -1.0893e-04, -1.0893e-04]]),
 'mouse': tensor([[        nan,         nan,         nan,  ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan,  ...,         nan,
                  nan,         nan],
         [-2.0017e-02, -2.0017e-02, -2.0017e-02,  ..., -2.0017e-02,
          -2.0017e-02, -2.0017e-02],
         [-1.8592e-20, -1.8592e-20, -1.8592e-20,  ..., -1.8592e-20,
          -1.8592e-20, -1.8592e-20]]),
 'joint': tensor([[-0.7718, -0.7718, -0.7718,  ..., -0.7718, -0.7718, -0.7718],
         [-0.7443, -0.7443, -0.7443,  ..., -0.7443, -0.7443