In [None]:
import os
import subprocess
import tempfile
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import torch
from sqlalchemy.orm import Session
from src.api.db import engine
from src.api.models.gaze import GazePoint
from src.api.models.pydantic import SimRoomClassDTO
from src.api.repositories import simrooms_repo
from src.api.services import gaze_service, simrooms_service, labeling_service
from src.config import TOBII_GLASSES_FPS, VIEWED_RADIUS
from src.api.utils import image_utils
from src.utils import (
    cv2_video_fps,
    cv2_video_frame_count,
    cv2_video_resolution,
    extract_frames_to_dir,
)
from tqdm import tqdm

from experiment.settings import (
    FULLY_LABELED_RECORDINGS,
    GROUND_TRUTH_PATH,
    LABELING_VALIDATION_VIDEOS_PATH,
    TRIAL_RECORDING_IDS,
)

# Create the ground truth dataset

In [None]:
print(
    np.load(
        "/home/zilian/projects/bachelorproef/experiments/experiment/data/labeling_results/2/1/0.npz"
    ).files
)

['mask', 'box', 'roi', 'class_id', 'frame_idx']


In [None]:
def get_viewed_annotations_per_frame(
    annotated_classes: list[SimRoomClassDTO],
    gaze_point_per_frame: dict[int, GazePoint],
    video_resolution: tuple[int, int],
):
    # Gather all annotation paths for each annotated frame
    annotations_per_frame: dict[int, list[Path]] = {}
    for anno_class in annotated_classes:
        annotation_paths = labeling_service.get_class_tracking_results(anno_class.id)
        for annotation_path in annotation_paths:
            frame_idx = int(annotation_path.stem)

            annotation_file = np.load(annotation_path)
            mask = annotation_file["mask"]
            x1, y1, x2, y2 = annotation_file["box"]

            # Put the mask in a tensor of the same size as the video frame
            mask_full = np.zeros(video_resolution, dtype=np.uint8)
            mask_full[y1:y2, x1:x2] = mask
            mask_full_torch = torch.from_numpy(mask_full)

            gaze_point = gaze_point_per_frame.get(frame_idx)
            if gaze_point is None:
                continue

            if gaze_service.mask_was_viewed(mask_full_torch, gaze_point.position):
                if frame_idx not in annotations_per_frame:
                    annotations_per_frame[frame_idx] = []

                annotations_per_frame[frame_idx].append(annotation_path)

    return annotations_per_frame

In [None]:
def draw_validation_video_frames(
    frames: list[Path],
    annotations_per_frame: dict[int, list[Path]],
    gaze_point_per_frame: dict[int, GazePoint],
    annotated_classes: list[SimRoomClassDTO],
):
    class_id_to_annotated_class = {
        anno_class.id: anno_class for anno_class in annotated_classes
    }

    # Iterate over frames and draw the annotations on them if they exist
    for frame in tqdm(frames, desc="Drawing annotations on frames"):
        frame_idx = int(frame.stem)
        frame_img = cv2.imread(str(frame))

        if annotations_per_frame.get(frame_idx) is not None:
            for annotation_path in annotations_per_frame[frame_idx]:
                annotation_file = np.load(annotation_path)
                class_id = int(annotation_file["class_id"])
                x1, y1, x2, y2 = annotation_file["box"]
                mask = annotation_file["mask"]

                # Squeeze mask if it has an extra dimension
                if mask.ndim == 3 and mask.shape[0] == 1:
                    mask = mask[0]
                if mask.dtype != bool:
                    mask = mask.astype(bool)

                color = class_id_to_annotated_class[class_id].color
                class_name = class_id_to_annotated_class[class_id].class_name
                box = (x1, y1, x2, y2)

                frame_img = image_utils.draw_mask(frame_img, mask, box)
                frame_img = image_utils.draw_labeled_box(
                    frame_img, box, class_name, color
                )

        # Draw the gaze point on the frame
        gaze_point = gaze_point_per_frame.get(frame_idx)
        if gaze_point is not None:
            gaze_x, gaze_y = gaze_point.position
            cv2.circle(
                frame_img,
                (int(gaze_x), int(gaze_y)),
                radius=VIEWED_RADIUS,
                color=(0, 0, 255),
                thickness=2,
            )

        # Save the modified image back to its original location
        cv2.imwrite(str(frame), frame_img)

In [None]:
if not LABELING_VALIDATION_VIDEOS_PATH.exists():
    os.makedirs(LABELING_VALIDATION_VIDEOS_PATH)
else:
    for file in LABELING_VALIDATION_VIDEOS_PATH.glob("*.mp4"):
        # os.remove(file)
        pass

if GROUND_TRUTH_PATH.exists():
    GROUND_TRUTH_PATH.unlink()

