In [13]:
from facenet_pytorch import InceptionResnetV1
from mtcnn import MTCNN
from mtcnn.utils.images import load_image, load_images_batch
from mtcnn.utils.plotting import plot
import torch
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import random
import chromadb
import pandas as pd
import queue
import threading

In [14]:
class App:
    class NoFaceDetectedError(Exception):
        """Raised when no face is detected in an image."""
        def __init__(self, message="No face detected in the provided image."):
            super().__init__(message)

    class VideoStream:
        """Threaded camera capture to avoid blocking"""
        def __init__(self, src=0):
            self.cap = cv2.VideoCapture(src)
            if not self.cap.isOpened():
                raise RuntimeError("Failed to open camera or video file.")
            self.q = queue.Queue(maxsize=1)
            self.running = True
            self.thread = threading.Thread(target=self.update, daemon=True)
            self.thread.start()

        def update(self):
            while self.running:
                ret, frame = self.cap.read()
                if not ret:
                    self.running = False
                    break
                if not self.q.full():
                    self.q.put(frame)

        def read(self):
            if not self.q.empty():
                return self.q.get()
            return None

        def release(self):
            self.running = False
            self.thread.join()
            self.cap.release()

    def __init__(self):
        self.KNOWN_FACES_PATH = "../known_faces"
        self.UNKNOWN_FACES_PATH = "../unknown_faces"
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.mtcnn = MTCNN(device="GPU:0")
        self.facenet = InceptionResnetV1(pretrained='vggface2').eval()
        self.chroma_client = chromadb.PersistentClient("../chromadb")
        self.collection = self.chroma_client.get_or_create_collection(
            name="faces",
            embedding_function=None,
            configuration={
                "hnsw": {
                    "space": "l2"
                }
            }
        )
        self.process_known_images()

    def select_random_known_img_path(self, KNOWN_FACES_PATH: str, k: int):
        res = []
        for i in range(k):
            folder_name = random.choice(os.listdir(KNOWN_FACES_PATH))
            folder_path = os.path.join(KNOWN_FACES_PATH, folder_name)
            file_name = random.choice(os.listdir(folder_path))
            res.append(os.path.join(folder_path, file_name))
        return res

    def detect_significant_face(self, img):
        faces = self.mtcnn.detect_faces(img, box_format="xywh")
        if not faces:
            raise App.NoFaceDetectedError
        significant_face = max(faces, key=lambda x: x['confidence'])
        return significant_face

    def crop_img(self, img, x, y, w, h):
        x1, y1 = x, y
        x2, y2 = x + w, y + h

        cropped_img = img[y1:y2, x1:x2]
        return cropped_img

    def img_to_tensor(self, img):
        resized_img = cv2.resize(img, (160, 160))
        tensor = torch.tensor(resized_img).permute(2, 0, 1).float()  # CxHxW
        return tensor.unsqueeze(0).to(self.device)

    def normalize_tensor(self, tensor):
        return (tensor - 127.5) / 128.0

    def embed_image(self, img):
        return self.facenet(img).detach().cpu()

    def process_known_images(self):
        for name in tqdm(os.listdir(self.KNOWN_FACES_PATH)):
            for image_name in os.listdir(os.path.join(self.KNOWN_FACES_PATH, name)):
                img_path = os.path.join(self.KNOWN_FACES_PATH, name, image_name)
                existing_record = self.collection.get(where={"img_path": img_path})
                if existing_record["ids"]:
                    continue
                try:
                    img = load_image(img_path)
                    detection_result = self.detect_significant_face(img)
                    cropped_img = self.crop_img(img, *detection_result['box'])
                    face_tensor = self.img_to_tensor(cropped_img)
                    normalized_tensor = self.normalize_tensor(face_tensor)
                    face_embedding = self.embed_image(normalized_tensor)
                    self.collection.add(
                        ids=image_name,
                        embeddings=face_embedding.tolist(),
                        metadatas={
                            "name": name,
                            "img_path": img_path
                        }
                    )
                except Exception as e:
                    print(f"Can't process file {img_path}\nError: {e}")
                    continue

    def process_unknown_images(self, results_path, tolerance: float = 0.5):
        data = []
        for img_name in tqdm(os.listdir(self.UNKNOWN_FACES_PATH)):
            img_path = os.path.join(self.UNKNOWN_FACES_PATH, img_name)
            try:
                results = self.search_image(img_path, 1)
                predicted_name = "UNKNOWN"
                predicted_image_path = None
                distance = None
                if results['distances'][0][0] <= tolerance:
                    predicted_name = results['metadatas'][0][0]['name']
                    predicted_image_path = results['metadatas'][0][0]['img_path']
                    distance = results['distances'][0][0]

                actual_name = "_".join(img_name.split('_')[:-1])
                is_correct = actual_name == predicted_name
                row = {
                    "actual_name": actual_name,
                    "predicted_name": predicted_name,
                    "is_correct": is_correct,
                    "predicted_image_path": predicted_image_path,
                    "known_image_path": img_path,
                    "distance": distance
                }
                data.append(row)

            except App.NoFaceDetectedError:
                print(f"No face detected in image: {img_path}")
                continue
        pd.DataFrame(data).to_csv(results_path, index=False)
    
    def search_image(self, path, n_results):
        img = load_image(path)
        return self.search_frame(img, n_results)

    def search_frame(self, frame, n_results):
        detection_result = self.detect_significant_face(frame)
        cropped_img = self.crop_img(frame, *detection_result['box'])
        face_tensor = self.img_to_tensor(cropped_img)
        normalized_tensor = self.normalize_tensor(face_tensor)
        face_embedding = self.embed_image(normalized_tensor)
        results = self.collection.query(
            query_embeddings=face_embedding.tolist(),
            n_results=n_results,
            include=["embeddings", "metadatas", "distances", "documents"]
        )
        return results
    
    def video_recognize(self, tolerance: float = 0.5, frame_step: int = 5):
        """
        Recognize faces from video stream, processing every `frame_step`-th frame.

        Parameters:
            tolerance (float): Distance threshold for a match.
            frame_step (int): Process every n-th frame. Example: frame_step=5 -> process every 5th frame.
        """
        # Open default camera (0). You can replace 0 with a video file path.
        stream = self.VideoStream(0)
        print("Press 'q' or 'Esc' to quit.")

        # Initialize timer for FPS calculation
        prev_time = time.time()
        fps = 0.0

        # Frame counter
        frame_count = 0

        # Cache last detection
        last_detection = None
        last_name = "UNKNOWN"
        last_matched_image = None

        while stream.running:
            frame = stream.read()
            if frame is None:
                continue

            frame_count += 1
            # ✅ Process only every n-th frame
            if frame_count % frame_step == 0:
                try:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    detection_result = self.detect_significant_face(frame_rgb)
                    x, y, w, h = detection_result['box']

                    cropped_img = self.crop_img(frame_rgb, x, y, w, h)
                    face_tensor = self.img_to_tensor(cropped_img)
                    normalized_tensor = self.normalize_tensor(face_tensor)
                    face_embedding = self.embed_image(normalized_tensor)

                    results = self.collection.query(
                        query_embeddings=face_embedding.tolist(),
                        n_results=1,
                        include=["embeddings", "metadatas", "distances", "documents"]
                    )

                    predicted_name = "UNKNOWN"
                    matched_image = None
                    if results['distances'][0][0] <= tolerance:
                        predicted_name = results['metadatas'][0][0]['name']
                        matched_image_path = results['metadatas'][0][0]['img_path']
                        matched_image = cv2.imread(matched_image_path)
                        
                        if matched_image is not None:
                            frame_h = frame.shape[0]
                            aspect_ratio = matched_image.shape[1] / matched_image.shape[0]
                            new_w = int(frame_h * aspect_ratio)
                            matched_image = cv2.resize(matched_image, (new_w, frame_h))

                    # ✅ Cache the detection
                    last_detection = (x, y, w, h)
                    last_name = predicted_name
                    last_matched_image = matched_image

                except self.NoFaceDetectedError:
                    pass

            # ✅ Always redraw from cache
            if last_detection:
                x, y, w, h = last_detection
                top_left, bottom_right = (x, y), (x+w, y+h)
                cv2.rectangle(frame, top_left, bottom_right, (0, 0, 255), 2)

                # Label
                label_y = y + h + 20 if (y + h + 20) < frame.shape[0] else y - 10
                cv2.putText(frame, last_name, (x, label_y),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

                # Matched image (if available)
                if last_matched_image is not None:
                    frame = cv2.hconcat([frame, last_matched_image])

            # Calculate FPS
            curr_time = time.time()
            elapsed = curr_time - prev_time
            if elapsed > 0:
                fps = 0.9 * fps + 0.1 * (1 / elapsed)
            prev_time = curr_time

            # Display FPS on frame
            cv2.putText(frame, f"FPS: {fps:.2f}", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

            # Show the frame
            cv2.imshow("Video", frame)

            # Wait for 1 ms and check for key press
            key = cv2.waitKey(1) & 0xFF
            if key == ord('q') or key == 27:  # 27 is Esc
                print("Exiting...")
                break

            # Cap the FPS to prevent CPU from running 100%
            time.sleep(0.01)

        # Release resources
        stream.release()
        cv2.destroyAllWindows()

In [15]:
app = App()

 64%|██████▍   | 1084/1681 [00:03<00:02, 199.06it/s]

Can't process file ../known_faces\Marilyn_Monroe\Marilyn_Monroe_0001.jpg
Error: No face detected in the provided image.


100%|██████████| 1681/1681 [00:06<00:00, 274.10it/s]


In [16]:
app.search_image("../unknown_faces/Abdoulaye_Wade_0004.jpg", 1)

{'ids': [['Abdoulaye_Wade_0001.jpg']],
 'embeddings': [array([[-0.01500503, -0.00814223, -0.08948263,  0.08022343, -0.02133966,
           0.04164231,  0.02203878, -0.0089441 , -0.08365081, -0.04914057,
          -0.00914059, -0.03725747, -0.02519825, -0.1073973 , -0.04476332,
          -0.06416232,  0.02533201,  0.02871547,  0.01664347, -0.09567705,
          -0.03124513,  0.03051565, -0.04863874,  0.08989248, -0.00244354,
           0.05888009,  0.08617198,  0.02023518,  0.05148148,  0.07205155,
           0.0379464 , -0.01517769, -0.00086331, -0.0138928 , -0.05126122,
           0.03332475, -0.03779221,  0.02241295,  0.03552938,  0.01191207,
          -0.02577062, -0.05721798,  0.05069862, -0.08422723,  0.03962384,
          -0.00718057, -0.09832343, -0.07272041,  0.08498655,  0.01572044,
           0.01590642,  0.0754925 ,  0.04452571,  0.05517029, -0.01918009,
          -0.01369224,  0.05149269, -0.00765376,  0.06864413,  0.00968034,
           0.05867082, -0.01660311,  0.06783019

In [17]:
#for index, tol in enumerate(np.arange(0.5, 1.0, 0.05), start=1):
#    tol = np.round(tol, 2)
#    app.process_unknown_images(f"../analysis/results{str(index)}.csv", tol)

In [None]:
app.video_recognize(tolerance=0.55, frame_step=2)

Press 'q' or 'Esc' to quit.
Exiting...
