In [5]:
# Imports
import matplotlib.pyplot as plt
import cv2
import torch
import pathlib
import os
from tqdm import tqdm
import numpy as np
from l2cs import Pipeline, render
import pandas as pd

# Constants 
CWD = pathlib.Path(os.path.abspath(""))
GIT_ROOT = CWD.parent.parent
DATA_DIR = GIT_ROOT / "data" / 'ICMI2024'
OUTPUT_DIR = DATA_DIR / 'gaze_vectors'

os.makedirs(OUTPUT_DIR, exist_ok=True)

In [6]:
gaze_pipeline = Pipeline(
    weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
    arch='ResNet50',
    device=torch.device('cuda'), # or 'gpu'
    include_detector=False
)

In [12]:
def process(cropped_dir: pathlib.Path, output_file: pathlib.Path):
    assert cropped_dir.exists()

    output_container = {"frame": [], "tracked_id": [], "pitch": [], "yaw": []}

    for img_fp in tqdm(cropped_dir.iterdir(), total=len(list(cropped_dir.iterdir()))):
        if img_fp.suffix != '.png':
            continue

        assert img_fp.exists()
        frame = cv2.imread(str(img_fp))   

        # Process frame and visualize
        results = gaze_pipeline.step(frame)
        draw = render(frame, results)

        # cv2.imshow('frame', draw)
        # if cv2.waitKey(1) & 0xFF == ord('q'):
        #     break

        # Get pitch and yaw
        pitch, yaw = results.pitch[0], results.yaw[0]

        # Save the data
        img_name = img_fp.stem
        split_img_name = img_name.split("_")
        frame_id = split_img_name[1]
        tracked_id = split_img_name[-1]
        output_container['frame'].append(int(frame_id))
        output_container['tracked_id'].append(int(tracked_id))
        output_container['pitch'].append(pitch)
        output_container['yaw'].append(yaw)

    # At the end, write it
    df = pd.DataFrame(output_container)
    df = df.sort_values(by=['frame', 'tracked_id'])
    df.to_csv(output_file, index=False)

process(
    DATA_DIR / 'reid' / 'cropped_faces' / 'd1g1',
    OUTPUT_DIR / 'gaze_vector_d1g1.csv'
)

process(
    DATA_DIR / 'reid' / 'cropped_faces' / 'd1g2',
    OUTPUT_DIR / 'gaze_vector_d1g2.csv'
)

process(
    DATA_DIR / 'reid' / 'cropped_faces' / 'd2g1',
    OUTPUT_DIR / 'gaze_vector_d2g1.csv'
)

process(
    DATA_DIR / 'reid' / 'cropped_faces' / 'd2g2',
    OUTPUT_DIR / 'gaze_vector_d2g2.csv'
)

cv2.destroyAllWindows()

100%|██████████| 52141/52141 [10:51<00:00, 80.08it/s] 
100%|██████████| 50882/50882 [08:34<00:00, 98.96it/s] 
100%|██████████| 43030/43030 [06:21<00:00, 112.67it/s]
100%|██████████| 57104/57104 [08:33<00:00, 111.25it/s]
