In [None]:
!pip install retina-face

Collecting retina-face
  Downloading retina_face-0.0.17-py3-none-any.whl.metadata (10 kB)
Downloading retina_face-0.0.17-py3-none-any.whl (25 kB)
Installing collected packages: retina-face
Successfully installed retina-face-0.0.17


In [None]:
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from retinaface import RetinaFace
from tqdm import tqdm
import os

class VideoGazeAnalyzer:
    def __init__(self, use_cuda=True):
        self.device = 'cuda' if use_cuda and torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")

        # Load Gazelle model
        self.model, self.transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
        self.model.eval()
        self.model.to(self.device)

        # Colors for visualization
        self.colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']

    def process_frame(self, frame):
        """Process a single frame and return the visualization"""
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(frame_rgb)
        width, height = image.size

        # Detect faces
        resp = RetinaFace.detect_faces(frame_rgb)
        if not isinstance(resp, dict):
            return frame  # Return original frame if no faces detected

        # Extract bounding boxes
        bboxes = [resp[key]['facial_area'] for key in resp.keys()]
        norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height])
                       for bbox in bboxes]]

        # Prepare input for Gazelle
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)
        input_data = {
            "images": img_tensor,
            "bboxes": norm_bboxes
        }

        # Get model predictions
        with torch.no_grad():
            output = self.model(input_data)

        # Visualize results
        result_image = self.visualize_all(
            image,
            output['heatmap'][0],
            norm_bboxes[0],
            output['inout'][0] if output['inout'] is not None else None
        )

        # Convert back to BGR for OpenCV
        result_array = np.array(result_image)
        return cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)

    def visualize_all(self, pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
        """Visualize all detected faces and their gaze directions"""
        overlay_image = pil_image.convert("RGBA")
        draw = ImageDraw.Draw(overlay_image)
        width, height = pil_image.size

        for i in range(len(bboxes)):
            bbox = bboxes[i]
            xmin, ymin, xmax, ymax = bbox
            color = self.colors[i % len(self.colors)]

            # Draw face bounding box
            draw.rectangle(
                [xmin * width, ymin * height, xmax * width, ymax * height],
                outline=color,
                width=int(min(width, height) * 0.01)
            )

            if inout_scores is not None:
                inout_score = inout_scores[i]

                # Draw in-frame score
                text = f"in-frame: {inout_score:.2f}"
                text_y = ymax * height + int(height * 0.01)
                draw.text(
                    (xmin * width, text_y),
                    text,
                    fill=color,
                    font=None  # Using default font
                )

                # Draw gaze direction if looking in-frame
                if inout_score > inout_thresh:
                    heatmap = heatmaps[i]
                    heatmap_np = heatmap.detach().cpu().numpy()
                    max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)

                    # Calculate gaze target and face center
                    gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
                    gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
                    bbox_center_x = ((xmin + xmax) / 2) * width
                    bbox_center_y = ((ymin + ymax) / 2) * height

                    # Draw gaze target point and line
                    draw.ellipse(
                        [(gaze_target_x-5, gaze_target_y-5),
                         (gaze_target_x+5, gaze_target_y+5)],
                        fill=color,
                        width=int(0.005*min(width, height))
                    )
                    draw.line(
                        [(bbox_center_x, bbox_center_y),
                         (gaze_target_x, gaze_target_y)],
                        fill=color,
                        width=int(0.005*min(width, height))
                    )

        # Convert to RGB for OpenCV compatibility
        return overlay_image.convert('RGB')

    def process_video(self, input_path, output_path, start_time=0, duration=None):
        """Process a video file and save the result"""
        # Open video file
        cap = cv2.VideoCapture(input_path)

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Calculate start and end frames
        start_frame = int(start_time * fps)
        if duration:
            end_frame = start_frame + int(duration * fps)
        else:
            end_frame = total_frames

        # Set up video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(
            output_path,
            fourcc,
            fps,
            (frame_width, frame_height)
        )

        # Seek to start frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        try:
            with tqdm(total=end_frame-start_frame) as pbar:
                frame_count = start_frame
                while cap.isOpened() and frame_count < end_frame:
                    ret, frame = cap.read()
                    if not ret:
                        break

                    # Process frame
                    processed_frame = self.process_frame(frame)
                    out.write(processed_frame)

                    frame_count += 1
                    pbar.update(1)

        finally:
            # Clean up
            cap.release()
            out.release()
            cv2.destroyAllWindows()

In [None]:
# Example usage
analyzer = VideoGazeAnalyzer()

# Process a video file
input_video = "/content/movie.mp4"  # Replace with your video path
output_video = "output_video.mp4"


Using device: cuda


Downloading: "https://github.com/fkryan/gazelle/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitl14_pretrain.pth
100%|██████████| 1.13G/1.13G [00:06<00:00, 194MB/s]
Downloading: "https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14_inout.pt" to /root/.cache/torch/hub/checkpoints/gazelle_dinov2_vitl14_inout.pt
100%|██████████| 12.2M/12.2M [00:00<00:00, 46.0MB/s]


In [None]:
# Process 10 seconds starting from 5 seconds into the video
analyzer.process_video(
    input_video,
    output_video,
    start_time=5,  # Start 5 seconds in
    duration=10    # Process 10 seconds
)

  0%|          | 0/250 [00:00<?, ?it/s]

24-12-22 20:59:29 - Directory /root/.deepface created
24-12-22 20:59:29 - Directory /root/.deepface/weights created
24-12-22 20:59:29 - retinaface.h5 will be downloaded from the url https://github.com/serengil/deepface_models/releases/download/v1.0/retinaface.h5


Downloading...
From: https://github.com/serengil/deepface_models/releases/download/v1.0/retinaface.h5
To: /root/.deepface/weights/retinaface.h5

  0%|          | 0.00/119M [00:00<?, ?B/s][A
 39%|███▉      | 46.7M/119M [00:00<00:00, 465MB/s][A
100%|██████████| 119M/119M [00:00<00:00, 429MB/s] 
  1%|          | 2/250 [00:13<27:27,  6.64s/it]
