# Thin-Plate Spline Motion Model for Image Animation

<img src="https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model/raw/main/assets/vox.gif" width="600px" />

<img src="https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model/raw/main/assets/ted.gif" width="600px" />

- 원본 소스코드: [빵형의 개발도상국](https://www.youtube.com/@bbanghyong)

## 소스코드/모델 다운로드

In [1]:
!git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model.git

Cloning into 'Thin-Plate-Spline-Motion-Model'...
remote: Enumerating objects: 115, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 115 (delta 43), reused 31 (delta 29), pack-reused 50[K
Receiving objects: 100% (115/115), 32.66 MiB | 25.22 MiB/s, done.
Resolving deltas: 100% (51/51), done.


In [2]:
cd Thin-Plate-Spline-Motion-Model

/content/Thin-Plate-Spline-Motion-Model


In [3]:
!mkdir checkpoints
!pip3 install wldhx.yadisk-direct
!curl -L $(yadisk-direct https://disk.yandex.com/d/bWopgbGj1ZUV1w) -o tpsmm.zip

Collecting wldhx.yadisk-direct
  Downloading wldhx.yadisk_direct-0.0.6-py3-none-any.whl (4.5 kB)
Installing collected packages: wldhx.yadisk-direct
Successfully installed wldhx.yadisk-direct-0.0.6
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 1055M    0 1055M    0     0  8234k      0 --:--:--  0:02:11 --:--:-- 8148k


In [10]:
!unzip tpsmm.zip
!mv tpsmm/* checkpoints/
!rm -rf tpsmm

Archive:  tpsmm.zip
  inflating: tpsmm/mgif.pth.tar      
  inflating: tpsmm/taichi.pth.tar    
  inflating: tpsmm/ted.pth.tar       
  inflating: tpsmm/vox.pth.tar       


In [5]:
!pip install -q face_alignment imageio_ffmpeg

## 설정

<img src="https://user-images.githubusercontent.com/48593306/197152487-45d5198a-1e7d-4e73-8709-cf7621827d60.png" width="600px" />

In [7]:
import torch

# edit the config
device = torch.device('cuda:0')
dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']
output_video_path = './generated.mp4'
config_path = 'config/vox-256.yaml'
checkpoint_path = 'checkpoints/vox.pth.tar'
predict_mode = 'relative' # ['standard', 'relative', 'avd']
find_best_frame = True # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result

pixel = 256 # for vox, taichi and mgif, the resolution is 256*256
if(dataset_name == 'ted'): # for ted, the resolution is 384*384
    pixel = 384

## 패키지/모델 로드

In [3]:
try:
  import imageio
  import imageio_ffmpeg
except:
  !pip install imageio_ffmpeg

In [4]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML
import warnings
import os

warnings.filterwarnings("ignore")

In [11]:
from demo import load_checkpoints

inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)

## 내 얼굴 녹화하기

driving video

In [1]:
#@title
from IPython.display import display, Javascript,HTML
from google.colab.output import eval_js
from base64 import b64decode

def record_video(filename):
  js=Javascript("""
    async function recordVideo() {
      const options = { mimeType: "video/webm; codecs=vp9" };
      const div = document.createElement('div');
      const capture = document.createElement('button');
      const stopCapture = document.createElement("button");

      div.style.maxWidth = '400px';

      capture.textContent = "Start Recording";
      capture.style.background = "orange";
      capture.style.color = "white";

      stopCapture.textContent = "Stop Recording";
      stopCapture.style.background = "red";
      stopCapture.style.color = "white";
      div.appendChild(capture);

      const video = document.createElement('video');
      const recordingVid = document.createElement("video");
      video.style.display = 'block';

      const stream = await navigator.mediaDevices.getUserMedia({audio:true, video: {
        facingMode: "environment",
        aspectRatio: { exact: 1 }
      }});

      let recorder = new MediaRecorder(stream, options);
      document.body.appendChild(div);
      div.appendChild(video);

      video.srcObject = stream;
      video.style.maxWidth = '400px';
      video.muted = true;

      await video.play();

      google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

      await new Promise((resolve) => {
        capture.onclick = resolve;
      });
      recorder.start();
      capture.replaceWith(stopCapture);

      await new Promise((resolve) => stopCapture.onclick = resolve);
      recorder.stop();
      let recData = await new Promise((resolve) => recorder.ondataavailable = resolve);
      let arrBuff = await recData.data.arrayBuffer();

      // stop the stream and remove the video element
      stream.getVideoTracks()[0].stop();
      div.remove();

      let binaryString = "";
      let bytes = new Uint8Array(arrBuff);
      bytes.forEach((byte) => {
        binaryString += String.fromCharCode(byte);
      })
    return btoa(binaryString);
    }
  """)
  try:
    display(js)
    data=eval_js('recordVideo({})')
    binary=b64decode(data)
    with open(filename,"wb") as video_file:
      video_file.write(binary)
    print(f"Finished recording video at:{filename}")
  except Exception as err:
    print(str(err))

record_video('capture.mp4')

<IPython.core.display.Javascript object>

Finished recording video at:capture.mp4


## 얼굴 파일 업로드

In [5]:
from google.colab import files
import shutil

myfiles = files.upload()
source_image_path = list(myfiles.keys())[0]

Saving test_face.jpg to test_face.jpg


## 전처리

In [8]:
driving_video_path = 'capture.mp4'

source_image = imageio.imread(source_image_path)
reader = imageio.get_reader(driving_video_path)

source_image = resize(source_image, (pixel, pixel))[..., :3]

fps = reader.get_meta_data()['fps']
driving_video = []
try:
    for im in reader:
        driving_video.append(im)
except RuntimeError:
    pass
reader.close()

driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]

def display(source, driving, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 4))
    fig.subplots_adjust(bottom=0, top=1, left=0, right=1)

    ims = []
    for i in range(len(driving)):
        cols = [source]
        cols.append(driving[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani


HTML(display(source_image, driving_video).to_html5_video())

## 추론/결과보기

In [None]:
output_video_path = 'assets/result.mp4'

from demo import make_animation
from demo import find_best_frame as _find
from skimage import img_as_ubyte

# if predict_mode=='relative' and find_best_frame:
#     i = _find(source_image, driving_video, device=='cpu')
#     print ("Best frame: " + str(i))
#     driving_forward = driving_video[i:]
#     driving_backward = driving_video[:(i+1)][::-1]
#     predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
#     predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
#     predictions = predictions_backward[::-1] + predictions_forward[1:]
# else:
#     predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)

predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)

# save result video
imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)

HTML(display(source_image, driving_video, predictions).to_html5_video())