# Capturing neutral face video

# Rendering face
## Remember to only run this once the headset is on! 

In [None]:
import cv2 as cv
import numpy as np
import pickle
from matplotlib import pyplot as plt

import face_alignment
import time
import cv2
import scipy.optimize as opt
from calibration.undistort import undistort

def draw_marks(image, marks, color=(0, 255, 0)):
    """
    Draw the facial landmarks on an image
    Parameters
    ----------
    image : np.uint8
        Image on which landmarks are to be drawn.
    marks : list or numpy array
        Facial landmark points
    color : tuple, optional
        Color to which landmarks are to be drawn with. The default is (0, 255, 0).
    Returns
    -------
    None.
    """
    for mark in marks:
        cv.circle(image, (mark[0], mark[1]), 2, color, -1, cv.LINE_AA)
        

fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cuda')

# this function compares the differences between two sets of points
# each column is a point
def err(p1, p2):
    # hopefully this is actually correct
    return sum(sum((p1 - p2)*(p1 - p2)))



# x is an array of 9 elements
def to_optimize(x, original_points, target_points):
    # reshape it to be a 3x3 transformation matrix
    transform = x.reshape((3,3))
    temp = transform@original_points
    result = np.zeros((2,original_points.shape[1]))
    for i in range(original_points.shape[1]):
        result[0,i] = temp[0,i]/temp[2,i]
        result[1,i] = temp[1,i]/temp[2,i]
    return err(result, target_points[:2,:])


def crop_image(image):
    return image[:,80:-80,:]

def capture_initial_video():
    original_points = 0
    target_points = 0
    cap = cv2.VideoCapture(0)
    
    ## capture jawline with headset on and save the feature pts
    while True:
        res, frame = cap.read()
        if not res:
            print('no frame')
            continue
        cv2.imshow(winname="RAW FRAME", mat=frame)

#         frame = undistort(frame)
#         frame = frame[:,80:-80,:]
        frame = crop_image(frame)
        frame = cv.resize(frame,(256,256))
        
        preds = fa.get_landmarks(frame)
        if(preds):
            headset_features = preds[0].astype(int)[:17]
            draw_marks(frame,headset_features,color = (255,0,0))
#             print("feature dims", headset_features.shape)
        cv2.imshow(winname="Face", mat=cv2.resize(frame,(720,720)))
        if cv2.waitKey(5) & 0xFF == ord('q'):
            target_points = preds[0].astype(int)[:17]
            target_points = np.append(target_points, np.ones((17,1)), axis=1)
            target_points = target_points.T
            print("got target points with shape", target_points.shape)
            break   
        
    video_out = cv2.VideoWriter('./init_video.mp4',cv2.VideoWriter_fourcc(*'XVID'),25,(256,256))
    print('now trying to find a good no-headset match')
    record = False
    
    ## align 
    result_marks = 0
    generated_result = False
    while True:
        res, frame = cap.read()
        if not res:
            print('no frame')
            continue
            
#         print(frame.shape)
#         cv2.imshow(winname = "original Frame",mat = cv2.resize(frame,(256,256)))

        # Convert image into grayscale
    #     frame = cv2.cvtColor(src=frame, code=cv2.COLOR_BGR2RGB)
