In [20]:
import cv2, time, numpy as np, torch
import numpy as np
from collections import deque
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
from typing import List, Tuple
import mediapipe as mp
from ultralytics import YOLO

In [28]:
# ---------- helpers ----------
def iou_xyxy(a, b):
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b
    inter_x1, inter_y1 = max(ax1, bx1), max(ay1, by1)
    inter_x2, inter_y2 = min(ax2, bx2), min(ay2, by2)
    iw, ih = max(0, inter_x2 - inter_x1), max(0, inter_y2 - inter_y1)
    inter = iw * ih
    if inter <= 0: return 0.0
    area_a = max(0, (ax2 - ax1)) * max(0, (ay2 - ay1))
    area_b = max(0, (bx2 - bx1)) * max(0, (by2 - by1))
    denom = area_a + area_b - inter
    return 0.0 if denom <= 0 else inter / denom

def normalize_dets_dict(dets_raw):
    """
    Accepts detector output like:
      [{'bbox':[x1,y1,x2,y2], 'conf':0.8}, ...]
    Returns list of ((x1,y1,x2,y2), conf) with floats.
    """
    clean = []
    if not dets_raw:
        return clean
    for d in dets_raw:
        try:
            box = d.get('bbox', None)
            conf = d.get('conf', None)
            if box is None or conf is None or len(box) != 4:
                continue
            x1, y1, x2, y2 = map(float, box)
            conf = float(conf)
            clean.append(((x1, y1, x2, y2), conf))
        except Exception as e:
            # Skip any malformed entries
            print("⚠️ Skipped invalid det:", d, e)
            continue
    return clean

def aggregate_temporal(dets_history, iou_merge_thr=0.6):
    """
    dets_history: deque where each item is [((x1,y1,x2,y2), conf), ...]
    Returns aggregated list [(bbox, conf), ...] using IOU clustering + confidence-weighted averaging.
    """
    all_dets = []
    for dets in dets_history:
        if not dets: 
            continue
        for (box, conf) in dets:
            if box is None: 
                continue
            x1,y1,x2,y2 = map(float, box)
            all_dets.append([np.array([x1,y1,x2,y2], dtype=np.float32), float(conf)])

    if not all_dets:
        return []

    clusters = []  # each: {'sum': sum(conf*box), 'w': sum(conf), 'max_conf': max(conf)}
    for box, conf in all_dets:
        matched = -1
        best_iou = 0.0
        # Compare with cluster average (sum/w)
        for idx, c in enumerate(clusters):
            avg_box = c['sum'] / max(c['w'], 1e-6)
            i = iou_xyxy(box, avg_box)
            if i > best_iou:
                best_iou, matched = i, idx
        if best_iou >= iou_merge_thr and matched >= 0:
            clusters[matched]['sum'] += conf * box
            clusters[matched]['w'] += conf
            clusters[matched]['max_conf'] = max(clusters[matched]['max_conf'], conf)
        else:
            clusters.append({'sum': conf * box.copy(), 'w': conf, 'max_conf': conf})

    aggregated = []
    for c in clusters:
        if c['w'] <= 0: 
            continue
        avg_box = (c['sum'] / c['w']).astype(np.float32)
        aggregated.append((tuple(avg_box.tolist()), float(c['max_conf'])))
    return aggregated


In [21]:
# -----------------------------
# Utils
# -----------------------------
def iou_xyxy(a, b):
    # a, b: (x1,y1,x2,y2)
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b
    inter_x1 = max(ax1, bx1); inter_y1 = max(ay1, by1)
    inter_x2 = min(ax2, bx2); inter_y2 = min(ay2, by2)
    iw = max(0, inter_x2 - inter_x1 + 1)
    ih = max(0, inter_y2 - inter_y1 + 1)
    inter = iw * ih
    if inter == 0:
        return 0.0
    area_a = (ax2 - ax1 + 1) * (ay2 - ay1 + 1)
    area_b = (bx2 - bx1 + 1) * (by2 - by1 + 1)
    return inter / float(area_a + area_b - inter)

