In [1]:
import glob
import time
import torch
import cv2
from PIL import Image
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import tqdm

# See github.com/timesler/facenet-pytorch:
from facenet_pytorch import InceptionResnetV1, extract_face
from facenet_pytorch import MTCNN

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Running on device: {device}')

Running on device: cpu


In [2]:
# Load face detector
mtcnn = MTCNN(margin=14, keep_all=True, factor=0.5, device=device).eval()

In [3]:
# Load facial recognition model
resnet = InceptionResnetV1(pretrained='vggface2', device=device).eval()

Downloading parameters (1/2)
Downloading parameters (2/2)


In [15]:
class DetectionPipeline:
    """Pipeline class for detecting faces in the frames of a video file."""
    
    def __init__(self, detector, n_frames=None, batch_size=10, resize=None):
        """Constructor for DetectionPipeline class.
        
        Keyword Arguments:
            n_frames {int} -- Total number of frames to load. These will be evenly spaced
                throughout the video. If not specified (i.e., None), all frames will be loaded.
                (default: {None})
            batch_size {int} -- Batch size to use with MTCNN face detector. (default: {32})
            resize {float} -- Fraction by which to resize frames from original prior to face
                detection. A value less than 1 results in downsampling and a value greater than
                1 result in upsampling. (default: {None})
        """
        self.detector = detector
        self.n_frames = n_frames
        self.batch_size = batch_size
        self.resize = resize
    
    def __call__(self, filename):
        """Load frames from an MP4 video and detect faces.

        Arguments:
            filename {str} -- Path to video.
        """
        # Create video reader and find length
        v_cap = cv2.VideoCapture(filename)
        v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Pick 'n_frames' evenly spaced frames to sample
        if self.n_frames is None:
            sample = np.arange(0, v_len)
        else:
            sample = np.linspace(0, v_len - 1, self.n_frames).astype(int)

        # Loop through frames
        faces = []
        frames = []
        for j in range(v_len):
            success = v_cap.grab()
            if j in sample:
                # Load frame
                success, frame = v_cap.retrieve()
                if not success:
                    continue
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = Image.fromarray(frame)
                
                # Resize frame to desired size
                if self.resize is not None:
                    frame = frame.resize([int(d * self.resize) for d in frame.size])
                frames.append(frame)

                # When batch is full, detect faces and reset frame list
                if len(frames) % self.batch_size == 0 or j == sample[-1]:
                    faces.extend(self.detector(frames))
                    frames = []

        v_cap.release()

        return faces    


def process_faces(faces, resnet):
    # Filter out frames without faces
    faces = [f for f in faces if f is not None]
    faces = torch.cat(faces).to(device)

    # Generate facial feature vectors using a pretrained model
    embeddings = resnet(faces)

    # Calculate centroid for video and distance of each face's feature vector from centroid
    centroid = embeddings.mean(dim=0)
    x = (embeddings - centroid).norm(dim=1).cpu().numpy()
    
    return x

In [33]:
# Define face detection pipeline
detection_pipeline = DetectionPipeline(detector=mtcnn, n_frames=30, batch_size=10, resize=0.25)

# Get all test videos
filenames = glob.glob('../data/aassnaulhq.mp4')

X = []
start = time.time()
n_processed = 0
with torch.no_grad():
    for i, filename in enumerate(filenames):
        try:
            # Load frames and find faces
            faces = detection_pipeline(filename)
            
            # Calculate embeddings
            X.append(process_faces(faces, resnet))

        except KeyboardInterrupt:
            print('\nStopped.')
            break

        except Exception as e:
            print(e)
            X.append(None)
        
        n_processed += len(faces)
        print(f'Frames per second (load+detect+embed): {n_processed / (time.time() - start):6.3}\r', end='')



In [65]:
a=resnet(torch.cat(faces))

In [68]:
torch.cat(faces).shape

torch.Size([31, 3, 160, 160])

In [66]:
a

tensor([[-0.0024,  0.0218, -0.0622,  ..., -0.0197, -0.0586, -0.0048],
        [ 0.0265,  0.0207, -0.0502,  ..., -0.0311, -0.0177, -0.0092],
        [ 0.0323,  0.0007, -0.0613,  ..., -0.0248, -0.0363, -0.0045],
        ...,
        [ 0.0036,  0.0707, -0.0460,  ...,  0.0056, -0.0603,  0.0092],
        [ 0.0360,  0.0182, -0.0438,  ..., -0.0305,  0.0057,  0.0370],
        [ 0.0186, -0.0229, -0.0828,  ...,  0.0175, -0.0572, -0.0325]],
       grad_fn=<DivBackward0>)

In [67]:
a.shape

torch.Size([31, 512])

In [71]:
a.mean(dim=0).shape

torch.Size([512])

In [72]:
(a-a.mean(dim=0)).shape

torch.Size([31, 512])

In [73]:
(a-a.mean(dim=0)).norm(dim=1).shape

torch.Size([31])

In [69]:
v_cap = cv2.VideoCapture(filename)
v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
success = v_cap.grab()
print('grab',success)
success, frame = v_cap.retrieve()
print('retrieve',success)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frame = frame.resize([int(d * 0.5) for d in frame.size])
face = mtcnn(frame)
print(face)


grab True
retrieve True
tensor([[[[-0.3477, -0.3477, -0.3398,  ..., -0.4023, -0.4023, -0.4023],
          [-0.3477, -0.3477, -0.3477,  ..., -0.4023, -0.4023, -0.4023],
          [-0.3555, -0.3555, -0.3555,  ..., -0.4102, -0.4102, -0.4102],
          ...,
          [-0.7695, -0.7695, -0.7695,  ..., -0.8477, -0.8711, -0.8789],
          [-0.7695, -0.7695, -0.7695,  ..., -0.8555, -0.8867, -0.8945],
          [-0.7695, -0.7695, -0.7695,  ..., -0.8633, -0.8945, -0.9023]],

         [[-0.4336, -0.4336, -0.4414,  ..., -0.4883, -0.4883, -0.4883],
          [-0.4336, -0.4336, -0.4414,  ..., -0.4883, -0.4883, -0.4883],
          [-0.4258, -0.4258, -0.4336,  ..., -0.4805, -0.4805, -0.4805],
          ...,
          [-0.8711, -0.8711, -0.8711,  ..., -0.9336, -0.9570, -0.9648],
          [-0.8711, -0.8711, -0.8711,  ..., -0.9492, -0.9648, -0.9727],
          [-0.8711, -0.8711, -0.8711,  ..., -0.9570, -0.9727, -0.9805]],

         [[-0.4883, -0.4883, -0.4961,  ..., -0.5273, -0.5273, -0.5273],
      