In [1]:
import torch
import cv2
import time
import os
import argparse
import datetime
import numpy as np
import tkinter as tk
import torchvision.transforms as transforms

from torchvision import datasets
from torch.utils.data import DataLoader
from PIL import Image, ImageTk
from facenet_pytorch import MTCNN, InceptionResnetV1

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


In [4]:
device = torch.device('cpu')

In [3]:
mtcnn0 = MTCNN(image_size=240, margin=0, keep_all=False, min_face_size=40, device=device) # keep_all=False
mtcnn = MTCNN(image_size=240, margin=0, keep_all=True, min_face_size=40, device=device) # keep_all=True
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)

In [4]:
# Read data from folder

def collate_fn(x):
    return x[0]

dataset = datasets.ImageFolder('photos') # photos folder path
idx_to_class = {i:c for c,i in dataset.class_to_idx.items()} # accessing names of peoples from folder names
loader = DataLoader(dataset, collate_fn=collate_fn)

name_list = [] # list of names corresponding to cropped photos
embedding_list = [] # list of embedding matrix after conversion from cropped faces to embedding matrix using resnet

for img, idx in loader:
    face, prob = mtcnn0(img, return_prob=True)
    if face is not None and prob>0.9:
        emb = resnet(face.unsqueeze(0).to(device))
        embedding_list.append(emb.cpu().detach())
        name_list.append(idx_to_class[idx])
        

# save data
data = [embedding_list, name_list]
torch.save(data, 'data.pt') # saving data.pt file

In [6]:
#using webcame recognize face
#loading data.pt file
load_data = torch.load('data.pt') 
embedding_list = load_data[0] 
name_list = load_data[1]

PROBABILITY_THRESHOLD = 0.9
MINDIST_THRESHOLD = 1.1

class Application:
    def __init__(self):
        self.vs = cv2.VideoCapture(0)
        self.screenshot_path = "screenshots"                     # store screenshot path
        self.video_path = "videos"                               # store screenshot path
        self.current_image = None                                # current image from the camera
        
        self.width = int(self.vs.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.height = int(self.vs.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        self.writer = None
        self.record_status = False

        self.root = tk.Tk()                                      # initialize root window
        self.root.title("Face Recognition")
        self.root.protocol('WM_DELETE_WINDOW', self.destructor)  # self.destructor function gets fired when the window is closed

        self.panel = tk.Label(self.root)                         # initialize image panel
        self.panel.pack(padx=10, pady=10)

        btn_screenshot = tk.Button(self.root, text="Take picture!", command=self.take_snapshot)
        btn_screenshot.pack(fill="both", expand=True, padx=10, pady=10)
        
        btn_start_record = tk.Button(self.root, text="Record", command=self.start_record)
        btn_start_record.pack(fill="both", expand=True, padx=10, pady=10)
        
        btn_stop_record = tk.Button(self.root, text="Stop Recording", command=self.stop_record)
        btn_stop_record.pack(fill="both", expand=True, padx=10, pady=10)

        self.video_loop()

    def video_loop(self):
        ready, frame = self.vs.read()                            # read frame from video stream
        img = Image.fromarray(frame)
        img_cropped_list, prob_list = mtcnn(img, return_prob=True)        
              
        if ready:  # frame captured without any errors
            if img_cropped_list is not None:
                boxes, _ = mtcnn.detect(img)

                for i, prob in enumerate(prob_list):
                    if prob>PROBABILITY_THRESHOLD:
                        emb = resnet(img_cropped_list[i].unsqueeze(0).to(device)).cpu().detach() 

                        dist_list = []         # list of matched distances, minimum distance is used to identify the person

                        for idx, emb_db in enumerate(embedding_list):
                            dist = torch.dist(emb, emb_db, 2).item()
                            dist_list.append(dist)

                        softmax_list = np.exp(dist_list) / np.sum(np.exp(dist_list), axis=0)
                        min_dist = min(dist_list)
                        min_dist_idx = dist_list.index(min_dist)                  # get minumum dist index
                        confidence_score = 1 - softmax_list[min_dist_idx]
                        name = name_list[min_dist_idx]                            # get name corrosponding to minimum dist

                        box = boxes[i].astype(int)

                        if min_dist<MINDIST_THRESHOLD:
                            frame = cv2.putText(frame, name+' '+str("%.3f" % confidence_score), (box[0],box[1]), cv2.FONT_HERSHEY_TRIPLEX, 1, (0,255,0), 1, cv2.LINE_AA)
                        else:
                            frame = cv2.putText(frame, 'unknown', (box[0],box[1]), cv2.FONT_HERSHEY_TRIPLEX, 1, (0,0,255), 1, cv2.LINE_AA)

                        frame = cv2.rectangle(frame, (box[0],box[1]) , (box[2],box[3]), (255,0,0), 2)
            
            if self.record_status:
                self.writer.write(frame)
            cv2image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            self.current_image = Image.fromarray(cv2image)
            imgtk = ImageTk.PhotoImage(image=self.current_image)
            self.panel.imgtk = imgtk                                                   # anchor imgtk so it does not be deleted by garbage-collector
            self.panel.config(image=imgtk)                                             # show the image
        self.root.after(30, self.video_loop)                                           # call the same function after 30 milliseconds

    def take_snapshot(self):
        ts = datetime.datetime.now()
        filename = "{}.jpg".format(ts.strftime("%Y-%m-%d_%H-%M-%S"))
        p = os.path.join(self.screenshot_path, filename)
        self.current_image.save(p, "JPEG")
        print("[INFO] saved {}".format(filename))
        
    def start_record(self):
        ts = datetime.datetime.now()
        video_name = "{}.mp4".format(ts.strftime("%Y-%m-%d_%H-%M-%S"))
        path = os.path.join(self.video_path, video_name)
        self.writer = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'DIVX'), 20, (self.width,self.height))
        self.record_status = True
        print("[INFO] start recording {}".format(video_name))
    
    def stop_record(self):
        self.writer.release()
        self.writer = None
        self.record_status = False
        print("[INFO] stop recording")
    
    def destructor(self):
        print("[INFO] closing...")
        self.root.destroy()
        self.vs.release()
        cv2.destroyAllWindows()

# start the app
print("[INFO] starting...")
pba = Application()
pba.root.mainloop()

[INFO] starting...
[INFO] saved 2021-02-01_23-24-34.jpg
[INFO] start recording 2021-02-01_23-25-25.mp4
[INFO] stop recording
[INFO] closing...
