In [1]:
import sys; sys.path.append('..')

import cv2
from PIL import Image
import numpy as np
import face_alignment
from continuous_landmarks.dataset.transforms import (
    Compose, Align, Resize,
    CenterCrop, AbsToRelLdmks,
    ToTensor, Normalize
)

In [2]:
def get_video_frames(video_path):
    frames = []

    cap = cv2.VideoCapture(video_path)
    while(cap.isOpened()):
        ret, frame = cap.read()
        if not ret:
            break
        frame = frame[..., ::-1]
        frames.append(frame)

    return frames

In [3]:
from tqdm import tqdm
import torch
from face_alignment.utils import crop, get_preds_fromhm


def get_face_detections(frames):
    face_detections = []
    
    for frame in tqdm(frames):
        dets = fa.face_detector.detect_from_image(frame.copy())
        assert len(dets) == 1
        face_detections.append(dets[0])

    return face_detections


def get_detection_center_scale(detection):
    d = detection
    center = torch.tensor(
        [d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0]
    )
    center[1] = center[1] - (d[3] - d[1]) * 0.12
    scale = (d[2] - d[0] + d[3] - d[1]) / fa.face_detector.reference_scale
    return center, scale


def crop_face(frame, center, scale):
    inp = crop(frame, center, scale)
    inp = torch.from_numpy(inp.transpose((2, 0, 1))).float()
    inp.div_(255.0)

    return inp

In [4]:
def get_landmarks(frames, face_detections, device='cuda'):
    det_center_scales = [
        get_detection_center_scale(d)
        for d in face_detections
    ]

    face_crops = torch.stack([
        crop_face(frame, *cs) for frame, cs in zip(frames, det_center_scales)
    ]).to(device)

    with torch.no_grad():
        out_batch = fa.face_alignment_net(face_crops).cpu().numpy()

    landmarks = []

    for out, (center, scale) in zip(out_batch, det_center_scales):
        pts, pts_img, scores = get_preds_fromhm(out[None, ...], center.numpy(), scale)
        pts, pts_img = torch.from_numpy(pts), torch.from_numpy(pts_img)
        pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2)
        landmarks.append(pts_img)

    return torch.stack(landmarks)

In [5]:
frames = get_video_frames('test.mov')

In [6]:
from pathlib import Path

lm_cache_path = Path('landmarks_cache.pth')
LAZY = True

if not lm_cache_path.exists() or not LAZY:
    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
    face_detections = get_face_detections(frames)
    landmarks = get_landmarks(frames, face_detections)
    torch.save(landmarks, lm_cache_path)
else:
    landmarks = torch.load(lm_cache_path)

# Continuous landmarks

In [7]:
def get_eyes_mouth(points):
    assert len(points) == 68

    e0 = points[36:42].mean(axis=0)
    e1 = points[42:48].mean(axis=0)
    m0 = points[48]
    m1 = points[54]

    return e0, e1, m0, m1

tfm = Compose([
    Align(get_eyes_mouth),
    Resize(224),
    CenterCrop(224),
    AbsToRelLdmks(),
    ToTensor(),
    Normalize([.5, .5, .5], [.2, .2, .2]),
])

frame_batch, lm_batch = list(zip(*[tfm(Image.fromarray(frame), np.array(lms)) for frame, lms in zip(frames, landmarks)]))
frame_batch = torch.stack(frame_batch)
lm_batch = torch.stack(lm_batch)

In [8]:
from continuous_landmarks.model import FeatureExtractor, LandmarkPredictor, \
    PositionEncoder
from torch import nn

device = 'cuda'

pos_encoder = PositionEncoder()
feat_extractor = FeatureExtractor('MobileNetV3')
lm_predictor = LandmarkPredictor(
    query_size=pos_encoder.encoding_size,
    feature_size=feat_extractor.feature_size,
    model_name='MLP',
)

model = nn.ModuleDict({
    'PositionEncoder': pos_encoder,
    'FeatureExtractor': feat_extractor,
    'LandmarkPredictor': lm_predictor
})
model.load_state_dict(torch.load('../ckpts/95yvg3pn_best.pth'))
model.to(device).eval();

In [9]:
canonical_shape = torch.load('../continuous_landmarks/dataset/facescape_mouth_stretch.pth').to(device)


In [10]:
batch_size = 32
frame_batches = [frame_batch[i:i + batch_size] for i in range(0, len(frame_batch), batch_size)]

In [11]:
from continuous_landmarks.training.training_loop import get_inv_tfm
from torchvision.transforms.functional import to_pil_image

inv_tfm = get_inv_tfm(tfm)

input_ims = []
pred_lms = []

with torch.no_grad():
    for batch in frame_batches:
        canon_batch = canonical_shape[None, ...].expand(len(batch), -1, -1)
        B, N, _ = canon_batch.shape
        query_sequence = pos_encoder(
            canon_batch.flatten(end_dim=1)
        ).unflatten(0, (B, N))
        feature = feat_extractor(batch.to(device))
        lm_pred, var_pred = lm_predictor(query_sequence, feature)

        for img, lms in zip(batch, lm_pred):
            img, lms = inv_tfm(img.cpu(), lms.cpu())
            input_ims.append(to_pil_image(img))
            pred_lms.append(lms)

In [12]:
from continuous_landmarks.utils.draw_points import draw_points
from continuous_landmarks.utils.face_alignment import get_matrix_and_size

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output.mp4', fourcc, 30.0, frames[0].shape[1::-1])

for orig_img, align_lms, input_im, lms in zip(tqdm(frames), landmarks, input_ims, pred_lms):
    eye_0, eye_1, mouth_0, mouth_1 = get_eyes_mouth(align_lms)
    M, align_size = get_matrix_and_size(
        eye_0, eye_1, mouth_0, mouth_1
    )

    # Undo resize
    assert input_im.width / input_im.height == 1
    input_size = input_im.width
    orig_height, orig_width, _ = orig_img.shape
    orig_size = min(orig_height, orig_width)
    lms = lms * align_size / input_size

    # Undo alignment
    lms = lms @ np.linalg.inv(M[:, :2]).T - M[:, 2]

    im = draw_points(orig_img, lms[::10], size=1)
    out.write(np.array(im)[..., ::-1])

out.release()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 188/188 [00:05<00:00, 31.65it/s]
