# Real Time Animation

## Import libs and modules

In [None]:
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize

from src.services.animation import AnimationService
from src.services.model import ModelService
from src.services.video_animation import VideoAnimationService

## Configuration

In [None]:
RELATIVE = True
ADAPT_MOVEMENT_SCALE = True
USE_CPU = True
VIDEO_CODEC = 'MJPG'
SOURCE_IMAGE_NAME = './data/input/nick.jpg'
RESULT_VIDEO_DIR = './data/output'
RESULT_VIDEO_NAME = './data/output/real_time_test.avi'
MODEL_CONFIG_PATH = './data/configs/vox-256.yaml'
MODEL_CHECKPOINT_PATH = './data/checkpoints/vox-cpk.pth.tar'
WINDOW_NAME = 'Real Time Animation'

In [None]:
%matplotlib inline
os.makedirs('./data/output', exist_ok=True)

## Prepare source image

In [None]:
source_image = cv2.imread(SOURCE_IMAGE_NAME)
source_image = cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB)
source_image = resize(source_image, (256, 256))[..., :3]

In [None]:
plt.imshow(source_image)
plt.axis('off')
plt.show()

## Prepare model

In [None]:
model_service = ModelService(
    config_path=MODEL_CONFIG_PATH,
    checkpoint_path=MODEL_CHECKPOINT_PATH,
    cpu=USE_CPU,
)

In [None]:
generator, kp_detector = model_service.load_eval_models()

## Additional functions

In [None]:
def preprocess_frame(frame, crop_box, target_size=(256, 256)):
    x, y, w, h = crop_box
    frame = cv2.flip(frame, 1)
    frame = frame[y:y+h, x:x+w]
    frame = resize(frame, target_size)[..., :3]
    return frame

In [None]:
def to_tensor(img, use_cpu=False):
    tensor = torch.tensor(img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
    return tensor if use_cpu else tensor.cuda()

In [None]:
def generate_frame(source, kp_source, frame_tensor, kp_initial, generator, kp_detector):
    kp_driving = kp_detector(frame_tensor)
    kp_norm = AnimationService.normalize_kp(
        kp_source=kp_source,
        kp_driving=kp_driving,
        kp_driving_initial=kp_initial,
        use_relative_movement=RELATIVE,
        use_relative_jacobian=RELATIVE,
        adapt_movement_scale=ADAPT_MOVEMENT_SCALE,
    )
    out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
    prediction = out['prediction'][0].data.cpu().permute(1, 2, 0).numpy()
    return prediction

## Prepare real time animation

In [None]:
cap = cv2.VideoCapture(0)
fourcc = cv2.VideoWriter_fourcc(*VIDEO_CODEC)
out_video = cv2.VideoWriter(RESULT_VIDEO_NAME, fourcc, 12, (256 * 3, 256), True)

source_rgb = cv2.cvtColor(source_image.astype('float32'), cv2.COLOR_BGR2RGB)
source_tensor = to_tensor(source_image, use_cpu=USE_CPU)

## Start real time animation

In [None]:
with torch.no_grad():
    kp_source = kp_detector(source_tensor)
    kp_initial = None
    crop_box = (143, 87, 322, 322)
    count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame_processed = preprocess_frame(frame, crop_box)
        frame_tensor = to_tensor(frame_processed, use_cpu=USE_CPU)

        if kp_initial is None:
            kp_initial = kp_detector(frame_tensor)

        prediction = generate_frame(
            source_tensor, kp_source, frame_tensor, kp_initial,
            generator, kp_detector
        )

        # Join frames
        joined = np.concatenate([
            source_rgb,
            cv2.cvtColor(prediction, cv2.COLOR_RGB2BGR),
            frame_processed
        ], axis=1)

        # Add text-hint
        cv2.putText(
            joined, "Press 'Q' to quit", (10, 245),
            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA
        )

        # Show and write
        cv2.imshow(WINDOW_NAME, joined)
        out_video.write(np.clip(joined * 255, 0, 255).astype(np.uint8))

        if cv2.waitKey(20) & 0xFF == ord('q'):
            break

cap.release()
out_video.release()
cv2.destroyAllWindows()

# Image Animation by Prepared Video

In [None]:
service = VideoAnimationService(
    config_path='./data/configs/vox-256.yaml',
    checkpoint_path='./data/checkpoints/vox-cpk.pth.tar',
    source_image_path='./data/input/monalisa.png',
    driving_video_path='./data/output/output.mp4',
    result_video_path='./data/output/result.mp4',
    relative=False,
    adapt_scale=False,
    find_best=False,
    best_frame=None,
    cpu=True,
)
service.run()