def cosine_distance(a: np.ndarray, b: np.ndarray):
    # a, b: (D,)
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) + 1e-12
    return 1.0 - float(np.dot(a, b) / denom)  # 0 is identical, 2 is opposite

def crop_xyxy(img, box):
    x1,y1,x2,y2 = [int(v) for v in box]
    h, w = img.shape[:2]
    x1 = max(0, x1); y1 = max(0, y1)
    x2 = min(w-1, x2); y2 = min(h-1, y2)
    if x2 <= x1 or y2 <= y1:
        return None
    return img[y1:y2, x1:x2].copy()

# -----------------------------
# Appearance encoder (ReID embedder)
# -----------------------------
class ReIDEmbedder(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        # Lightweight backbone: ResNet18 global pooled features (512-D)
        # Tip: if weights fail to download in your environment, set weights=None and train/freeze later.
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.backbone.fc = nn.Identity()
        self.device = device
        self.to(device).eval()

        self.tf = T.Compose([
            T.ToPILImage(),
            T.Resize((128, 128)),
            T.ToTensor(),
            T.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
        ])

    @torch.no_grad()
    def forward(self, imgs: List[np.ndarray]) -> np.ndarray:
        if len(imgs) == 0:
            return np.zeros((0, 512), dtype=np.float32)
        batch = torch.stack([self.tf(im) for im in imgs]).to(self.device)
        feats = self.backbone(batch)               # (N, 512)
        feats = nn.functional.normalize(feats, p=2, dim=1)  # L2-normalize
        return feats.detach().cpu().numpy()        # (N, 512)

# -----------------------------
# Track structure
# -----------------------------
class Track:
    def __init__(self, tid: int, bbox, conf: float, emb: np.ndarray, frame_idx: int):
        self.id = tid
        self.bbox = np.array(bbox, dtype=np.float32)  # (x1,y1,x2,y2)
        self.conf = float(conf)
        self.emb = emb.astype(np.float32)  # (D,)
        self.last_seen = frame_idx
        self.age = 0       # frames since creation
        self.time_since_update = 0
        self.state = "active"  # "active" | "lost"

    def update(self, bbox, conf, emb, frame_idx, emb_momentum=0.9):
        self.bbox = np.array(bbox, dtype=np.float32)
        self.conf = float(conf)
        # EMA on embedding
        self.emb = emb_momentum * self.emb + (1.0 - emb_momentum) * emb
        self.emb = self.emb / (np.linalg.norm(self.emb) + 1e-12)
        self.last_seen = frame_idx
        self.time_since_update = 0
        self.state = "active"
        self.age += 1

    def mark_lost(self):
        self.state = "lost"

# -----------------------------
# IOU + ReID Tracker
# -----------------------------
class IOUReIDTracker:
    def __init__(
        self,
        reid_encoder: ReIDEmbedder,
        iou_thresh: float = 0.3,
        reid_thresh: float = 0.35,     # cosine distance threshold (0 = identical)
        max_age_lost: int = 50,        # keep lost tracks this many frames for re-id
        w_iou: float = 0.5,            # cost blend: cost = w_iou*(1-IOU) + (1-w_iou)*cosine_dist
        device: str = 'cpu'
    ):
        self.encoder = reid_encoder
        self.iou_thresh = iou_thresh
        self.reid_thresh = reid_thresh
        self.max_age_lost = max_age_lost
        self.w_iou = w_iou
        self.device = device

        self.tracks: List[Track] = []
        self.next_id = 1
        self.frame_idx = -1

    def _match_by_iou(self, dets: List[Tuple], iou_thr: float):
        # Greedy IOU matching
        # dets: list of (bbox, conf, emb)
        unmatched_tracks = list(range(len(self.tracks)))
        unmatched_dets = list(range(len(dets)))
        matches = []

        if len(unmatched_tracks) == 0 or len(unmatched_dets) == 0:
            return matches, unmatched_tracks, unmatched_dets

        iou_matrix = np.zeros((len(self.tracks), len(dets)), dtype=np.float32)
        for ti, tr in enumerate(self.tracks):
            if tr.state != "active":
                continue
            for di, (dbox, _, _) in enumerate(dets):
                iou_matrix[ti, di] = iou_xyxy(tr.bbox, dbox)

        while True:
            ti, di = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape)
            max_iou = iou_matrix[ti, di]
            if max_iou < iou_thr:
                break
            matches.append((ti, di))
            # invalidate row/col
            iou_matrix[ti, :] = -1
            iou_matrix[:, di] = -1
            if ti in unmatched_tracks: unmatched_tracks.remove(ti)
            if di in unmatched_dets: unmatched_dets.remove(di)

            if (iou_matrix > -1).sum() == 0:
                break

        return matches, unmatched_tracks, unmatched_dets

    def _match_by_reid(self, lost_track_ids: List[int], dets: List[Tuple]):
        # Lost tracks can be matched to *any* remaining det via appearance only
        # cost = cosine distance; accept if below threshold
        # We’ll do greedy min-cost.
        if len(lost_track_ids) == 0 or len(dets) == 0:
            return [], lost_track_ids, list(range(len(dets)))

        cost = np.full((len(lost_track_ids), len(dets)), 1e3, dtype=np.float32)
        for i, ti in enumerate(lost_track_ids):
            tr = self.tracks[ti]
            for j, (_, _, demb) in enumerate(dets):
                cost[i, j] = cosine_distance(tr.emb, demb)

        matches = []
        used_tracks, used_dets = set(), set()
        while True:
            idx = np.unravel_index(np.argmin(cost), cost.shape)
            i, j = idx
            best = cost[i, j]
            if best > self.reid_thresh:
                break
            ti = lost_track_ids[i]
            if ti in used_tracks or j in used_dets:
                cost[i, j] = 1e3
                continue
            matches.append((ti, j))
            used_tracks.add(ti); used_dets.add(j)
            cost[i, :] = 1e3
            cost[:, j] = 1e3
            if (cost < 1e3).sum() == 0:
                break

        unmatched_tracks = [ti for ti in lost_track_ids if ti not in used_tracks]
        unmatched_dets = [j for j in range(len(dets)) if j not in used_dets]
        return matches, unmatched_tracks, unmatched_dets

    def update(self, frame_bgr: np.ndarray, detections: List[Tuple[Tuple[int,int,int,int], float]]):
        """
        detections: list of (bbox, conf) with bbox=(x1,y1,x2,y2)
        Returns: list of active tracks
        """
        self.frame_idx += 1

        # STEP 0: extract embeddings for all detections
        det_crops, valid_idx = [], []
        for i, (box, _) in enumerate(detections):
            crop = crop_xyxy(frame_bgr, box)
            if crop is None or crop.size == 0:
                continue
            det_crops.append(crop); valid_idx.append(i)

        det_embs = self.encoder(det_crops)  # (M, 512)
        # Rebuild dets with embeddings (skip invalid crops)
        dets = []
        for k, vi in enumerate(valid_idx):
            box, conf = detections[vi]
            dets.append((np.array(box, dtype=np.float32), float(conf), det_embs[k]))

        # Mark aging & move old actives to lost if not updated for long
        for tr in self.tracks:
            tr.time_since_update += 1
            # Passive aging: we won't drop immediately; we handle pruning below.

        # STEP 1: IOU matching among ACTIVE tracks
        active_ids = [i for i, tr in enumerate(self.tracks) if tr.state == "active"]
        matches_iou, unmatched_tracks, unmatched_dets = self._match_by_iou(dets, self.iou_thresh)

        # Apply IOU matches
        used_det_ids = set()
        for ti, di in matches_iou:
            tr = self.tracks[ti]
            dbox, dconf, demb = dets[di]
            tr.update(dbox, dconf, demb, self.frame_idx)
            used_det_ids.add(di)

        # STEP 2: ReID matching — allow LOST tracks to come back
        # Collect candidate lost tracks that are within reid horizon
        lost_candidates = []
        for i, tr in enumerate(self.tracks):
            if tr.state == "lost":
                if (self.frame_idx - tr.last_seen) <= self.max_age_lost:
                    lost_candidates.append(i)

        # Build list of remaining dets (not matched yet)
        remaining_dets = [dets[i] for i in range(len(dets)) if i not in used_det_ids]
        reid_matches, unmatched_lost, remaining_unmatched_det_indices = self._match_by_reid(lost_candidates, remaining_dets)

        # Map remaining_dets indices back to original det indices
        rem_map = [i for i in range(len(dets)) if i not in used_det_ids]
        for ti, j in reid_matches:
            di = rem_map[j]
            dbox, dconf, demb = dets[di]
            tr = self.tracks[ti]
            tr.update(dbox, dconf, demb, self.frame_idx)
            used_det_ids.add(di)

        # STEP 3: Create new tracks for unmatched detections
        for di in range(len(dets)):
            if di in used_det_ids:
                continue
            dbox, dconf, demb = dets[di]
            new_tr = Track(self.next_id, dbox, dconf, demb, self.frame_idx)
            self.tracks.append(new_tr)
            self.next_id += 1

        # STEP 4: Mark unmatched ACTIVE tracks as LOST
        unmatched_active = [i for i in unmatched_tracks if self.tracks[i].state == "active"]
        for ti in unmatched_active:
            self.tracks[ti].mark_lost()

        # STEP 5: Prune very old LOST tracks
        pruned = []
        for tr in self.tracks:
            if tr.state == "lost" and (self.frame_idx - tr.last_seen) > self.max_age_lost:
                pruned.append(tr.id)
        if pruned:
            self.tracks = [tr for tr in self.tracks if not (tr.state == "lost" and (self.frame_idx - tr.last_seen) > self.max_age_lost)]

        # Return ACTIVE tracks only
        return [tr for tr in self.tracks if tr.state == "active"]

