# Original FOMM Demo

In [6]:
# Import the packages needed for demonstration

import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML

from demo import make_animation
from skimage import img_as_ubyte

import warnings
warnings.filterwarnings("ignore")

In [7]:
# read in the source video and target image
target_path = "raw_data/targets/3.png"
source_path = "raw_data/sources/00048.mp4"

source_image = imageio.imread(target_path)
reader = imageio.get_reader(source_path)

# pre process the video and image
source_image = resize(source_image, (256, 256))[..., :3]
fps = reader.get_meta_data()['fps'] # number of frames

# Add each frame of the video
driving_video = []
try:
    for im in reader:
        driving_video.append(im)
except RuntimeError:
    pass
reader.close()
# resize each frame in the video to 256x256
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]

In [None]:
# A function that generates a piece of video
def display(source, driving, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

    ims = []
    for i in range(len(driving)):
        cols = [source]
        cols.append(driving[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani

In [None]:
# Load the deep network
from demo import load_checkpoints
generator, kp_detector = load_checkpoints(config_path='config/vox-256.yaml', 
                            checkpoint_path='pre_trains/vox-cpk.pth.tar', cpu=True)

In [None]:
# Generate animation
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, cpu=True)

In [None]:
# show the video
HTML(display(source_image, driving_video, predictions).to_html5_video())

# Stylizer Added Demo
Now we add the expression stylizer to see the effects transferred to expressions of animated characters.

In [2]:
# Import the packages needed for demonstration

import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML

from demo import make_animation
from skimage import img_as_ubyte

import warnings
warnings.filterwarnings("ignore")

In [3]:
# read in the source video and target image
target_path = "raw_data/targets/3.png"
source_path = "raw_data/sources/00048.mp4"

source_image = imageio.imread(target_path)
reader = imageio.get_reader(source_path)

# pre process the video and image
source_image = resize(source_image, (256, 256))[..., :3]
fps = reader.get_meta_data()['fps'] # number of frames

# Add each frame of the video
driving_video = []
try:
    for im in reader:
        driving_video.append(im)
except RuntimeError:
    pass
reader.close()
# resize each frame in the video to 256x256
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]

In [4]:
# A function that generates a piece of video
def display(source, driving, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

    ims = []
    for i in range(len(driving)):
        cols = [source]
        cols.append(driving[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani

In [5]:
import matplotlib
matplotlib.use('Agg')
import os, sys
import yaml
from argparse import ArgumentParser
from tqdm import tqdm

import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback

from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
from scipy.spatial import ConvexHull

checkpoint_path = "pre_trains/vox-cpk.pth.tar"
config_path='config/anim-256.yaml'
with open(config_path) as f:
    config = yaml.load(f)

# initialize generator
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
# initialize kp detector
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])

# If GPU Available, adapt to it
if torch.cuda.is_available():
    print("using GPU")
    generator.to(0)
    kp_detector.to(0)
    
# load in the pretrained modules
train_params = config['train_params']

if not torch.cuda.is_available():
    # remember to adapt to cpu version
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
    checkpoint = torch.load(checkpoint_path)

generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])

# The following models are used as data pre-processor
generator.eval()
kp_detector.eval()

using GPU


