In [None]:
import cv2 as cv
import numpy as np
import pickle
from matplotlib import pyplot as plt

import face_alignment
import time
import cv2
import scipy.optimize as opt
from calibration.undistort import undistort

def draw_marks(image, marks, color=(0, 255, 0), ratio = 1):
    """
    Draw the facial landmarks on an image
    Parameters
    ----------
    image : np.uint8
        Image on which landmarks are to be drawn.
    marks : list or numpy array
        Facial landmark points
    color : tuple, optional
        Color to which landmarks are to be drawn with. The default is (0, 255, 0).
    Returns
    -------
    None.
    """
    for mark in marks:
        cv.circle(image, (int(mark[0]*ratio), int(mark[1]*ratio)), 2, color, -1, cv.LINE_AA)
        

fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cuda')

# this function compares the differences between two sets of points
# each column is a point
def err(p1, p2):
    # hopefully this is actually correct
    return sum(sum((p1 - p2)*(p1 - p2)))



# x is an array of 9 elements
def to_optimize(x, original_points, target_points):
    # reshape it to be a 3x3 transformation matrix
    transform = x.reshape((3,3))
    temp = transform@original_points
    result = np.zeros((2,original_points.shape[1]))
    for i in range(original_points.shape[1]):
        result[0,i] = temp[0,i]/temp[2,i]
        result[1,i] = temp[1,i]/temp[2,i]
    return err(result, target_points[:2,:])


def crop_image(image):
    return image[:,80:-80,:]

def capture_initial_video():
    original_points = 0
    target_points = 0
    cap = cv2.VideoCapture(0)
    record = False
    video_out_1 = cv2.VideoWriter('./video_headset_on.mp4',cv2.VideoWriter_fourcc(*'XVID'),25,(480,480))
    
    ## capture jawline with headset on and save the feature pts
    while True:
        res, frame = cap.read()
        if not res:
            print('no frame')
            continue
        cv2.imshow(winname="RAW FRAME", mat=frame)

        frame = crop_image(frame)
        resized_frame = cv.resize(frame,(256,256))
        clean_frame = frame.copy()
        preds = fa.get_landmarks(resized_frame)
        
        if(preds):
            headset_features = np.concatenate((preds[0].astype(int)[:17],preds[0].astype(int)[49:61]))
            draw_marks(frame,headset_features,color = (255,0,0), ratio = 480/256)

        cv2.imshow(winname="Face", mat=cv2.resize(frame,(720,720)))
        pressed_key = cv2.waitKey(5) & 0xFF
        if pressed_key == ord('q'):
            target_points = np.concatenate((preds[0].astype(int)[:17],preds[0].astype(int)[49:61]))
            target_points = np.append(target_points, np.ones((target_points.shape[0],1)), axis=1)
            target_points = target_points.T
            print("got target points with shape", target_points.shape)
            break 
        elif pressed_key == ord('r'):
            record = True
            
        if record:
            print(record)
            video_out_1.write(clean_frame)
            
    video_out_1.release()
    video_out = cv2.VideoWriter('./video_headset_off.mp4',cv2.VideoWriter_fourcc(*'XVID'),25,(480,480))
    print('now trying to find a good no-headset match')
    record = False
    
    ## align 
    result_marks = 0
    generated_result = False
    while True:
        res, frame = cap.read()
        if not res:
            print('no frame')
            continue
        
        frame = crop_image(frame)
        resized_frame = cv.resize(frame,(256,256))
        
        # show the image
        preds = fa.get_landmarks(resized_frame)
        if(preds):
            new_face = preds[0].astype(int)
            clean_frame = frame.copy()
            draw_marks(frame,new_face,color = (0,0,255), ratio = 480/256)
        if generated_result:
            print("generating result")
            draw_marks(frame,result_marks, color = (0,255,0), ratio = 480/256)
            
        draw_marks(frame,headset_features,color = (255,0,0), ratio = 480/256)
        cv2.imshow(winname="Face", mat=cv2.resize(frame,(720,720)))

        pressed_key = cv2.waitKey(5) & 0xFF

        if(pressed_key == ord('q')):
            print("stop recording")
            break
        elif(pressed_key == ord('r')):
            print('started recording')
            record = True
            original_points = np.concatenate((preds[0].astype(int)[:17],preds[0].astype(int)[49:61]))
            original_points = np.append(original_points, np.ones((original_points.shape[0],1)), axis=1)
            original_points = original_points.T
            print(original_points.shape)
            res = opt.minimize(to_optimize, np.array([1,0,0,0,1,0,0,0,1]), args=(original_points, target_points))
            transform = res.x.reshape((3,3))
            temp = transform@original_points
            result = np.zeros((2,original_points.shape[1]))
            for i in range(original_points.shape[1]):
                result[0,i] = temp[0,i]/temp[2,i]
                result[1,i] = temp[1,i]/temp[2,i]
            result_marks = result.T.astype(int)
            generated_result = True
            print(result_marks)
        if(record):
            video_out.write(clean_frame)
            print(clean_frame.shape)
    # When everything done, release the video capture and video write objects
    cap.release()
    video_out.release()
    # Close all windows
    cv2.destroyAllWindows()
    pickle.dump(clean_frame,open('./starting_picture.p','wb'))
    return clean_frame

print("run")

capture_initial_video()
cv2.destroyAllWindows()