In [1]:
import json
import shutil
import traceback
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import src.api.controllers.generate_embeddings as generate_embeddings
import torch
import torchvision.transforms.functional as F
from sqlalchemy.orm import Session
from src.config import GAZE_FOVEA_FOV, TOBII_FOV_X
from src.db import engine
from src.db.models import Recording, SimRoomClass
from src.api.controllers.gaze_segmentation import (
    get_gaze_points,
    match_frames_to_gaze,
    parse_gazedata_file,
    mask_was_viewed
)
from src.utils import cv2_video_fps, cv2_video_frame_count, cv2_video_resolution
from torchvision.ops import masks_to_boxes
from torchvision.transforms import InterpolationMode
from ultralytics import FastSAM
from tqdm import tqdm
import faiss
from typing import Any

2025-04-07 17:31:49.051694: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-07 17:31:49.211665: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744039909.273034   11280 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744039909.289846   11280 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-07 17:31:49.435226: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
with open("experiment_metadata.json") as file:
    experiment_metadata = json.load(file)
    trial_recordings_metadata = experiment_metadata["trial_recordings_metadata"]
    trial_recording_uuids = list(trial_recordings_metadata.keys())
    labeling_same_background_uuid = experiment_metadata["labeling_same_background_uuid"]
    labeling_diff_background_uuid = experiment_metadata["labeling_diff_background_uuid"]

with Session(engine) as session:
    trial_recordings = (
        session.query(Recording).filter(Recording.uuid.in_(trial_recording_uuids)).all()
    )

In [3]:
dinov2 = generate_embeddings.load_model()
GAZE_SEGMENTATION_RESULTS_PATH = Path("data/gaze_segmentation_results")
SAME_BACKGROUND_VECTOR_INDEXES_PATH = Path("data/vector_indexes/same_background")
DIFF_BACKGROUND_VECTOR_INDEXES_PATH = Path("data/vector_indexes/diff_background")

In [4]:
def create_grounding_dataset(dataset_path: Path, gaze_segmentation_results: list[Any], index: faiss.IndexIDMap, k: int):
    result_rows = []
    for result in gaze_segmentation_results:
        frame_idx = result["frame_idx"]
        rois = result["rois"]
        object_ids = result["object_ids"]

        # Get embeddings (assuming one batch is returned)
        embeddings, _, _ = list(generate_embeddings.get_embeddings(dinov2, rois))[0]
        per_roi_distances, per_roi_class_ids = generate_embeddings.search_index(
            index, embeddings, k=k
        )

        for i, roi in enumerate(rois):  # iterate over each ROI
            object_id = object_ids[i]
            distances = per_roi_distances[i]
            class_ids = per_roi_class_ids[i]

            # Group distances by class using defaultdict for conciseness
            class_to_dists = defaultdict(list)
            for cid, d in zip(class_ids, distances, strict=False):
                class_to_dists[cid].append(d)

            # For each class, compute statistics and add a row
            for cid, dists in class_to_dists.items():
                result_rows.append({
                    "frame_idx": frame_idx,
                    "object_id": object_id,
                    "class_id": cid,
                    "avg_distance": np.mean(dists),
                    "min_distance": np.min(dists),
                    "max_distance": np.max(dists),
                    "var_distance": np.var(dists),
                })

    pd.DataFrame(result_rows).to_csv(dataset_path, index=False)


In [5]:
def process_recording(recording_uuid: str, index: faiss.IndexIDMap, k: int):
    gaze_segmentation_results_path = GAZE_SEGMENTATION_RESULTS_PATH / recording_uuid

    gaze_segmentation_results = list(gaze_segmentation_results_path.iterdir())
    gaze_segmentation_results.sort(key=lambda x: int(x.stem))
    gaze_segmentation_results = [
        np.load(result, allow_pickle=True) for result in gaze_segmentation_results
    ]
    # Further processing can be done here using the 'index' and 'k' if needed.

K_OPTIONS = [50] #, 100, 200, 300, 400, 500]
GROUNDING_DATASETS_PATH = Path("data/grounding_datasets")