KPDetector(
  (predictor): Hourglass(
    (encoder): Encoder(
      (down_blocks): ModuleList(
        (0): DownBlock2d(
          (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (pool): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
        )
        (1): DownBlock2d(
          (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): SynchronizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (pool): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
        )
        (2): DownBlock2d(
          (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): SynchronizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (pool): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)

In [None]:
# declare the stylizer we need
# declare objects needed by training process
import torch
from modules.stylizer import StylizerGenerator
from modules.stylizer_discriminator import StylizerDiscrim

# create network models
stylizer = StylizerGenerator(**config['model_params']['stylizer_params'])
styDiscrim = StylizerDiscrim(**config['model_params']['stylizerDiscrim_params'])

# If GPU Available, adapt to it
if torch.cuda.is_available():
    print("using GPU")
    stylizer.to(0)
    styDiscrim.to(0)
    
# load in pretrained modules
stylizer_checkpoint = "pre_trains/00000099-checkpoint.pth.tar"
stylizer_checkpoint = torch.load(stylizer_checkpoint)
stylizer.load_state_dict(checkpoint['stylizer'])
styDiscrim.load_state_dict(checkpoint['styDiscrim'])
# set to evaluate mode
stylizer.eval()
styDiscrim.eval()

## define funtion of prediction

In [None]:
# With modules given, generate final results
from animate import normalize_kp

def my_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
    with torch.no_grad():
        predictions = []
        # turn source and driving to tensor
        source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
        if not cpu:
            source = source.cuda()
        driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
        # pass in the detector for a result
        kp_source = kp_detector(source)
        kp_driving_initial = kp_detector(driving[:, :, 0])

        for frame_idx in tqdm(range(driving.shape[2])):
            driving_frame = driving[:, :, frame_idx]
            if not cpu:
                driving_frame = driving_frame.cuda()
            kp_driving = kp_detector(driving_frame)
            kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                   kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
                                   use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
            # ---------------------------------------- #
            # TODO: replace the generator below
            dm_network = generator.dense_motion_network
            out = generator.first(source)
            for i in range(len(generator.down_blocks)):
                out = generator.down_blocks[i](out)

            # Transforming feature representation according to deformation and occlusion
            # 通过形变等信息来变换特征向量
            output_dict = {}
            if dm_network is not None:
                # 通过 稠密运动网络模块 获取运动变换信息
                # ------------------------------------------ #
                # TODO: replace dense motion
                if dm_network.scale_factor != 1:
                    src_image = dm_network.down(source)
                
                bs, _, h, w = src_image.shape

                dm_out_dict = dict()
                heatmap_representation = dm_network.create_heatmap_representations(src_image, kp_driving, kp_source)
                orig_sparse_motion = dm_network.create_sparse_motion(src_image, kp_driving, kp_source)
                sparse_motion = orig_sparse_motion[:, :-1, :, :, :] # The shape should be (1, kp_num, 64, 64, 2)
                temp_shape = sparse_motion.shape
                sparse_motion = sparse_motion.permute((0, 1, 4, 2, 3)).reshape((temp_shape[0], temp_shape[1] * temp_shape[4], temp_shape[2], temp_shape[3]))
                # now the shape is (1, kp_num * 2, 64, 64)
                stylized_motion = stylizer(sparse_motion)['prediction'] # this is the stylized sparse motion
                stylized_motion = stylized_motion.reshape((temp_shape[0], temp_shape[1], temp_shape[4], temp_shape[2], temp_shape[3])).permute((0, 1, 3, 4, 2))
                orig_sparse_motion[:, :-1, :, :, :] = stylized_motion
                # now the shape is (1, kp_num, 64, 64, 2), which is the component we want
                deformed_source = dm_network.create_deformed_source_image(src_image, orig_sparse_motion)
                dm_out_dict['sparse_deformed'] = deformed_source
                
                input = torch.cat([heatmap_representation, deformed_source], dim=2)
                input = input.view(bs, -1, h, w)
                prediction = dm_network.hourglass(input)
                
                mask = dm_network.mask(prediction)
                mask = F.softmax(mask, dim=1)
                dm_out_dict['mask'] = mask
                
                mask = mask.unsqueeze(2)
                orig_sparse_motion = orig_sparse_motion.permute(0, 1, 4, 2, 3)
                deformation = (orig_sparse_motion * mask).sum(dim=1)
                deformation = deformation.permute(0, 2, 3, 1)
                
                dm_out_dict['deformation'] = deformation
                
                # Sec. 3.2 in the paper
                if dm_network.occlusion:
                    occlusion_map = torch.sigmoid(dm_network.occlusion(prediction))
                    dm_out_dict['occlusion_map'] = occlusion_map
                
                # ------------------------------------------ #
                # back to generator
                output_dict['mask'] = dm_out_dict['mask']
                output_dict['sparse_deformed'] = dm_out_dict['sparse_deformed']

                if 'occlusion_map' in dm_out_dict:
                    occlusion_map = dm_out_dict['occlusion_map']
                    output_dict['occlusion_map'] = occlusion_map
                else:
                    occlusion_map = None
                deformation = dm_out_dict['deformation']
                # 最终在此步对 encode 出来的特征值进行变换
                out = generator.deform_input(out, deformation)

                if occlusion_map is not None:
                    if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
                        occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
                    out = out * occlusion_map

                output_dict["deformed"] = generator.deform_input(source, deformation)

            # Decoding part
            out = generator.bottleneck(out)
            for i in range(len(generator.up_blocks)):
                out = generator.up_blocks[i](out)
            out = generator.final(out)
            out = F.sigmoid(out)

            output_dict["prediction"] = out
            # -------------------------------- End of generator ----------------------------------# 
            predictions.append(np.transpose(output_dict['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
    return predictions

In [None]:
# Generate animation
predictions = my_animation(source_image, driving_video, generator, kp_detector, relative=True, cpu=True)

In [None]:
# show the video
HTML(display(source_image, driving_video, predictions).to_html5_video())