In [1]:
import os

os.environ["CHECKPOINTS_PATH"] = "../checkpoints"

import subprocess
import tempfile
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import torch
from sqlalchemy.orm import Session
from src.api.db import engine
from src.api.models.gaze import GazePoint
from src.api.models.pydantic import SimRoomClassDTO
from src.api.repositories import simrooms_repo
from src.api.services import gaze_service, simrooms_service, labeling_service
from src.config import TOBII_GLASSES_FPS, VIEWED_RADIUS, TOBII_GLASSES_RESOLUTION
from src.api.utils import image_utils
from src.utils import (
    extract_frames_to_dir,
)
from tqdm import tqdm

from experiment.settings import (
    FULLY_LABELED_RECORDINGS,
    GROUND_TRUTH_PATH,
    LABELING_VALIDATION_VIDEOS_PATH,
    SIMROOM_ID,
)

# Create the ground truth dataset

In [2]:
print(np.load("data/labeling_results/2/1/0.npz").files)

['mask', 'box', 'roi', 'class_id', 'frame_idx']


In [3]:
def get_viewed_annotations_per_frame(
    cal_rec_id: int,
    tracked_classes: list[SimRoomClassDTO],
    gaze_position_per_frame: dict[int, GazePoint],
    video_resolution: tuple[int, int] = TOBII_GLASSES_RESOLUTION,
):
    # Gather all annotation paths for each annotated frame
    annotations_per_frame: dict[int, list[Path]] = {}
    for anno_class in tracked_classes:
        annotation_paths = labeling_service.get_class_tracking_results(
            calibration_id=cal_rec_id, class_id=anno_class.id
        )

        print(
            f"Processing {len(annotation_paths)} annotations for class {anno_class.class_name}"
        )
        for annotation_path in tqdm(annotation_paths):
            frame_idx = int(annotation_path.stem)

            annotation_file = np.load(annotation_path)
            mask = annotation_file["mask"]
            x1, y1, x2, y2 = annotation_file["box"]

            # Put the mask in a tensor of the same size as the video frame
            mask_full = np.zeros(video_resolution, dtype=np.uint8)
            mask_full[y1:y2, x1:x2] = mask
            mask_full_torch = torch.from_numpy(mask_full)

            gaze_position = gaze_position_per_frame.get(frame_idx)
            if gaze_position is None:
                continue

            if gaze_service.mask_was_viewed(mask_full_torch, gaze_position):
                if frame_idx not in annotations_per_frame:
                    annotations_per_frame[frame_idx] = []

                annotations_per_frame[frame_idx].append(annotation_path)

    return annotations_per_frame

In [4]:
def draw_validation_video_frames(
    frames: list[Path],
    annotations_per_frame: dict[int, list[Path]],
    gaze_position_per_frame: dict[int, GazePoint],
    tracked_classes: list[SimRoomClassDTO],
):
    class_id_to_annotated_class = {
        anno_class.id: anno_class for anno_class in tracked_classes
    }

    # Iterate over frames and draw the annotations on them if they exist
    for frame in tqdm(frames, desc="Drawing annotations on frames"):
        frame_idx = int(frame.stem)
        frame_img = cv2.imread(str(frame))

        if annotations_per_frame.get(frame_idx) is not None:
            for annotation_path in annotations_per_frame[frame_idx]:
                annotation_file = np.load(annotation_path)
                class_id = int(annotation_file["class_id"])
                x1, y1, x2, y2 = annotation_file["box"]
                mask = annotation_file["mask"]

                # Squeeze mask if it has an extra dimension
                if mask.ndim == 3 and mask.shape[0] == 1:
                    mask = mask[0]
                if mask.dtype != bool:
                    mask = mask.astype(bool)

                color = class_id_to_annotated_class[class_id].color
                class_name = class_id_to_annotated_class[class_id].class_name
                box = (x1, y1, x2, y2)

                frame_img = image_utils.draw_mask(frame_img, mask, box)
                frame_img = image_utils.draw_labeled_box(
                    frame_img, box, class_name, color
                )

        # Draw the gaze point on the frame
        gaze_position = gaze_position_per_frame.get(frame_idx)
        if gaze_position is not None:
            gaze_x, gaze_y = gaze_position
            cv2.circle(
                frame_img,
                (int(gaze_x), int(gaze_y)),
                radius=VIEWED_RADIUS,
                color=(0, 0, 255),
                thickness=2,
            )

        # Save the modified image back to its original location
        cv2.imwrite(str(frame), frame_img)