CREATE_VALIDATION_VIDEO = False
gt_rows = []
for recording_id in tqdm(TRIAL_RECORDING_IDS):
    if recording_id not in FULLY_LABELED_RECORDINGS:
        continue

    with Session(engine) as session:
        calibration_recording = simrooms_repo.get_calibration_recording(
            db=session,
            simroom_id=1,
            recording_id=recording_id,
        )
        annotated_classes = simrooms_service.get_tracked_classes(
            session=session,
            simroom_id=1,
            recording_id=recording_id,
        )

        trial_recording_path = calibration_recording.recording.video_path
        gaze_data_path = calibration_recording.recording.gaze_data_path

    trial_video_resolution = cv2_video_resolution(trial_recording_path)
    trial_video_fps = cv2_video_fps(trial_recording_path)
    trial_video_frame_count = cv2_video_frame_count(trial_recording_path)

    # Load and preprocess gaze points
    print(f"Loading gaze data for {recording_id}")
    gaze_point_per_frame = gaze_service.get_gaze_point_per_frame(
        gaze_data_path=gaze_data_path,
        resolution=trial_video_resolution,
        frame_count=trial_video_frame_count,
        fps=trial_video_fps,
    )

    # Get all annotations that were viewed
    print(f"Getting viewed annotations for {recording_id}")
    annotations_per_frame = get_viewed_annotations_per_frame(
        annotated_classes=annotated_classes,
        gaze_point_per_frame=gaze_point_per_frame,
        video_resolution=trial_video_resolution,
    )

    # Build the ground truth DataFrame
    # TODO: Might be interesting to add blur metric per frame to the ground truth dataset
    print(f"Building ground truth for {recording_id}")
    for frame_idx, annotation_paths in annotations_per_frame.items():
        for annotation_path in annotation_paths:
            annotation_file = np.load(annotation_path)
            class_id = int(annotation_file["class_id"])
            mask_area = np.sum(annotation_file["mask"])
            x1, y1, x2, y2 = annotation_file["box"]
            roi = annotation_file["roi"]

            laplacian_variance = cv2.Laplacian(roi, cv2.CV_64F).var()

            gt_rows.append({
                "recording_id": recording_id,
                "frame_idx": frame_idx,
                "class_id": class_id,
                "mask_area": mask_area,
                "laplacian_variance": laplacian_variance,
                "x1": x1,
                "y1": y1,
                "x2": x2,
                "y2": y2,
            })

    if CREATE_VALIDATION_VIDEO:
        # Extract frames from the video and save them to a temporary directory
        print(f"Extracting frames for {recording_id}")
        tmp_frames_dir = tempfile.TemporaryDirectory()
        tmp_frames_path = Path(tmp_frames_dir.name)
        extract_frames_to_dir(
            video_path=trial_recording_path,
            frames_path=tmp_frames_path,
            print_output=False,
        )
        frames = sorted(list(tmp_frames_path.glob("*.jpg")), key=lambda x: int(x.stem))

        print(f"Drawing annotations for {recording_id}")
        draw_validation_video_frames(
            frames=frames,
            annotations_per_frame=annotations_per_frame,
            gaze_point_per_frame=gaze_point_per_frame,
            annotated_classes=annotated_classes,
        )

        print(f"Creating video for {recording_id}")
        cmd = f'ffmpeg -hwaccel cuda -y -pattern_type glob -framerate {TOBII_GLASSES_FPS} -i "{tmp_frames_path!s}/*.jpg" -c:v libx264 -pix_fmt yuv420p "{LABELING_VALIDATION_VIDEOS_PATH}/{recording_id}.mp4"'
        subprocess.run(
            cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
        )

ground_truth_df = pd.DataFrame(gt_rows)
ground_truth_df.to_csv(
    GROUND_TRUTH_PATH,
    index=False,
)

  0%|          | 0/14 [00:00<?, ?it/s]

Loading gaze data for 67b71a70-da64-467a-9fb6-91bc29265fd1
Getting viewed annotations for 67b71a70-da64-467a-9fb6-91bc29265fd1
Building ground truth for 67b71a70-da64-467a-9fb6-91bc29265fd1


  7%|▋         | 1/14 [01:49<23:37, 109.07s/it]

Loading gaze data for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Getting viewed annotations for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Building ground truth for 32f02db7-adc0-4556-a2da-ed2ba60a58c9


 14%|█▍        | 2/14 [03:38<21:48, 109.05s/it]

Loading gaze data for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
Getting viewed annotations for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
Building ground truth for b8eeecc0-06b1-47f7-acb5-89aab3c1724d


100%|██████████| 14/14 [05:02<00:00, 21.58s/it]
