In [None]:
# Clone necessary code
!git clone https://github.com/AliaksandrSiarohin/first-order-model
%cd first-order-model

In [None]:
# Download still image of statue
!wget --no-clobber https://myshare.uni-osnabrueck.de/f/2c323b06f1354ed2ae72/?dl=1 -O still.png

In [None]:
# Download video to be mapped onto statue
!wget --no-clobber https://myshare.uni-osnabrueck.de/f/43efb4770f54456192ae/?dl=1 -O leo.mp4

In [None]:
# Download pretrained model on celebrity faces (VoxCeleb)
!wget --no-clobber https://myshare.uni-osnabrueck.de/f/86c0195eb2e74845b77d/?dl=1 -O vox-cpk.pth.tar

In [None]:
# Download config file
!wget --no-clobber https://myshare.uni-osnabrueck.de/f/41983adf0c4c4fcbabda/?dl=1 -O config.yml

In [None]:
import torch

USE_GPU = torch.cuda.is_available() and torch.cuda.device(torch.cuda.current_device())
if USE_GPU:
    print("Found GPU: {}".format(torch.cuda.get_device_name(torch.cuda.current_device())))
else:
    print("Using CPU. It will bleed. RIP")

In [None]:
# Last check for existing files
from pathlib import Path

CONF_VALID=False

# Input Data
CONFIG='config.yml' # Path to model config
DRIVING_VIDEO='leo.mp4' # Driving Video to use
SOURCE_IMAGE='still.png' # Source image to use
CKPT='vox-cpk.pth.tar' # path to checkpoint to restore
CPU=not USE_GPU

# Ouput
OUTPUT_VIDEO='result.mp4'
OVERWRITE=True # Overwrite existing OUTPUT_VIDEO

# Input Data Extraction
FIND_BEST_FRAME=False
BEST_FRAME=None

RELATIVE=True # use relative or absolute keypoint coordinates
ADAPT_SCALE=True # adapt movement scale based on convex hull of keypoints

assert(Path(SOURCE_IMAGE).exists())
assert(Path(DRIVING_VIDEO).exists())
assert(Path(CKPT).exists())
assert(Path(CONFIG).exists())
if not OVERWRITE and Path(OUTPUT_VIDEO).exists():
    raise Exception("Config would overwride existing output file!")
    
CONF_VALID=True

In [None]:
from IPython.display import HTML

print("  Driving Video (left), Source Image (right)")

HTML("""
<div>
<table>
<tr>
  <td>
  <video width="350" height="350" controls autoplay loop>
    <source src="{}" type="video/mp4">
  </video>
  </td>
  <td>
  <img src="{}" alt="Girl in a jacket" width="350" height="350">
  </td>
  </tr>
  </table>
</div>
""".format(DRIVING_VIDEO, SOURCE_IMAGE, {}))

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

import matplotlib
matplotlib.use('Agg')
import os, sys
import yaml
from argparse import ArgumentParser
from tqdm import tqdm

import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback

from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
from scipy.spatial import ConvexHull


if sys.version_info[0] < 3:
    raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")

def load_checkpoints(config_path, checkpoint_path, cpu=False):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    if not cpu:
        generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not cpu:
        kp_detector.cuda()
    
    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)
 
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])
    
    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()
    
    return generator, kp_detector


def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
    with torch.no_grad():
        predictions = []
        source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
        if not cpu:
            source = source.cuda()
        driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
        kp_source = kp_detector(source)
        kp_driving_initial = kp_detector(driving[:, :, 0])

        for frame_idx in tqdm(range(driving.shape[2])):
            driving_frame = driving[:, :, frame_idx]
            if not cpu:
                driving_frame = driving_frame.cuda()
            kp_driving = kp_detector(driving_frame)
            kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                   kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
                                   use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
            out = generator(source, kp_source=kp_source, kp_driving=kp_norm)

            predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
    return predictions

def find_best_frame(source, driving, cpu=False):
    import face_alignment

    def normalize_kp(kp):
        kp = kp - kp.mean(axis=0, keepdims=True)
        area = ConvexHull(kp[:, :2]).volume
        area = np.sqrt(area)
        kp[:, :2] = kp[:, :2] / area
        return kp

    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
                                      device='cpu' if cpu else 'cuda')
    kp_source = fa.get_landmarks(255 * source)[0]
    kp_source = normalize_kp(kp_source)
    norm  = float('inf')
    frame_num = 0
    for i, image in tqdm(enumerate(driving)):
        kp_driving = fa.get_landmarks(255 * image)[0]
        kp_driving = normalize_kp(kp_driving)
        new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
        if new_norm < norm:
            norm = new_norm
            frame_num = i
    return frame_num

if __name__ == "__main__":
    if not CONF_VALID:
        raise Exception("Sorry, you ran into issues with config and ignored them. Not so fast, bucko.")
    
    source_image = imageio.imread(SOURCE_IMAGE)
    video_reader = imageio.get_reader(DRIVING_VIDEO)
    fps = video_reader.get_meta_data()['fps']
    driving_video = []
    try:
        for im in video_reader:
            driving_video.append(im)
    except RuntimeError:
        pass
    video_reader.close()

    source_image = resize(source_image, (256, 256))[..., :3]
    driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
    
    generator, kp_detector = load_checkpoints(config_path=CONFIG, checkpoint_path=CKPT, cpu=CPU)

    if FIND_BEST_FRAME or BEST_FRAME is not None:
        i = BEST_FRAME if BEST_FRAME is not None else find_best_frame(source_image, driving_video, cpu=CPU)
        print ("Best frame: " + str(i))
        driving_forward = driving_video[i:]
        driving_backward = driving_video[:(i+1)][::-1]
        predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=RELATIVE, adapt_movement_scale=ADAPT_SCALE, cpu=CPU)
        predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=RELATIVE, adapt_movement_scale=ADAPT_SCALE, cpu=CPU)
        predictions = predictions_backward[::-1] + predictions_forward[1:]
    else:
        predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=RELATIVE, adapt_movement_scale=ADAPT_SCALE, cpu=CPU)
    
    imageio.mimsave(OUTPUT_VIDEO, [img_as_ubyte(frame) for frame in predictions], format='.mp4', fps=fps)
    print("Output video saved to: {}".format(OUTPUT_VIDEO))

In [None]:
from IPython.display import HTML

print("  Driving Video (left), Source Image (right), Result (right)")

HTML("""
<div>
<table>
<tr>
  <td>
  <video width="250" height="250" controls autoplay loop>
    <source src="{}" type="video/mp4">
  </video>
  </td>
  <td>
  <img src="{}" width="250" height="250">
  </td>
  <td>
  <video width="250" height="250" controls autoplay loop>
    <source src="{}" type="video/mp4">
  </video>
  </td>
  </tr>
</table>
</div>
""".format(DRIVING_VIDEO, SOURCE_IMAGE, 'result.mp4'))