論文  
https://arxiv.org/abs/2203.14367<br>
<br>  
GitHub
https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model.git<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/ThinPlateSplineMotionModel_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## GitHubからコード取得

In [None]:
%cd /content

!git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model.git

## ライブラリのインストール

In [None]:
%cd /content/

!pip install face_alignment > /dev/null
# face alignment用にclone
!git clone https://github.com/adamian98/pulse.git

## ライブラリのインポート

In [None]:
%cd /content/Thin-Plate-Spline-Motion-Model

import torch

import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML
import warnings
import os

from demo import load_checkpoints
from demo import make_animation
from skimage import img_as_ubyte

from google.colab import files
from moviepy.editor import *

warnings.filterwarnings("ignore")

# テストデータのセットアップ
[使用動画](https://www.pexels.com/ja-jp/video/5981354/)<br>
[使用画像](https://www.pakutaso.com/shared/img/thumb/nissinIMGL0823_TP_V.jpg)

## アップロード

In [None]:
#@markdown 動画の切り抜き範囲(秒)を指定してください。\
#@markdown 30秒以上の場合OOM発生の可能性が高いため注意
start_sec =  3#@param {type:"integer"}
end_sec =  7#@param {type:"integer"}

(start_pt, end_pt) = (start_sec, end_sec)

In [None]:
%cd /content/Thin-Plate-Spline-Motion-Model

!rm -rf test_data
!mkdir test_data
%cd test_data
!mkdir image aligned_image video frames aligned_video

%cd video
print("upload video...")
video = files.upload()
video = list(video.keys())
video_file = video[0]
# 指定区間切り抜き
with VideoFileClip(video_file) as video:
    subclip = video.subclip(start_pt, end_pt)
    subclip.write_videofile("./video.mp4")
# frameに分割
!ffmpeg -i video.mp4 ../frames/%02d.png

%cd ../image
print("upload image...")
image = files.upload()
image = list(image.keys())
image_file = image[0]

## aligned image

In [None]:
%cd /content/Thin-Plate-Spline-Motion-Model/test_data


!python /content/pulse/align_face.py \
  -input_dir /content/Thin-Plate-Spline-Motion-Model/test_data/image \
  -output_dir /content/Thin-Plate-Spline-Motion-Model/test_data/aligned_image \
  -output_size 256 \
  -seed 1234

## aligned video

In [None]:
%cd /content/Thin-Plate-Spline-Motion-Model

# tedの場合はoutput_size 384
!python /content/pulse/align_face.py \
  -input_dir /content/Thin-Plate-Spline-Motion-Model/test_data/frames \
  -output_dir /content/Thin-Plate-Spline-Motion-Model/test_data/aligned_video \
  -output_size 256 \
  -seed 1234

In [None]:

!ffmpeg -i /content/Thin-Plate-Spline-Motion-Model/test_data/aligned_video/%02d_0.png -c:v libx264 -vf "fps=25,format=yuv420p" /content/Thin-Plate-Spline-Motion-Model/test_data/aligned_video/aligned.mp4

# モデルのセットアップ

In [None]:
%cd /content/Thin-Plate-Spline-Motion-Model

# @markdown モデル選択
dataset_name = 'vox' #@param ["vox", "taichi", "ted", "mgif"]
# @markdown 入力画像
source_image_path = '/content/Thin-Plate-Spline-Motion-Model/test_data/aligned_image/nissinIMGL0823_TP_V_0.png' #@param {type:"string"}
# @markdown 入力動画
driving_video_path = '/content/Thin-Plate-Spline-Motion-Model/test_data/aligned_video/aligned.mp4' #@param {type:"string"}
# @markdown 出力先
output_video_path = './generated.mp4' #@param {type:"string"}
# @markdown predict mode
predict_mode = 'relative' #@param ['standard', 'relative', 'avd']
# "relative"の際にTrueにすると出力結果の品質が向上
find_best_frame = False  #@param {type:"boolean"}

In [None]:
%cd /content/Thin-Plate-Spline-Motion-Model

# edit the config
device = torch.device('cuda:0')

## 学習済みモデルのダウンロード

In [None]:
%cd /content/Thin-Plate-Spline-Motion-Model

!mkdir checkpoints

# dataset_name = 'vox' #@param ["vox", "taichi", "ted", "mgif"]

if dataset_name == 'vox':
  !wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar
  config_path = 'config/vox-256.yaml'
  checkpoint_path = 'checkpoints/vox.pth.tar'
  pixel = 256
if dataset_name == 'taichi':
  !wget -c https://cloud.tsinghua.edu.cn/f/9ec01fa4aaef423c8c02/?dl=1 -O checkpoints/taichi.pth.tar
  config_path = 'config/taichi-256.yaml'
  checkpoint_path = 'checkpoints/taichi.pth.tar'
  pixel = 256
if dataset_name == 'ted':
  !wget -c https://cloud.tsinghua.edu.cn/f/483ef53650b14ac7ae70/?dl=1 -O checkpoints/ted.pth.tar
  config_path = 'config/ted-384.yaml'
  checkpoint_path = 'checkpoints/ted.pth.tar'
  pixel = 384
if dataset_name == 'mgif':
  !wget -c https://cloud.tsinghua.edu.cn/f/cd411b334a2e49cdb1e2/?dl=1 -O checkpoints/mgif.pth.tar
  config_path = 'config/mgif-256.yaml'
  checkpoint_path = 'checkpoints/mgif.pth.tar'
  pixel = 256

# 表示用関数定義

In [None]:
def display(source, driving, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

    ims = []
    for i in range(len(driving)):
        cols = [source]
        cols.append(driving[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani

# データのロード

In [None]:
source_image = imageio.imread(source_image_path)
reader = imageio.get_reader(driving_video_path)

source_image = resize(source_image, (pixel, pixel))[..., :3]

fps = reader.get_meta_data()['fps']
driving_video = []
try:
    for im in reader:
        driving_video.append(im)
except RuntimeError:
    pass
reader.close()

driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]

HTML(display(source_image, driving_video).to_html5_video())

# Inference

In [None]:
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)

if predict_mode=='relative' and find_best_frame:
    from demo import find_best_frame as _find
    i = _find(source_image, driving_video, device.type=='cpu')
    print ("Best frame: " + str(i))
    driving_forward = driving_video[i:]
    driving_backward = driving_video[:(i+1)][::-1]
    predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
    predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
    predictions = predictions_backward[::-1] + predictions_forward[1:]
else:
    predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)

#save resulting video
imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)

HTML(display(source_image, driving_video, predictions).to_html5_video())