In [None]:
# Import the packages needed for demonstration

import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML

from demo import make_animation
from skimage import img_as_ubyte

import warnings
warnings.filterwarnings("ignore")

In [None]:
# read in the source video and target image
target_path = "raw_data/targets/3.png"
source_path = "raw_data/sources/00048.mp4"

source_image = imageio.imread(target_path)
reader = imageio.get_reader(source_path)

# pre process the video and image
source_image = resize(source_image, (256, 256))[..., :3]
fps = reader.get_meta_data()['fps'] # number of frames

# Add each frame of the video
driving_video = []
try:
    for im in reader:
        driving_video.append(im)
except RuntimeError:
    pass
reader.close()
# resize each frame in the video to 256x256
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]

In [None]:
# A function that generates a piece of video
def display(source, driving, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

    ims = []
    for i in range(len(driving)):
        cols = [source]
        cols.append(driving[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani

In [None]:
# Load the deep network
from demo import load_checkpoints
generator, kp_detector = load_checkpoints(config_path='config/vox-256.yaml', 
                            checkpoint_path='pre_trains/vox-cpk.pth.tar', cpu=True)

In [None]:
# Generate animation
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, cpu=True)

In [None]:
# show the video
HTML(display(source_image, driving_video, predictions).to_html5_video())