# -----------------------------
# Example usage per frame
# -----------------------------
# Initialize once


# In your video loop:
# detections = [(bbox, conf), ...]  # from your detector
# tracks = tracker.update(frame, detections)
# for tr in tracks:
#     x1,y1,x2,y2 = map(int, tr.bbox)
#     cv2.rectangle(frame, (x1,y1), (x2,y2), (0,255,0), 2)
#     # top-left: ID
#     cv2.putText(frame, f"ID {tr.id}", (x1, max(0, y1-6)),
#                 cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
#     # top-right: Conf
#     txt = f"Conf {tr.conf:.2f}"
#     (tw, th), _ = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
#     cv2.putText(frame, txt, (x2 - tw, max(0, y1-6)),
#                 cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2)


In [14]:
# ---------- YOLO detector wrapper ----------
class YoloFaceDetector:
    def __init__(self, weights="yolov12n-face.pt", imgsz=640, conf=0.35, iou=0.5):
        self.model = YOLO(weights)
        self.imgsz = imgsz
        self.conf = conf
        self.iou = iou

    def __call__(self, frame_bgr):
        results = self.model.predict(
            source=frame_bgr, imgsz=self.imgsz,
            conf=self.conf, iou=self.iou, verbose=False, device="cpu"
        )[0]
        dets = []
        if results.boxes is not None and len(results.boxes) > 0:
            xyxy = results.boxes.xyxy.cpu().numpy()
            conf = results.boxes.conf.cpu().numpy()
            for k in range(len(xyxy)):
                x1, y1, x2, y2 = xyxy[k].tolist()
                dets.append({"bbox": [x1, y1, x2, y2], "conf": float(conf[k])})
        return dets



