In [1]:
from pathlib import Path
from facenet_pytorch import MTCNN, InceptionResnetV1, extract_face
import torch
import torchvision.transforms as T
import cv2
from PIL import Image

import os

In [19]:
people_path = 'data/faces'

# Minimun similarity
sim_threshold = 0.4

#OpenCV parameters
font = cv2.FONT_HERSHEY_SIMPLEX
BRG_color = (0,0,255)
thickness = 2
text_y_offset = -10


In [20]:
# Check if gpu available and if yes, use it
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


In [21]:
# Initialize networks
mtcnn = MTCNN(keep_all=True, post_process=True, device=device)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)


In [22]:
# helper functions
cos_sim = torch.nn.CosineSimilarity()
convert_tensor = T.ToTensor()

In [23]:
# Calculate available faces embedings

names = []
sample_embedings = []

for name in os.listdir(people_path):
    filename = Path(name)
    names.append(filename.with_suffix(''))
    img = Image.open(os.path.join(people_path,name))
    sample_embedings.append(resnet((convert_tensor(img).unsqueeze(0)).to(device)))

In [24]:
# Make it torch tensor
sample_embedings = torch.vstack(sample_embedings)

In [37]:
# Capture webcam
capture = cv2.VideoCapture(0)

while (capture.isOpened()):

    # Read camera
    ret, frame = capture.read()

    #Get face boxes
    frame_RGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    boxes, _ = mtcnn.detect(frame_RGB)
    

    if not(boxes is None):
        #Iterate over the faces detected
        for box in boxes:
            
            # Draw Rectangle in video
            cv2.rectangle(frame,(int(box[0]),int(box[1])),(int(box[2]),int(box[3])),color=BRG_color, thickness=thickness)
            
            # Get face and compare with existing ones
            face = extract_face(frame_RGB, box)/255.
            emb = resnet(face.unsqueeze(0).to(device))
            sims = cos_sim(emb, sample_embedings)

            # Get the highest similitude and check if high enough
            max_sim_idx = torch.argmax(sims)

            if sims[max_sim_idx]>sim_threshold:
                cv2.putText(frame, text=str(names[max_sim_idx]), org=(int(box[0]),int(box[1])+text_y_offset), fontFace=font, fontScale=thickness, color=BRG_color)
                # cv2.putText(frame, text=str(float(sims[max_sim_idx].cpu())), org=(int(box[0]),int(box[1])+200), fontFace=font, fontScale=thickness, color=BRG_color)
            
            else:
                cv2.putText(frame, text='unknown', org= (int(box[0]),int(box[1])+text_y_offset), fontFace=font, fontScale=thickness, color=BRG_color)
    
    #Show the image
    cv2.imshow('webCam',frame)

    #Press s key to stop
    if (cv2.waitKey(1) == ord('s')):
        break

capture.release()
cv2.destroyAllWindows()