In [None]:
import sys
import os
import cv2
import math
from glob import glob

import numpy as np
from PIL import Image

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from ultralytics import YOLO
from transformers import AutoTokenizer, CLIPTextModelWithProjection

from semgaze.modeling.semgaze import SemGaze
from semgaze.utils.common import dark_coordinate_decoding, square_bbox

%config InlineBackend.figure_format = "retina"
import matplotlib.pyplot as plt
import matplotlib.cm as cm


In [None]:
# Globals
DET_TRH = 0.4
IMG_MEAN = [0.44232, 0.40506, 0.36457]
IMG_STD = [0.28674, 0.27776, 0.27995]

COLOR_NAMES = ["mediumvioletred", "green", "dodgerblue", "crimson", "goldenrod", 
               "DarkSlateGray", "saddlebrown", "purple", "teal"]
COLORS = [(199, 21, 133), (0, 128, 0), (30, 144, 255), (220, 20, 60), (218, 165, 32), 
          (47, 79, 79), (139, 69, 19), (128, 0, 128), (0, 128, 128)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def get_num_params(module):
    return sum([param.numel() for param in module.parameters()])


def expand_bbox(bbox, img_w, img_h, k=0.1):
    w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
    bbox[0] = max(0, bbox[0] - k * w)
    bbox[1] = max(0, bbox[1] - k * h)
    bbox[2] = min(img_w, bbox[2] + k * w)
    bbox[3] = min(img_h, bbox[3] + k * h)
    return bbox


def load_head_detection_model(device):
    # Load and return the pre-trained head detection model
    ckpt_path = "weights/yolo11m_merged.torchscript"
    model = YOLO(ckpt_path, task="detect")
    return model

def load_text_model(device):
    # Load CLIP Text Encoder
    text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch16")
    text_model.eval()
    text_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch16")
    return text_model, text_tokenizer

def load_semgaze_model(ckpt_path, device):
    # Build model
    semgaze = SemGaze(
        image_size = 256,
        patch_size = 16,
        token_dim = 768,
        gaze_vec_dim = 2,
        encoder_num_heads = 12,
        encoder_depth = 12,
        encoder_num_global_tokens = 1,
        decoder_depth = 2,
        decoder_num_heads = 8,
        decoder_label_emb_dim = 512,
    )

    
    # Load checkpoint
    checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    logit_scale = checkpoint["state_dict"]["logit_scale"]
    state_dict = {name[6:]: value for name, value in checkpoint["state_dict"].items() if name != "logit_scale"}
    semgaze.load_state_dict(state_dict, strict=True)
    semgaze.to(device)
    semgaze.eval()
    return semgaze, logit_scale


def draw_gaze(
    image,
    head_bboxes,
    gaze_points,
    gaze_vecs,
    inouts,
    pids,
    gaze_heatmaps,
    heatmap_pid = None,
    frame_nb = None,
    colors = COLORS,
    alpha: float = 0.5,
    io_thr: float = 0.4, 
    gaze_pt_size: int = 10,
    gaze_vec_factor: float = 0.8,
    head_center_size: int = 10,
    thickness: int = 4,
    fs: float = 0.6,
):
    """
    Draws gaze results on the given image.
 
    Args:
        image (np.ndarray): The input image on which to draw.
        head_bboxes (array-like): Bounding boxes for heads.
        gaze_points (array-like): Points representing gaze locations.
        gaze_vecs (array-like): Vectors representing gaze directions.
        inouts (array-like): In/out scores for each head.
        pids (array-like): Person IDs for each head.
        gaze_heatmaps (array-like): Heatmaps for gaze.
        heatmap_pid (int, optional): Person ID for which to draw the heatmap. Defaults to None.
        frame_nb (int, optional): Frame number to display on the image. Defaults to None.
        colors (array-like, optional): Colors to use for drawing. Defaults to COLORS.
        alpha (float, optional): Alpha blending value for heatmap overlay. Defaults to 0.5.
        io_thr (float, optional): Threshold for in/out scores to draw gaze points. Defaults to 0.5.
        gaze_pt_size (int, optional): Size of the gaze points. Defaults to 10.
        gaze_vec_factor (float, optional): Scaling factor for gaze vectors. Defaults to 0.8.
        head_center_size (int, optional): Size of the head center points. Defaults to 10.
        thickness (int, optional): Thickness of the drawing lines. Defaults to 4.
        fs (float, optional): Font scale for text. Defaults to 0.6.
    Returns:
        np.ndarray: The image with gaze results drawn on it.
    """
    # Create canvas on which to draw predictions
    img_h, img_w, img_c = image.shape
    canvas = image.copy()
    
    # Scale of the drawing according to image resolution
    scale = max(img_h, img_w) / 1920
    fs *= scale
    thickness = int(scale * thickness)
    gaze_pt_size = int(scale * gaze_pt_size)
    head_center_size = int(scale * head_center_size)
    
    # Draw heatmap
    if heatmap_pid is not None:
        if len(gaze_heatmaps) == 0:
            raise ValueError("gaze_heatmaps must be provided if heatmap_pid is provided.")
        mask = (pids == heatmap_pid)
        if mask.sum() == 1: # only if detection found
            gaze_heatmap = gaze_heatmaps[mask]
            heatmap = TF.resize(gaze_heatmap, (img_h, img_w), antialias=True).squeeze().numpy()
            heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
            heatmap = cm.inferno(heatmap) * 255 
            canvas = ((1 - alpha) * image + alpha * heatmap[..., :3]).astype(np.uint8)

            # Write pid being used for the heatmap
            hm_pid_text = f"Heatmap PID: {heatmap_pid}"
            (w_text, h_text), _ = cv2.getTextSize(hm_pid_text, cv2.FONT_HERSHEY_SIMPLEX, fs, 1)
            ul = (img_w - w_text - 20, img_h - h_text - 15)
            br = (img_w, img_h)
            cv2.rectangle(canvas, ul, br, (0, 0, 0), -1)
            hm_pid_text_loc = (img_w - w_text - 10, img_h - 10)
            cv2.putText(canvas, hm_pid_text, hm_pid_text_loc, cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 1, cv2.LINE_AA)   

    # Draw head bboxes  
    if len(head_bboxes) > 0:
        if len(pids) == 0:
            raise ValueError("pids must be provided if head_bboxes is provided.")
        
        # Convert to numpy
        head_bboxes = head_bboxes.numpy() if isinstance(head_bboxes, torch.Tensor) else np.array(head_bboxes)
        inouts = inouts.numpy() if isinstance(inouts, torch.Tensor) else np.array(inouts)
        if head_bboxes.max() <= 1.0:
            head_bboxes = head_bboxes * np.array([img_w, img_h, img_w, img_h])
        head_bboxes = head_bboxes.astype(int)
        
        # Compute head center
        head_centers = np.hstack([(head_bboxes[:,[0]] + head_bboxes[:,[2]]) / 2, (head_bboxes[:,[1]] + head_bboxes[:,[3]]) / 2])
        head_centers = head_centers.astype(int)
        
        gaze_available = (len(gaze_points) > 0)
        if gaze_available and (len(inouts) == 0):
            raise ValueError("inouts must be provided if gaze_pts is provided.")
            
        if gaze_available:
            gaze_points = gaze_points.numpy() if isinstance(gaze_points, torch.Tensor) else np.array(gaze_points)
            if (gaze_points.max() <= 1.):
                gaze_points = gaze_points * np.array([img_w, img_h])
            gaze_points = gaze_points.astype(int)
            
        if gaze_vecs is not None:
            gaze_vecs = gaze_vecs.numpy() if isinstance(gaze_vecs, torch.Tensor) else np.array(gaze_vecs)
        
        for i, head_bbox in enumerate(head_bboxes):
            
            if (heatmap_pid is not None) and (heatmap_pid != i):
                continue
            
            xmin, ymin, xmax, ymax = head_bbox
            head_radius = max(xmax-xmin, ymax-ymin) // 2
            pid = pids[i]
            color = colors[pid % len(colors)]
                            
            # Compute Head Center
            head_center = head_centers[i]
        
            head_bbox_ul = (xmin, ymin)
            head_bbox_br = (xmax, ymax)
            head_center_ul = head_center - (head_center_size // 2)
            head_center_br = head_center + (head_center_size // 2)
            cv2.rectangle(canvas, head_center_ul, head_center_br, color, -1) # head center point
            cv2.circle(canvas, head_center, head_radius, color, thickness) # head circle
            
            # Draw header
            io = inouts[i] if inouts is not None else "-"
            header_text = f"P{pid}: {io:.2f}"
            (w_text, h_text), _ = cv2.getTextSize(header_text, cv2.FONT_HERSHEY_SIMPLEX, fs, 1)
            
            header_ul =  (int(head_center[0] - w_text / 2), int(ymin - thickness / 2))
            header_br = (int(head_center[0] + w_text / 2), int(ymin + h_text + 5))
            cv2.rectangle(canvas, header_ul, header_br, color, -1) # header bbox
            cv2.putText(canvas, header_text, (header_ul[0], int(ymin + h_text)), cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 1, cv2.LINE_AA) # header text
            
            if gaze_available and (io > io_thr):
                gp = gaze_points[i]
                vec = (gp - head_center)
                vec = vec / (np.linalg.norm(vec) + 0.000001)
                intersection = head_center + (vec * head_radius).astype(int)
                #cv2.line(canvas, head_center, gp, color, int(0.5 * thickness)) # UNCOMMENT
                cv2.line(canvas, intersection, gp, color, thickness)
                
                cv2.circle(canvas, gp, gaze_pt_size, color, -1)
                
            if gaze_vecs is not None:
                gv = gaze_vecs[i]
                cv2.arrowedLine(canvas, head_center, (head_center + gaze_vec_factor * head_radius * gv).astype(int), color, thickness)
                
                
    # Write frame number
    if frame_nb is not None:
        frame_nb = str(frame_nb)
        (w_text, h_text), _ = cv2.getTextSize(frame_nb, cv2.FONT_HERSHEY_SIMPLEX, fs, 1)
        nb_ul = (int((img_w - w_text) / 2), (img_h - h_text - 15))
        nb_br = (int((img_w + w_text) / 2), img_h)
        cv2.rectangle(canvas, nb_ul, nb_br, (0, 0, 0), -1)
        nb_text_loc = (int((img_w - w_text) / 2), (img_h - 10))
        cv2.putText(canvas, frame_nb, nb_text_loc, cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 1, cv2.LINE_AA) 

    return canvas



def predict_gaze(img_path, semgaze, head_detector):
    # 1. Read image
    image = Image.open(img_path).convert('RGB')
    image_np = np.array(image)
    img_h, img_w, img_c = image_np.shape

    # 2. detect & process head bboxes
    detections = head_detector(image, verbose=False)[0]
    head_bboxes = []
    for bbox, conf in zip(detections.boxes.xyxy, detections.boxes.conf):
        if conf > DET_TRH:
            #bbox = expand_bbox(bbox.cpu(), img_w, img_h, k=0.1)
            head_bboxes.append(bbox)     
    head_bboxes = torch.stack(head_bboxes).cpu()
    t_head_bboxes = square_bbox(head_bboxes, img_w, img_h)

    num_heads = len(head_bboxes)
    print(f"Detected {num_heads} heads.")

    # 3. Extract and transform heads
    heads = []
    for bbox in t_head_bboxes:
        head = TF.resize(TF.to_tensor(image.crop(bbox.numpy())), (224, 224))
        heads.append(head)
    heads = torch.stack(heads)
    heads = TF.normalize(heads, mean=IMG_MEAN, std=IMG_STD)

    # 4. Transform Image
    image = TF.to_tensor(image)
    image = TF.resize(image, (256, 256))
    image = TF.normalize(image, mean=IMG_MEAN, std=IMG_STD)

    # 5. Normalize head bboxes
    scale = torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    t_head_bboxes /= scale

    # 6. build input sample
    sample = {}
    sample["image"] = image.unsqueeze(0).expand(num_heads, -1, -1, -1).to(device) # (num_heads, 1, 3, 224, 224)
    sample["heads"] = heads.unsqueeze(1).to(device) # (num_heads, 1, 3, 224, 224)
    sample["head_bboxes"] = t_head_bboxes.unsqueeze(1).to(device) # (num_heads, 1, 4)

    # 7. predict gaze
    with torch.no_grad():
        gaze_heatmaps, gaze_vecs, gaze_label_embs = semgaze(sample)
        gaze_heatmaps = gaze_heatmaps.squeeze(1).cpu()
        gaze_vecs = gaze_vecs.squeeze(1).cpu()    
        gaze_points = dark_coordinate_decoding(gaze_heatmaps, kernel_size=9, normalize=True)
        gaze_label_embs = gaze_label_embs.squeeze(1).cpu()
  
    return image_np, head_bboxes, gaze_points, gaze_vecs, gaze_heatmaps, gaze_label_embs



In [None]:
# Load Head Detector
head_detector = load_head_detection_model(device)
head_detector_num_params = get_num_params(head_detector)
head_detector_num_params

In [None]:
# Load SemGaze model
ckpt_path = "checkpoints/gazefollow.ckpt"

semgaze, logit_scale = load_semgaze_model(ckpt_path, device)
semgaze_num_params = get_num_params(semgaze)
semgaze_num_params

In [None]:
# Load text model
text_model, text_tokenizer = load_text_model(device)

# Compute vocabulary class embeddings (we use COCO classes here, but you can use any other set of classes)
VOCAB = [
    'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
    'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
    'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
    'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
    'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
    'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
    'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
    'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
    'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
    'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
    'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

vocab_embs = []
with torch.no_grad():
    for class_name in VOCAB:
        inputs = text_tokenizer(class_name, return_tensors="pt")
        outputs = text_model(**inputs)
        class_emb = outputs.text_embeds.squeeze(0)
        vocab_embs.append(class_emb)

# Normalize Class Embeddings
vocab_embs = F.normalize(torch.stack(vocab_embs), p=2, dim=-1)
vocab_embs.shape

In [None]:
# Uncomment and execute this cell once so matplotlib will display the plots inline
#%matplotlib inline 

In [None]:
# Predict gaze for a given image
img_path = "samples/image.jpg" # change to the path of the image you want to test e.g. "data/image2.jpg" or "data/image3.jpg"
output = predict_gaze(img_path, semgaze, head_detector)
image_np, head_bboxes, gaze_points, gaze_vecs, gaze_heatmaps, gaze_label_embs = output

# Convert gaze label embeddings to class probabilities
gaze_label_logits = gaze_label_embs @ vocab_embs.T * logit_scale.data.exp()
gaze_label_probs = gaze_label_logits.softmax(dim=1)

In [None]:
def probs_to_labels(probs, vocab, top_k=3):
    labels = {}
    for pid, prob in enumerate(probs):
        values, sort_indices = prob.sort(descending=True)
        top_k_labels = [(vocab[idx], round(value.item(), 3)) for value, idx in zip(values[:top_k], sort_indices[:top_k])]
        labels[pid] = top_k_labels
    return labels

probs_to_labels(gaze_label_probs, VOCAB)

In [None]:
# Visualize
show_gaze_vec = True
alpha = 0.7
fs = 1.
thickness = 10
gaze_pt_size = 20
head_center_size = 18
gaze_vec_factor = 0.6

img_h, img_w = image_np.shape[:2]
num_people = len(head_bboxes)
pids = np.arange(num_people)

num_axes = 2 + num_people
ncols = 2
nrows = np.ceil(num_axes / ncols).astype(int)
fig_w = 20
ax_w = fig_w // ncols
ax_h = int(round(ax_w * img_h / img_w))

fig, axes = plt.subplots(figsize = (fig_w, ax_h * nrows), nrows = nrows, ncols = ncols, tight_layout=True)
axes = axes.flatten()
[ax.axis("off") for ax in axes]

# Show input image
axes[0].imshow(image_np)

# Iterate over people and show the heatmap of each. The first iteration (ie. None) shows all predictions without heatmaps
i = 1
for heatmap_pid in [None] + np.arange(num_people).tolist():
    frame = draw_gaze(image_np, 
                      head_bboxes = head_bboxes, 
                      gaze_points = gaze_points, 
                      gaze_vecs = gaze_vecs if show_gaze_vec else None, 
                      inouts = np.ones(num_people), 
                      pids = pids, 
                      gaze_heatmaps = gaze_heatmaps, 
                      heatmap_pid = heatmap_pid, 
                      frame_nb = None, 
                      colors = COLORS,
                      alpha = alpha, 
                      gaze_pt_size = gaze_pt_size,
                      gaze_vec_factor = gaze_vec_factor,
                      head_center_size = head_center_size,
                      thickness = thickness,
                      fs = fs,
                     ) 

    axes[i].imshow(frame)
    i += 1