# Notebook to parallelize frame extraction

In [None]:
import os
import json
import subprocess
import numpy as np
from ipywidgets import interact
from collections import defaultdict

import face_recognition
import cv2

from PIL import Image


import pretorched.visualizers as vutils
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_style("dark")

# Plotting
%matplotlib inline
plt.rcParams['font.size'] = 18.0
plt.rcParams['figure.figsize'] = (24.0, 16.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
# For MB Pro retina display
%config InlineBackend.figure_format = 'retina'

# For auto-reloading external modules
%load_ext autoreload
%autoreload 2

from data import utils

In [None]:
root = '/data/datasets/DeepfakeDetection/'

frame_dir = os.path.join(root, 'frames', 'dfdc_train_part_0', 'aaqaifqrwn.mp4')
# frame_dir = os.path.join(root, 'frames', 'dfdc_train_part_0', 'gnlvnyrpfq.mp4')
# frame_dir = os.path.join(root, 'frames', 'dfdc_train_part_2', 'cmliuimutv.mp4')
test_image = os.path.join(frame_dir, '000200.jpg')
# test_video = os.path.join(root, 'videos', 'dfdc_train_part_0','aaqaifqrwn.mp4')
test_video = os.path.join(root, 'videos', 'dfdc_train_part_2','aljjmeqszq.mp4')
# frame = face_recognition.load_image_file(test_image)
# vutils.imshow(frame)

In [None]:
face_image = utils.extract_face_frame(test_image)
# vutils.imshow(frame)
vutils.imshow(face_image)

In [None]:
faces = utils.extract_multi_faces(test_video)

In [None]:
video = test_video
v_margin=100
h_margin=100
batch_size=64
fps=30
device_id=0
imsize=360
num_frames = 16


# Open video file
video_capture = cv2.VideoCapture(video)
video_capture.set(cv2.CAP_PROP_FPS, fps)
faces = []
frames = []
frame_count = 0

while video_capture.isOpened():
    # Grab a single frame of video
    ret, frame = video_capture.read()

    # Bail out when the video file ends
    if not ret:
        break

    # Convert the image from BGR color (which OpenCV uses) to RGB color (which face_recognition uses)
    frame = frame[:, :, ::-1]

    # Save each frame of the video to a list
    frame_count += 1
    frames.append(frame)
#     if frame_count >= num_frames:
#         break

In [None]:
batched_face_locations = face_recognition.batch_face_locations(frames, number_of_times_to_upsample=0, batch_size=64)

In [None]:
def crop_face_location(frame, face_location, v_margin, h_margin, imsize):
    top, right, bottom, left = face_location
    mtop = max(top - v_margin, 0)
    mbottom = min(bottom + v_margin, frame.shape[0])
    mleft = max(left - h_margin, 0)
    mright = min(right + h_margin, frame.shape[1])
    face_image = frame[mtop:mbottom, mleft:mright]
    return Image.fromarray(face_image).resize((imsize, imsize))

In [None]:
match = [True, False]
match.index(True)
import pretorched.runners.utils as rutils

In [None]:
def get_match(known_faces, face_encoding, tolerance=0.50):
    match = face_recognition.compare_faces(known_faces, face_encoding, tolerance=tolerance)
    try:
        face_idx = match.index(True)
        return face_idx
    except ValueError:
        return get_match(known_faces, face_encoding, tolerance + 0.01)

In [None]:
faces = defaultdict(list)
known_faces = []

for frameno, (frame, face_locations) in enumerate(zip(frames,batched_face_locations)):
    number_of_faces_in_frame = len(face_locations)
    print(frameno, number_of_faces_in_frame)
    
    if number_of_faces_in_frame < 1:
        raise ValueError('WARNING TODO: NEED TO TRY ANOTHER METHOD')
    elif number_of_faces_in_frame > 1:
        print('MORE THAN ONE FACE')
            
        face_encodings = face_recognition.face_encodings(frame, face_locations)
        
        if not known_faces:
            known_faces = face_encodings
#             known_faces = [rutils.AverageMeter(i) for i in range(len(face_encodings))]
#             for m, fe in zip(known_faces, face_encodings):
#                 m.update(fe)
            
        for face_encoding, face_location in zip(face_encodings, face_locations):
            # See if the face is a match for the known face(s)
#             match = face_recognition.compare_faces(known_faces, face_encoding, tolerance=0.60)
            face_idx = get_match(known_faces, face_encoding)
#             face_idx = get_match([m.avg for m in known_faces], face_encoding)
    
            face_image = crop_face_location(frame, face_location, v_margin, h_margin, imsize)
#             print(match)
#             face_idx = match.index(True)
            faces[face_idx].append(face_image)
            known_faces[face_idx] = face_encoding
#             known_faces[face_idx].update(face_encoding)
    
    else:
        for i, face_location in enumerate(face_locations):
            # Print the location of each face in this frame
            face_image = crop_face_location(frame, face_location, v_margin, h_margin, imsize)
            faces[i].append(face_image)

In [None]:
faces

In [None]:
for i, images in faces.items():
    f = np.stack([np.array(frame) for frame in images])
    vutils.imshow(vutils.make_grid(f, nrow=16))
    print(f.shape)

In [None]:
def process_batch(frames, imsize):
    face_batch = []
    batch_of_face_locations = face_recognition.batch_face_locations(frames, number_of_times_to_upsample=0)

    for frame_number_in_batch, face_locations in enumerate(batch_of_face_locations):
        number_of_faces_in_frame = len(face_locations)

        for frame, face_location in zip(frames, face_locations):
            # Print the location of each face in this frame
            top, right, bottom, left = face_location
            mtop = max(top - v_margin, 0)
            mbottom = min(bottom + v_margin, frame.shape[0])
            mleft = max(left - h_margin, 0)
            mright = min(right + h_margin, frame.shape[1])
            # face_image = frame[top - v_margin:bottom + v_margin, left - h_margin:right + h_margin]
            face_image = frame[mtop:mbottom, mleft:mright]
            face_batch.append(Image.fromarray(face_image).resize((imsize, imsize)))
    return face_batch