for trial_recording in tqdm(trial_recordings, desc="Processing trial recordings"):
    # Load gaze segmentation results for this recording
    gaze_segmentation_results_path = GAZE_SEGMENTATION_RESULTS_PATH / trial_recording.uuid
    gaze_segmentation_results = list(gaze_segmentation_results_path.iterdir())
    gaze_segmentation_results.sort(key=lambda x: int(x.stem))
    gaze_segmentation_results = [
        np.load(result, allow_pickle=True) for result in gaze_segmentation_results
    ]

    # Create the grounding datasets directory for this recording
    grounding_datasets_path = GROUNDING_DATASETS_PATH / trial_recording.uuid
    if grounding_datasets_path.exists():
        shutil.rmtree(grounding_datasets_path)
    grounding_datasets_path.mkdir(parents=True, exist_ok=True)

    # Create grounding datasets for each index and for each k value
    vector_index_paths = list(SAME_BACKGROUND_VECTOR_INDEXES_PATH.iterdir())
    for vector_index_path in tqdm(vector_index_paths, desc="Processing vector indexes", leave=False):
        sample_count = int(vector_index_path.name.split("_")[0])
        index = faiss.read_index(str(vector_index_path))

        for k in K_OPTIONS:
            if k > sample_count:
                continue

            print(f"Processing k={k} for sample_count={sample_count}")
            grounding_dataset_path = grounding_datasets_path / f"grounding_dataset_k={k}_samples={sample_count}.csv"
            create_grounding_dataset(
                dataset_path=grounding_dataset_path,
                gaze_segmentation_results=gaze_segmentation_results,
                index=index,
                k=k,
            )


Processing trial recordings:   0%|          | 0/14 [00:00<?, ?it/s]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:   7%|▋         | 1/14 [14:07<3:03:42, 847.88s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  14%|█▍        | 2/14 [22:25<2:08:25, 642.09s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  21%|██▏       | 3/14 [31:39<1:50:16, 601.47s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  29%|██▊       | 4/14 [41:40<1:40:12, 601.30s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  36%|███▌      | 5/14 [53:17<1:35:24, 636.01s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  43%|████▎     | 6/14 [1:04:05<1:25:20, 640.01s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  50%|█████     | 7/14 [1:12:49<1:10:15, 602.23s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  57%|█████▋    | 8/14 [1:20:52<56:25, 564.18s/it]  

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  64%|██████▍   | 9/14 [1:30:04<46:41, 560.37s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  71%|███████▏  | 10/14 [1:37:14<34:40, 520.19s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600
Processing k=100 for sample_count=600
Processing k=200 for sample_count=600
Processing k=300 for sample_count=600
Processing k=400 for sample_count=600
Processing k=500 for sample_count=600




Processing k=50 for sample_count=400
Processing k=100 for sample_count=400
Processing k=200 for sample_count=400
Processing k=300 for sample_count=400
Processing k=400 for sample_count=400




Processing k=500 for sample_count=400
Processing k=50 for sample_count=500
Processing k=100 for sample_count=500
Processing k=200 for sample_count=500
Processing k=300 for sample_count=500
Processing k=400 for sample_count=500
Processing k=500 for sample_count=500


Processing trial recordings:  79%|███████▊  | 11/14 [1:48:16<28:10, 563.58s/it]

Processing k=50 for sample_count=200
Processing k=100 for sample_count=200
Processing k=200 for sample_count=200




Processing k=300 for sample_count=200
Processing k=400 for sample_count=200
Processing k=500 for sample_count=200
Processing k=50 for sample_count=300
Processing k=100 for sample_count=300
Processing k=200 for sample_count=300
Processing k=300 for sample_count=300




Processing k=400 for sample_count=300
Processing k=500 for sample_count=300
Processing k=50 for sample_count=100
Processing k=100 for sample_count=100




Processing k=200 for sample_count=100
Processing k=300 for sample_count=100
Processing k=400 for sample_count=100
Processing k=500 for sample_count=100
Processing k=50 for sample_count=600


Processing trial recordings:  79%|███████▊  | 11/14 [1:51:00<30:16, 605.50s/it]


KeyboardInterrupt: 