In [5]:
if not LABELING_VALIDATION_VIDEOS_PATH.exists():
    os.makedirs(LABELING_VALIDATION_VIDEOS_PATH)
else:
    for file in LABELING_VALIDATION_VIDEOS_PATH.glob("*.mp4"):
        # os.remove(file)
        pass

if GROUND_TRUTH_PATH.exists():
    GROUND_TRUTH_PATH.unlink()

CREATE_VALIDATION_VIDEO = False
gt_rows = []
for recording_id in FULLY_LABELED_RECORDINGS:
    with Session(engine) as session:
        calibration_recording = simrooms_repo.get_calibration_recording(
            db=session,
            simroom_id=SIMROOM_ID,
            recording_id=recording_id,
        )
        tracked_classes = simrooms_service.get_tracked_classes(
            db=session, calibration_id=calibration_recording.id
        )

        trial_recording_path = calibration_recording.recording.video_path
        gaze_data_path = calibration_recording.recording.gaze_data_path

    # Extract video frames
    print(f"Extracting video frames for {recording_id}")
    frames, _ = simrooms_service.extract_tmp_frames(
        recording_id=recording_id,
    )

    # Load and preprocess gaze points
    print(f"Loading gaze data for {recording_id}")
    gaze_position_per_frame = gaze_service.get_gaze_position_per_frame(
        recording_id=recording_id,
        frame_count=len(frames),
    )

    # Get all annotations that were viewed
    print(f"Getting viewed annotations for {recording_id}")
    annotations_per_frame = get_viewed_annotations_per_frame(
        cal_rec_id=calibration_recording.id,
        tracked_classes=tracked_classes,
        gaze_position_per_frame=gaze_position_per_frame,
    )

    # Build the ground truth DataFrame
    print(f"Building ground truth for {recording_id}")
    for frame_idx, annotation_paths in annotations_per_frame.items():
        for annotation_path in annotation_paths:
            annotation_file = np.load(annotation_path)
            class_id = int(annotation_file["class_id"])
            mask_area = np.sum(annotation_file["mask"])
            x1, y1, x2, y2 = annotation_file["box"]
            roi = annotation_file["roi"]

            laplacian_variance = cv2.Laplacian(roi, cv2.CV_64F).var()

            gt_rows.append({
                "recording_id": recording_id,
                "frame_idx": frame_idx,
                "class_id": class_id,
                "mask_area": mask_area,
                "laplacian_variance": laplacian_variance,
                "x1": x1,
                "y1": y1,
                "x2": x2,
                "y2": y2,
            })

    if CREATE_VALIDATION_VIDEO:
        # Extract frames from the video and save them to a temporary directory
        print(f"Extracting frames for {recording_id}")
        tmp_frames_dir = tempfile.TemporaryDirectory()
        tmp_frames_path = Path(tmp_frames_dir.name)
        extract_frames_to_dir(
            video_path=trial_recording_path,
            frames_path=tmp_frames_path,
            print_output=False,
        )
        frames = sorted(list(tmp_frames_path.glob("*.jpg")), key=lambda x: int(x.stem))

        print(f"Drawing annotations for {recording_id}")
        draw_validation_video_frames(
            frames=frames,
            annotations_per_frame=annotations_per_frame,
            gaze_position_per_frame=gaze_position_per_frame,
            tracked_classes=tracked_classes,
        )

        print(f"Creating video for {recording_id}")
        cmd = f'ffmpeg -hwaccel cuda -y -pattern_type glob -framerate {TOBII_GLASSES_FPS} -i "{tmp_frames_path!s}/*.jpg" -c:v libx264 -pix_fmt yuv420p "{LABELING_VALIDATION_VIDEOS_PATH}/{recording_id}.mp4"'
        subprocess.run(
            cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
        )

ground_truth_df = pd.DataFrame(gt_rows)
ground_truth_df.to_csv(
    GROUND_TRUTH_PATH,
    index=False,
)

Extracting video frames for 67b71a70-da64-467a-9fb6-91bc29265fd1
Loading gaze data for 67b71a70-da64-467a-9fb6-91bc29265fd1
Getting viewed annotations for 67b71a70-da64-467a-9fb6-91bc29265fd1
Processing 1147 annotations for class naaldcontainer


100%|██████████| 1147/1147 [00:16<00:00, 67.86it/s]


