In [3]:
import torch
import numpy as np
import onnxruntime as ort
from PIL import Image
import os
from ultralytics import YOLO
import cv2
import pickle
from IPython.display import display

# Класс базы

In [4]:
class FaceBase:
    def __init__(self, db_path='./base.pkl'):
        self.encoder = FaceEncoder()
        self.detector = FaceDetector()
        
        self.path = db_path
        self.base = self._load_db() if os.path.exists(db_path) else {"name": [], "face": []}
        # print(self.base['name'])

    def find_person(self, image, treshold=0.6):
        
        face = self.extract_face(image)
        if face is None:
            return "No faces"

        query_embedding = self.get_emgedding(face)
        
        # print(query_embedding.shape, np.array(self.base['face']).shape)
        
        similarity = self._cos_sim(query_embedding)
        # print("SIM: ", similarity)
        
        if np.max(similarity) > treshold:
            max_index = np.argmax(similarity)
            # print(max_index)
            return self.base['name'][max_index]
        else:
            return "Unknown"

    def _cos_sim(self, query):
        
        return np.dot(self.base['face'], query) / (
            np.linalg.norm(self.base['face'], axis=1) * np.linalg.norm(query)
        )
    
    def extract_face(self, image):
        
        boxes = self.detector.detect(image)
        if len(boxes) == 0:
            return None

        x1, y1, x2, y2 = boxes[0].astype(int)
        return Image.fromarray(image[y1:y2, x1:x2])

    def add_image(self, name, image_path):
        embeddings = []

        for path in image_path:
            image = cv2.imread(path)
            face = self.extract_face(image)
            
            if face is not None:
                embedded = self.get_emgedding(face)
                # print(embedded)
                embeddings.append(embedded)

        # print(embeddings)

        if embeddings:
            avg_embedding = np.mean(embeddings, axis=0)
            
            if name not in self.base['name']:
                self.base['name'].append(name)
                self.base['face'].append(avg_embedding)
                print("*"*50)
                print(f"Новое лицо")
                print("*"*50)
            else:
                self.base['face'][-1] = (self.base['face'][-1] + avg_embedding)/2
                print("*"*50)
                print(f"* Добавлено *")
                print("*"*50)

        self._save()
        
    def _preprocess(self, image):
        
        image = image.resize((112, 112))  # Размер, ожидаемый моделью
        image = np.array(image).transpose(2, 0, 1)  # HWC → CHW
        image = (image / 255.0 - 0.5) / 0.5  # Нормализация [0,255] → [-1,1]
        
        return image.astype(np.float32)[None, ...]  # Добавляем batch-ось

    def get_emgedding(self, face):
        processed_face = self._preprocess(face)
        embedding = self.encoder.encode(processed_face)
        return embedding
    
    def _save(self):
        with open('base.pkl', 'wb') as f:
            pickle.dump(self.base, f)

    def _load_db(self):
        with open(self.path, 'rb') as f:
            return pickle.load(f)

class FaceEncoder:
    def __init__(self, model_name="./archface/model.onnx"):
        self.__face_encoder = ort.InferenceSession(model_name, providers=["CPUExecutionProvider"])

    def encode(self, image):
        outputs = self.__face_encoder.run(
            output_names=None,  # Все выходы модели
            input_feed={
                self.__face_encoder.get_inputs()[0].name: image
            }
        )
        embeddings = outputs[0][0]
        return embeddings

class FaceDetector:
    def __init__(self, model_name='./yolov8-face/yolov8x-face-lindevs.pt'):
        self.__face_extractor = YOLO(model_name)

    def detect(self, image):
        results = self.__face_extractor(image)
        boxes = results[0].boxes.xyxy.cpu().numpy()
        return boxes

In [None]:
db = FaceBase()

In [9]:
from datetime import datetime


datetime.now()

datetime.datetime(2025, 6, 1, 15, 9, 54, 508759)

# Обучение

In [None]:

root = './data/train/'

names = os.listdir(root)

batch_size = 20

for name in names:
    
    start = datetime.now()
    # print(f"start: {name}, {start}")

    all_frames = os.listdir(os.path.join(root, name))
    
    for batch_num in range(len(all_frames) // batch_size):
        batch = all_frames[batch_num*batch_size:(batch_num+1)*batch_size]
        print(f"SEND {batch_num} BATCH")
        db.add_image(
            name, [os.path.join(root, name, img) for img in batch]
        )
    
    # print(f'end: {name}')
   

# Тест

In [55]:
with open('./base.pkl', 'rb') as f:
        base = pickle.load(f)

In [65]:
classes = base['name']

In [None]:
db = FaceBase()

In [67]:
# classes.append("Unknown")

In [None]:
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

encoder = LabelEncoder()
encoder.fit(classes)
enc_peoples = encoder.transform(classes)


In [None]:
root = './data/test/'
names = os.listdir(root)

predictions = []
real = []

for name in names:
    
    for image_name in os.listdir(os.path.join(root, name)):
        
        image = cv2.imread(os.path.join(root, name, image_name))
        # Имя
        res = db.find_person(image, treshold=.5)
        print(name, res)
        pred = encoder.transform([res])
        predictions.append(pred[0])
        real.append(encoder.transform([name])[0])

In [77]:
report = classification_report(real, predictions, target_names=encoder.classes_)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
print(report)