#         frame = undistort(frame)
#         frame = frame[:,80:-80,:]
        frame = crop_image(frame)
        frame = cv.resize(frame,(256,256))
        # show the image
        preds = fa.get_landmarks(frame)
        if(preds):
            new_face = preds[0].astype(int)
            clean_frame = frame.copy()

            draw_marks(frame,new_face,color = (0,0,255))
        if generated_result:
            print("generating result")
            draw_marks(frame,result_marks, color = (0,255,0))
        draw_marks(frame,headset_features,color = (255,0,0))
        cv2.imshow(winname="Face", mat=cv2.resize(frame,(720,720)))

        pressed_key = cv2.waitKey(5) & 0xFF

        if(pressed_key == ord('q')):
            print("stop recording")
            break
        elif(pressed_key == ord('r')):
            print('started recording')
            record = True
            original_points = preds[0].astype(int)[:17]
            original_points = np.append(original_points, np.ones((17,1)), axis=1)
            original_points = original_points.T
            print(original_points.shape)
            res = opt.minimize(to_optimize, np.array([1,0,0,0,1,0,0,0,1]), args=(original_points, target_points))
            transform = res.x.reshape((3,3))
            temp = transform@original_points
            result = np.zeros((2,original_points.shape[1]))
            for i in range(original_points.shape[1]):
                result[0,i] = temp[0,i]/temp[2,i]
                result[1,i] = temp[1,i]/temp[2,i]
            result_marks = result.T.astype(int)
            generated_result = True
            print(result_marks)
        if(record):
            video_out.write(clean_frame)
    # When everything done, release the video capture and video write objects
    cap.release()
    video_out.release()
    # Close all windows
    cv2.destroyAllWindows()
    pickle.dump(clean_frame,open('./starting_picture.p','wb'))
    return clean_frame

print("run")

capture_initial_video()
cv2.destroyAllWindows()

In [None]:
import cv2
## Take the first frame of the video w/ headset on 
cap1 = cv2.VideoCapture('./init_video.mp4')
for i in range(30):
    res, frame = cap1.read()
    cv2.imshow('aaaa',frame)
    cv2.waitKey(100)
cv2.imwrite('./init_img2.png',frame)
cv2.destroyAllWindows()

# Attention: 
## Run this before the next cell:
### https://stackoverflow.com/questions/70775129/runtimeerror-v4l2loopback-backend-stdexception-when-using-pyvirtualcam


!sudo modprobe -r v4l2loopback && sudo modprobe v4l2loopback devices=1 video_nr=4 card_label="Virtual" exclusive_caps=1 max_buffers=2


In [None]:
import getpass
import os

password = getpass.getpass()
command = 'sudo -S modprobe -r v4l2loopback && sudo modprobe v4l2loopback devices=1 video_nr=5 card_label="Virtual" exclusive_caps=1 max_buffers=2'

os.system('echo %s | %s' % (password, command))

In [None]:
#### import matplotlib
# matplotlib.use('Agg')
import os, sys
import yaml
from argparse import ArgumentParser
from tqdm import tqdm

import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback

from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
from scipy.spatial import ConvexHull
import pdb
import pyvirtualcam
import time
from calibration.undistort import undistort
import cv2


MERGE_HEIGHT = 150
def crop_image(image):
    return image[:,80:-80,:]

if sys.version_info[0] < 3:
    raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")

def load_checkpoints(config_path, checkpoint_path, cpu=False):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    if not cpu:
        generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not cpu:
        kp_detector.cuda()
    
    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)
 
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])
    
    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()
    
    return generator, kp_detector


def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
    with torch.no_grad():
        predictions = []
        source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
        if not cpu:
            source = source.cuda()
        driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
        kp_source = kp_detector(source)
        kp_driving_initial = kp_detector(driving[:, :, 0])

        for frame_idx in tqdm(range(driving.shape[2])):
            driving_frame = driving[:, :, frame_idx]
            if not cpu:
                driving_frame = driving_frame.cuda()
            kp_driving = kp_detector(driving_frame)
#             pdb.set_trace()
#             print(kp_driving.shape)
            kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                   kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
                                   use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
            out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
            predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
    return predictions

def find_best_frame(source, driving, cpu=False):
    import face_alignment

    def normalize_kp(kp):
        kp = kp - kp.mean(axis=0, keepdims=True)
        area = ConvexHull(kp[:, :2]).volume
        area = np.sqrt(area)
        kp[:, :2] = kp[:, :2] / area
        return kp

    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
                                      device='cpu' if cpu else 'cuda')
    kp_source = fa.get_landmarks(255 * source)[0]
    kp_source = normalize_kp(kp_source)
    norm  = float('inf')
    frame_num = 0
    for i, image in tqdm(enumerate(driving)):
        kp_driving = fa.get_landmarks(255 * image)[0]
        kp_driving = normalize_kp(kp_driving)
        new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
        if new_norm < norm:
            norm = new_norm
            frame_num = i
    return frame_num