Processing 868 annotations for class spuit


100%|██████████| 868/868 [00:12<00:00, 70.06it/s]


Processing 1190 annotations for class keukenmes


100%|██████████| 1190/1190 [00:16<00:00, 70.20it/s]


Processing 208 annotations for class infuus


100%|██████████| 208/208 [00:03<00:00, 64.15it/s]


Processing 1125 annotations for class stethoscoop


100%|██████████| 1125/1125 [00:17<00:00, 65.11it/s]


Processing 529 annotations for class snoep


100%|██████████| 529/529 [00:08<00:00, 61.98it/s]


Processing 1195 annotations for class iced tea


100%|██████████| 1195/1195 [00:22<00:00, 53.43it/s]


Processing 977 annotations for class bril


100%|██████████| 977/977 [00:22<00:00, 43.40it/s]


Processing 110 annotations for class rollator


100%|██████████| 110/110 [00:02<00:00, 41.12it/s]


Processing 380 annotations for class ampulevloeistof


100%|██████████| 380/380 [00:09<00:00, 42.21it/s]


Processing 388 annotations for class ampulepoeder


100%|██████████| 388/388 [00:08<00:00, 44.24it/s]


Building ground truth for 67b71a70-da64-467a-9fb6-91bc29265fd1
Extracting video frames for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Loading gaze data for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Detected 3 gaze points for a frame in the video. This is unexpected.
Getting viewed annotations for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Processing 715 annotations for class naaldcontainer


100%|██████████| 715/715 [00:10<00:00, 65.75it/s]


Processing 370 annotations for class spuit


100%|██████████| 370/370 [00:06<00:00, 59.85it/s]


Processing 505 annotations for class keukenmes


100%|██████████| 505/505 [00:09<00:00, 52.63it/s]


Processing 590 annotations for class stethoscoop


100%|██████████| 590/590 [00:09<00:00, 65.33it/s]


Processing 593 annotations for class bol wol


100%|██████████| 593/593 [00:07<00:00, 78.84it/s]


Processing 620 annotations for class snoep


100%|██████████| 620/620 [00:08<00:00, 73.25it/s]


Processing 621 annotations for class nuchter


100%|██████████| 621/621 [00:09<00:00, 66.57it/s]


Processing 711 annotations for class fotokader


100%|██████████| 711/711 [00:10<00:00, 66.43it/s]


Processing 655 annotations for class iced tea


100%|██████████| 655/655 [00:09<00:00, 67.90it/s]


Processing 635 annotations for class bril


100%|██████████| 635/635 [00:09<00:00, 67.25it/s]


Processing 760 annotations for class monitor


100%|██████████| 760/760 [00:11<00:00, 67.10it/s]


Processing 387 annotations for class ampulevloeistof


100%|██████████| 387/387 [00:05<00:00, 67.67it/s]


Processing 507 annotations for class ampulepoeder


100%|██████████| 507/507 [00:07<00:00, 66.98it/s]


Building ground truth for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Extracting video frames for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
Loading gaze data for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
Detected 3 gaze points for a frame in the video. This is unexpected.
Getting viewed annotations for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
Processing 862 annotations for class spuit


100%|██████████| 862/862 [00:11<00:00, 76.34it/s]


Processing 852 annotations for class keukenmes


100%|██████████| 852/852 [00:10<00:00, 77.78it/s]


Processing 951 annotations for class stethoscoop


100%|██████████| 951/951 [00:13<00:00, 69.37it/s]


Processing 873 annotations for class snoep


100%|██████████| 873/873 [00:11<00:00, 79.14it/s]


Processing 64 annotations for class nuchter


100%|██████████| 64/64 [00:00<00:00, 77.33it/s]


Processing 63 annotations for class fotokader


100%|██████████| 63/63 [00:00<00:00, 75.05it/s]


Processing 973 annotations for class bril


100%|██████████| 973/973 [00:12<00:00, 78.93it/s]


Processing 281 annotations for class monitor


100%|██████████| 281/281 [00:03<00:00, 71.46it/s]


Processing 143 annotations for class rollator


100%|██████████| 143/143 [00:02<00:00, 69.41it/s]


Processing 840 annotations for class ampulevloeistof


100%|██████████| 840/840 [00:10<00:00, 79.87it/s]


Processing 280 annotations for class ampulepoeder


100%|██████████| 280/280 [00:03<00:00, 81.33it/s]


Building ground truth for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