In [15]:
# ------------------------ Utils ------------------------
def clamp_box(x1, y1, x2, y2, w, h):
    return [max(0, x1), max(0, y1), min(w-1, x2), min(h-1, y2)]

def pad_and_square(b, pad, w, h):
    x1,y1,x2,y2 = b
    cx = (x1+x2)/2; cy = (y1+y2)/2; s = max(x2-x1, y2-y1) * (1+pad*2)
    nx1 = cx - s/2; ny1 = cy - s/2; nx2 = cx + s/2; ny2 = cy + s/2
    return clamp_box(nx1, ny1, nx2, ny2, w, h)

def preprocess_face(frame_bgr, box, size=160):
    x1,y1,x2,y2 = map(int, box); crop = frame_bgr[y1:y2, x1:x2]
    crop = cv2.resize(crop, (size, size), interpolation=cv2.INTER_LINEAR)
    rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    # ImageNet-ish normalization (adapt to your training)
    mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225])
    rgb = (rgb - mean) / std
    chw = np.transpose(rgb, (2,0,1))
    return chw

class EMA:
    def __init__(self, alpha=0.6, init=None): self.a=alpha; self.v=init
    def __call__(self, x):
        self.v = x if self.v is None else self.a*x + (1-self.a)*self.v
        return self.v

