In [1]:
import shutil
import subprocess
import tempfile
from pathlib import Path

import cv2
import pandas as pd
from src.api.controllers.calibration_controller import (
    get_calibration_recording_by_uuid,
    get_class_by_id,
    get_gaze_data_path,
    get_recording_path,
)
from src.api.controllers.gaze_controller import (
    get_gaze_point_per_frame,
)
from src.api.models.gaze import GazePoint
from src.config import TOBII_GLASSES_FPS, VIEWED_RADIUS
from src.utils import (
    cv2_video_fps,
    cv2_video_frame_count,
    cv2_video_resolution,
    draw_annotation_on_frame,
    extract_frames_to_dir,
)
from tqdm import tqdm

from controlled_experiment.settings import (
    CLASS_ID_TO_NAME,
    FULLY_LABELED_RECORDINGS,
    TRIAL_RECORDING_UUIDS,
)

In [2]:
def draw_video_frames(
    frames: list[Path],
    gaze_point_per_frame: dict[int, GazePoint],
    predictions_df: pd.DataFrame,
):
    # Iterate over frames and draw the annotations on them if they exist
    for frame in tqdm(frames, desc="Drawing annotations on frames"):
        frame_idx = int(frame.stem)
        frame_img = cv2.imread(str(frame))

        # iterate through the predictions with the same frame index
        frame_predictions = predictions_df[predictions_df["frame_idx"] == frame_idx]
        for _, row in frame_predictions.iterrows():
            class_id = int(row["predicted_class_id"])
            class_name = CLASS_ID_TO_NAME[class_id]
            box = (int(row["x1"]), int(row["y1"]), int(row["x2"]), int(row["y2"]))

            try:
                sim_room_class = get_class_by_id(class_id)
                class_color_hex = sim_room_class.color
            except ValueError:
                # For unknown classes, use a default color
                class_color_hex = "#FF0000"

            frame_img = draw_annotation_on_frame(
                frame_img=frame_img,
                mask=None,
                box=box,
                class_color_hex=class_color_hex,
                class_name=class_name,
            )

        # Draw the gaze point on the frame
        gaze_point = gaze_point_per_frame.get(frame_idx)
        if gaze_point is not None:
            gaze_x, gaze_y = gaze_point.position
            cv2.circle(
                frame_img,
                (int(gaze_x), int(gaze_y)),
                radius=VIEWED_RADIUS,
                color=(0, 0, 255),
                thickness=2,
            )

        # Save the modified image back to its original location
        cv2.imwrite(str(frame), frame_img)

In [3]:
FINAL_PREDICTIONS_PATH = Path("data/final_predictions")
RECORDING_FRAMES_PATH = Path("data/recording_frames")

FINAL_PREDICTION_VIDEOS_PATH = Path("data/final_prediction_videos")
if FINAL_PREDICTION_VIDEOS_PATH.exists():
    shutil.rmtree(FINAL_PREDICTION_VIDEOS_PATH)
FINAL_PREDICTION_VIDEOS_PATH.mkdir(parents=True, exist_ok=True)

for trial_recording_uuid in tqdm(
    TRIAL_RECORDING_UUIDS, desc="Processing trial recordings"
):
    if trial_recording_uuid not in FULLY_LABELED_RECORDINGS:
        continue

    calibration_recording = get_calibration_recording_by_uuid(trial_recording_uuid)

    # Get statistics of the video
    trial_recording_path = get_recording_path(calibration_recording.id)
    trial_video_resolution = cv2_video_resolution(trial_recording_path)
    trial_video_fps = cv2_video_fps(trial_recording_path)
    trial_video_frame_count = cv2_video_frame_count(trial_recording_path)

    # Load and preprocess gaze points
    print(f"Loading gaze data for {trial_recording_uuid}")
    gaze_data_path = get_gaze_data_path(calibration_recording.id)
    gaze_point_per_frame = get_gaze_point_per_frame(
        gaze_data_path=gaze_data_path,
        resolution=trial_video_resolution,
        frame_count=trial_video_frame_count,
        fps=trial_video_fps,
    )

    calibration_recording = get_calibration_recording_by_uuid(trial_recording_uuid)

    print(f"Loading prediction results for {trial_recording_uuid}")
    final_predictions_path = FINAL_PREDICTIONS_PATH / f"{trial_recording_uuid}.csv"
    predictions_df = pd.read_csv(final_predictions_path)

    print(f"Extracting frames for {trial_recording_uuid}")
    tmp_frames_dir = tempfile.TemporaryDirectory()
    tmp_frames_path = Path(tmp_frames_dir.name)
    extract_frames_to_dir(
        video_path=trial_recording_path,
        frames_path=tmp_frames_path,
        print_output=False,
    )
    frames = sorted(list(tmp_frames_path.glob("*.jpg")), key=lambda x: int(x.stem))

    print(f"Drawing annotations for {trial_recording_uuid}")
    draw_video_frames(
        frames=frames,
        gaze_point_per_frame=gaze_point_per_frame,
        predictions_df=predictions_df,
    )

    print(f"Creating video for {trial_recording_uuid}")
    cmd = f'ffmpeg -hwaccel cuda -y -pattern_type glob -framerate {TOBII_GLASSES_FPS} -i "{tmp_frames_path!s}/*.jpg" -c:v libx264 -pix_fmt yuv420p "{FINAL_PREDICTION_VIDEOS_PATH}/{trial_recording_uuid}.mp4"'
    subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

Processing trial recordings:   0%|          | 0/14 [00:00<?, ?it/s]

Loading gaze data for 67b71a70-da64-467a-9fb6-91bc29265fd1
Loading prediction results for 67b71a70-da64-467a-9fb6-91bc29265fd1
Extracting frames for 67b71a70-da64-467a-9fb6-91bc29265fd1
Drawing annotations for 67b71a70-da64-467a-9fb6-91bc29265fd1


Drawing annotations on frames: 100%|██████████| 2064/2064 [00:15<00:00, 131.87it/s]


Creating video for 67b71a70-da64-467a-9fb6-91bc29265fd1


Processing trial recordings:   7%|▋         | 1/14 [00:58<12:34, 58.00s/it]

Loading gaze data for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Loading prediction results for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Extracting frames for 32f02db7-adc0-4556-a2da-ed2ba60a58c9
Drawing annotations for 32f02db7-adc0-4556-a2da-ed2ba60a58c9


Drawing annotations on frames: 100%|██████████| 1365/1365 [00:12<00:00, 105.69it/s]


Creating video for 32f02db7-adc0-4556-a2da-ed2ba60a58c9


Processing trial recordings:  14%|█▍        | 2/14 [01:38<09:35, 47.96s/it]

Loading gaze data for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
Loading prediction results for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
Extracting frames for b8eeecc0-06b1-47f7-acb5-89aab3c1724d
Drawing annotations for b8eeecc0-06b1-47f7-acb5-89aab3c1724d


Drawing annotations on frames: 100%|██████████| 1557/1557 [00:11<00:00, 136.24it/s]


Creating video for b8eeecc0-06b1-47f7-acb5-89aab3c1724d


Processing trial recordings: 100%|██████████| 14/14 [02:19<00:00,  9.93s/it]
