https://github.com/serengil/retinaface

https://github.com/fkryan/gazelle

uv pip install retina-face

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import requests
from io import BytesIO
import numpy as np
import cv2
from tqdm import tqdm
from deepface.modules.detection import extract_faces, DetectedFace, FacialAreaRegion, is_valid_landmark

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# load Gaze-LLE model
model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
model.eval()
model.to(device)

In [42]:
colors = ['yellow', 'red', 'green', 'blue', 'lime']

In [59]:
from retinaface import RetinaFace

def process_frame(frame):
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(frame_rgb)
    width, height = img.size
    
    faces = RetinaFace.detect_faces(frame_rgb)
    if not isinstance(faces, dict):
        return frame
    
    bboxes = [faces[key]['facial_area'] for key in faces.keys()]
    norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]]
    
    img_t = transform(img).unsqueeze(0).to(device)
    
    input_data = {
        'images': img_t,
        'bboxes': norm_bboxes
    }
    
    with torch.no_grad():
        output = model(input_data)
        
    res_img = viz_all(
        img,
        output["heatmap"][0],
        norm_bboxes[0],
        output["inout"][0] if output["inout"] is not None else None,
        0.5
    )
    
    res_arr = np.array(res_img)
    return cv2.cvtColor(res_arr, cv2.COLOR_RGB2BGR)

def process_video(input_path, output_path):
    cap = cv2.VideoCapture(input_path)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2
    
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    try:
        with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), desc="Processing Video") as pbar:
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break                
                processed_frame = process_frame(frame)
                out.write(processed_frame)                
                pbar.update(1)
    finally:
        cap.release()
        out.release()
        
def viz_all(pil_img, heatmaps, bboxes, inout_scores, inout_threshold):
    over_img = pil_img.convert("RGBA")
    draw = ImageDraw.Draw(over_img)
    width, height = pil_img.size
    
    for i in range(len(bboxes)):
        bbox = bboxes[i]
        xmin, ymin, xmax, ymax = bbox
        color = colors[i % len(colors)]
        
        draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=2)
        
        if inout_scores is not None:
            inout_score = inout_scores[i].item()
            text = f"frame: {inout_score:.2f}"
            text_y = ymax * height + int(height / 100)
            draw.text((xmin * width, text_y), text, fill=color)
            
            if inout_score > inout_threshold:
                heatmap = heatmaps[i].cpu().numpy()
                max_idx = np.unravel_index(np.argmax(heatmap), heatmap.shape)
                gaze_x = max_idx[1] / heatmap.shape[1] * width
                gaze_y = max_idx[0] / heatmap.shape[0] * height
                bbox_center_x = (xmin + xmax) / 2 * width
                bbox_center_y = (ymin + ymax) / 2 * height
            
                draw.ellipse([(gaze_x - 5, gaze_y - 5), (gaze_x + 5, gaze_y + 5)], fill=color, width=2)
                draw.line([(bbox_center_x, bbox_center_y), (gaze_x, gaze_y)], fill=color, width=2)
                
    return over_img  
    
    
    

In [60]:
# input_video = 'people_gallery.mp4'
# output_video = 'people_gallery_output.mp4'
input_video = 'museum2.mp4'
output_video = 'museum2_output.mp4'

In [61]:
process_video(input_video, output_video)

Processing Video: 100%|██████████| 603/603 [05:22<00:00,  1.87it/s]