def create_frame(source_image, driving_video, generator, kp_detector,kp_source,kp_driving_initial,source, relative=True, adapt_movement_scale=True, cpu=False,):
    with torch.no_grad():
        source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
        if not cpu:
            source = source.cuda()
        driving_frame = torch.tensor(driving_video[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
        if not cpu:
            driving_frame = driving_frame.cuda()
        kp_driving = kp_detector(driving_frame)
        kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                               kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
                               use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
        out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
        return np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]



source_image = (imageio.imread('./init_img2.png')).astype(np.uint8)
# reader = cv2.VideoCapture('./test_video.mp4')
reader2 = cv2.VideoCapture('./init_video.mp4')
# reader = imageio.get_reader('./test_video.mp4')
# reader2 = imageio.get_reader('./init_video.mp4')
# fps = reader.get_meta_data()['fps']
driving_video = []
regular_video = []

def pre_process_frame(frame):
    #frame = frame[:,80:-80,:]
    frame = crop_image(frame)
    frame = cv2.resize(frame,(256,256))
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)    
    frame = [resize(img, (256, 256))[..., :3] for img in [frame]]
    return frame

        
while(1):
    ret,frame = reader2.read()
    if(ret):
        regular_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    else:
        break


source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
regular_video = [resize(frame, (256, 256))[..., :3] for frame in regular_video]


cap = cv2.VideoCapture(0)
res, frame = cap.read()

while(res is not True):
    res, frame = cap.read()
    print('retry')
    time.sleep(0.3)
# frame = undistort(frame)
driving_video = pre_process_frame(frame)

driving_video[0][:MERGE_HEIGHT,:] = regular_video[0][:MERGE_HEIGHT,:]

# cv2.imshow(winname="Face", mat=cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))

generator, kp_detector = load_checkpoints(config_path='./config/vox-adv-256.yaml', checkpoint_path='./vox-adv-cpk.pth.tar', cpu=False)

source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
source = source.cuda()
# driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
kp_source = kp_detector(source)

driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
relative = True
adapt_movement_scale = True

predictions = []


kp_driving_initial = kp_detector(driving[:, :, 0])
counter = 0

with pyvirtualcam.Camera(width=256, height=256, fps=30,device='/dev/video5') as cam:
    while(1):
        start = time.time()
        res, frame = cap.read()
        if not res:
            print('no frame')
            continue
#         frame = undistort(frame)
        cv2.imwrite("normalframe.jpg", frame);
        frame = pre_process_frame(frame)
        frame[0][:128,:] = regular_video[counter%len(regular_video)][:128,:]
        cv2.imshow("split frame",cv2.resize(frame[0],(720,720)))
#         print(frame)
#         cv2.imshow('stitched',mat = frame[0])
#         plt.imshow(frame[0])
#         plt.show()
        output = create_frame(source_image,frame[0],generator,kp_detector,kp_source,kp_driving_initial,source,relative=True, adapt_movement_scale=True, cpu=False)
        output255 = (output*255).astype(np.uint8)
        cam.send(output255)
#         cv2.imwrite("outputframe.jpg", output255)
#         cv2.imshow('aaaaa',output255)
#         cv2.imshow('original',frame)
#         cv2.waitKey(1)
        cam.sleep_until_next_frame()
        counter = (counter+1)%len(regular_video)
#         cv2.waitKey(5)
#     predictions.append(output)

# predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
imageio.mimsave('res.mp4', [img_as_ubyte(frame) for frame in predictions], fps=25)

In [None]:
x = np.array([1,2,3,4,5,6,7,8,9])
print(x.reshape((3,3)))