# MediaPipe face mesh
mp_face_mesh = mp.solutions.face_mesh
mesh = mp_face_mesh.FaceMesh(static_image_mode=False, max_num_faces=8,
                             refine_landmarks=True, min_detection_confidence=0.5,
                             min_tracking_confidence=0.5)

def mouth_aspect_ratio(landmarks):
    """
    Use MediaPipe FaceMesh indices:
    mouth corners ~ 78 (left), 308 (right)
    upper/lower inner lip center ~ 13 (upper), 14 (lower)
    """
    p = landmarks
    A = np.linalg.norm(p[13] - p[14])       # vertical
    B = np.linalg.norm(p[78] - p[308])      # horizontal
    return float(A/(B+1e-6))



I0000 00:00:1760405981.895969 353355743 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.3), renderer: Apple M3 Pro
W0000 00:00:1760405981.899737 353372820 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1760405981.923389 353372819 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


In [22]:
# ---------- helpers ----------
def iou_xyxy(a, b):
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b
    inter_x1, inter_y1 = max(ax1, bx1), max(ay1, by1)
    inter_x2, inter_y2 = min(ax2, bx2), min(ay2, by2)
    iw, ih = max(0, inter_x2 - inter_x1), max(0, inter_y2 - inter_y1)
    inter = iw * ih
    if inter <= 0: return 0.0
    area_a = max(0, (ax2 - ax1)) * max(0, (ay2 - ay1))
    area_b = max(0, (bx2 - bx1)) * max(0, (by2 - by1))
    denom = area_a + area_b - inter
    return 0.0 if denom <= 0 else inter / denom

def aggregate_temporal(dets_history, iou_merge_thr=0.6):
    """
    dets_history: deque of lists; each list is [(bbox, conf), ...] with bbox=(x1,y1,x2,y2)
    Returns aggregated list [(bbox, conf), ...] using IOU-based clustering and confidence-weighted averaging.
    """
    all_dets = []
    for dets in dets_history:
        if not dets: 
            continue
        for (box, conf) in dets:
            if box is None: 
                continue
            x1,y1,x2,y2 = map(float, box)
            all_dets.append([np.array([x1,y1,x2,y2], dtype=np.float32), float(conf)])

    if not all_dets:
        return []

    clusters = []  # each: dict(keys: 'sum', 'w', 'max_conf')
    for box, conf in all_dets:
        matched = -1
        best_iou = 0.0
        for idx, c in enumerate(clusters):
            i = iou_xyxy(box, c['sum'] / max(c['w'], 1e-6))
            if i > best_iou:
                best_iou, matched = i, idx
        if best_iou >= iou_merge_thr and matched >= 0:
            clusters[matched]['sum'] += conf * box
            clusters[matched]['w'] += conf
            clusters[matched]['max_conf'] = max(clusters[matched]['max_conf'], conf)
        else:
            clusters.append({'sum': conf * box.copy(), 'w': conf, 'max_conf': conf})

    aggregated = []
    for c in clusters:
        if c['w'] <= 0: 
            continue
        avg_box = (c['sum'] / c['w']).astype(np.float32)
        aggregated.append((tuple(avg_box.tolist()), float(c['max_conf'])))
    return aggregated


