# Talking Head Anime from a Single Image (Manual Poser Tool)

Copycat of [original talking head anime](https://github.com/pkhungurn/talking-head-anime-demo/blob/master/tha_colab.ipynb), local ipynb inference version.

### Instruction
1. Edit var `path_ckpt` to your ckpt path.
2. If GPU is usable, change device to `cuda` or whatever you want.
3. Run the cells below, one by one and then play with the GUI.

### Input Images

1. Must be an image of a single humanoid anime character.
2. Must have an alpha channel, where background pixels must have color value of RGBA=(0, 0, 0, 0).
3. Input image is recommended to have closed mouth, face in the middle of the image, human-like character.

In [1]:
device = 'cpu'
path_morpher_ckpt = 'checkpoint/step_163548.pth'
path_rotator_ckpt = 'checkpoint/step_148680.pth'
imsize = 256

In [2]:
import torch
import cv2
import numpy as np
from PIL import Image
import IPython.display as ipd

from models.tha1 import FaceMorpher, TwoAlgorithmFaceRotator

In [3]:
import ipywidgets
import io

In [4]:
def show_pytorch_image(pytorch_image, output_widget=None):
    np_image = (255 * pytorch_image.detach().cpu().permute((1, 2, 0)).numpy()).astype(np.uint8)
    np_image = np.clip(np_image, 0, 255)
    ipd.display(Image.fromarray(np_image))
    
def extract_pytorch_image_from_filelike(file):
    pil_image = Image.open(file)
    im = np.asarray(pil_image)
    
#     im = cv2.imread(file, cv2.IMREAD_UNCHANGED)
    im = cv2.resize(im, (imsize, imsize))
#     im = cv2.cvtColor(im, cv2.COLOR_BGRA2RGBA)
    im = torch.from_numpy(im).permute((2, 0, 1)) / 255.
    return im

In [5]:
last_torch_input_image = None
torch_input_image = None

# image widgets, upload button
input_image_widget = ipywidgets.Output(
    layout={
        'border': '1px solid black',
        'width': '256px',
        'height': '256px'
    })

output_image_widget = ipywidgets.Output(
    layout={
        'border': '1px solid black',
        'width': '256px',
        'height': '256px'
    }
)

upload_input_image_button = ipywidgets.FileUpload(
    accept='.png',
    multiple=False,
    layout={
        'width': '256px'
    }
)


# control sliders
eye_left_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Left Eye:",
    readout=True,
    readout_format=".2f"
)
eye_right_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Right Eye:",
    readout=True,
    readout_format=".2f"
)
mouth_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Mouth:",
    readout=True,
    readout_format=".2f"
)

head_x_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-30,
    max=30,
    step=1,
    description="X-axis:",
    readout=True,
    readout_format=".2f"
)
head_y_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-30,
    max=30,
    step=1,
    description="Y-axis:",
    readout=True,
    readout_format=".2f",    
)
neck_z_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-30,
    max=30,
    step=1,
    description="Z-axis:",
    readout=True,
    readout_format=".2f",    
)


# control panels
control_panel = ipywidgets.VBox([    
    ipywidgets.HTML(value="<center><b>Head Rotation</b></center>"),
    head_x_slider,
    head_y_slider,
    neck_z_slider,
    ipywidgets.HTML(value="<hr>"),
    ipywidgets.HTML(value="<center><b>Facial Features</b></center>"),
    eye_left_slider,
    eye_right_slider,
    mouth_slider,
])

controls = ipywidgets.HBox([
    ipywidgets.VBox([
        input_image_widget, 
        upload_input_image_button
    ]),
    control_panel,
    ipywidgets.HTML(value="&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"),
    output_image_widget,
])

In [6]:
morpher = FaceMorpher(None)
morpher.load_state_dict(torch.load(path_morpher_ckpt, map_location='cpu')['FaceMorpher']['state_dict'])
morpher = morpher.to(device)
morpher = morpher.eval()

rotator = TwoAlgorithmFaceRotator(None)
rotator.load_state_dict(torch.load(path_rotator_ckpt, map_location='cpu')['FaceRotator']['state_dict'])
rotator = rotator.to(device)
rotator = rotator.eval()

# pose vector
pose_size = 6
last_pose = torch.zeros(1, pose_size).to(device)

In [7]:
def get_pose():
    pose = torch.zeros(1, pose_size)
    
    pose[0, 0] = head_x_slider.value
    pose[0, 1] = head_y_slider.value
    pose[0, 2] = neck_z_slider.value
    
    pose[0, 3] = mouth_slider.value
    pose[0, 4] = eye_left_slider.value
    pose[0, 5] = eye_right_slider.value
        
    return pose.to(device)

In [8]:
def update(change):
    global last_pose
    global last_torch_input_image
        
    if torch_input_image is None:
        return
        
    needs_update = False
    if last_torch_input_image is None:
        needs_update = True        
    else:
        if (torch_input_image - last_torch_input_image).abs().max().item() > 0:
            needs_update = True         
            
    pose = get_pose()
    if (pose - last_pose).abs().max().item() > 0:
        needs_update = True
    
    if not needs_update:
        return
   
    output_image = rotator(morpher(torch_input_image, pose[:, 3:])['e2'], pose[:, :3])['e4'][0]
    with output_image_widget:
        output_image_widget.clear_output(wait=True)
        show_pytorch_image(output_image, output_image_widget)  
        
    last_torch_input_image = torch_input_image
    last_pose = pose
        
def upload_image(change):
    global torch_input_image
    for name, file_info in upload_input_image_button.value.items():
        torch_input_image = extract_pytorch_image_from_filelike(io.BytesIO(file_info['content'])).to(device)
        torch_input_image = torch_input_image.unsqueeze(0)
    if torch_input_image is not None:
        n,c,h,w = torch_input_image.shape
        if h != imsize or w != imsize:
            with input_image_widget:
                input_image_widget.clear_output(wait=True)
                display(ipywidgets.HTML(f"Image must be {imsize}x{imsize} in size!!!"))
            torch_input_image = None
        if c != 4:
            with input_image_widget:
                input_image_widget.clear_output(wait=True)
                display(ipywidgets.HTML("Image must have an alpha channel!!!"))                
            torch_input_image = None
        if torch_input_image is not None:
            with input_image_widget:
                input_image_widget.clear_output(wait=True)
                show_pytorch_image(torch_input_image[0], input_image_widget)
        update(None)

In [9]:
display(controls)
upload_input_image_button.observe(upload_image, names='value')
eye_left_slider.observe(update, 'value')
eye_right_slider.observe(update, 'value')
mouth_slider.observe(update, 'value')
head_x_slider.observe(update, 'value')
head_y_slider.observe(update, 'value')
neck_z_slider.observe(update, 'value')

HBox(children=(VBox(children=(Output(layout=Layout(border='1px solid black', height='256px', width='256px')), …

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  "Default grid_sample and affine_grid behavior has changed "