In [51]:
def main():
    cap = cv2.VideoCapture(0)  # webcam
    win = "YOLO + IOU+ReID (Temporal-agg)"
    cv2.namedWindow(win, cv2.WINDOW_NORMAL)
    # det = YoloFaceDetector(weights="best.pt", imgsz=640, conf=0.4)
    det = YoloFaceDetector(weights="yolov12n-face.pt", imgsz=640, conf=0.4)
    # det = YoloFaceDetector(weights="face_yolo12_best.pt", imgsz=640, conf=0.4)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embedder = ReIDEmbedder(device=device)
    tracker = IOUReIDTracker(embedder, iou_thresh=0.4, reid_thresh=0.35, max_age_lost=80, w_iou=0.5, device=device)

    K = 3                      # run detector every K frames
    fidx = 0
    last_dets = []             # normalized detections
    A = 3                      # aggregate the last A detection frames
    det_history = deque(maxlen=A)

    try:
        while True:
            ok, frame = cap.read()
            if not ok:
                break
            fidx += 1
    
            # 1) DETECT every Kth frame; otherwise reuse last normalized dets
            if fidx % K == 1:
                raw = det(frame) or []        # -> [{'bbox': [...], 'conf': ...}, ...]
                dets = normalize_dets_dict(raw)
                last_dets = dets
            else:
                dets = last_dets
    
            # 2) aggregate recent detections
            det_history.append(dets)
            agg_dets = aggregate_temporal(det_history, iou_merge_thr=0.6)  # -> [((x1,y1,x2,y2), conf), ...]
    
            # 3) TRACK
            tracks = tracker.update(frame, agg_dets)
    
            # 4) Draw: ID (top-left), Conf (top-right)
            for tr in tracks:
                x1, y1, x2, y2 = map(int, tr.bbox)
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0,255,0), 2)
                
    
                y_text = max(0, y1 - 6)
                # ID left
                cv2.putText(frame, f"ID {tr.id}", (x1, y_text),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
                # Conf right
                text = f"Conf {tr.conf:.2f}"
                (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
                cv2.putText(frame, text, (x2 - tw, y_text),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2)
    
            cv2.imshow(win, frame)
            k = cv2.waitKey(1) & 0xFF
            if k in (27, ord('q')): # ESC or q
                print(dets)
                break

            # also exit if user clicks the window's X button
            if cv2.getWindowProperty(win, cv2.WND_PROP_VISIBLE) < 1:
                break
    
        # cap.release()
        # cv2.destroyAllWindows()
    finally:
        cap.release()
        cv2.destroyWindow(win)                # close just this window
        # pump the event queue a few times so the OS actually closes it
        for _ in range(3):
            cv2.waitKey(1)
        # tiny sleep can help on some macOS builds
        time.sleep(0.05)

In [52]:
if __name__ == "__main__":
    main()

[((881.050048828125, 513.7398071289062, 1239.62890625, 932.495361328125), 0.8655760288238525), ((0.0, 610.517578125, 123.80026245117188, 984.6851806640625), 0.7591698169708252), ((392.1365661621094, 822.3052368164062, 461.5881042480469, 900.2393188476562), 0.7365306615829468)]
