In [1]:
import os
import argparse
import torch
import numpy as np
import json
from tqdm import tqdm
from ipywidgets import interact, widgets
from os.path import join as pjoin
from PIL import Image
from decimal import Decimal
from glob import glob
from renderer import Renderer
from model import Model
from anymole_render_utils import save_cam_to_json, load_cam_from_json

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [9]:
torch.cuda.empty_cache()
cam_params = argparse.Namespace(
    znear=0.1,
    zfar=100.0,
    aspect_ratio=1.0,
    fov=30.0,
    dist=2.7,
    elev=0.0,
    azim=0.0,
    lookat_x=0.0,
    lookat_y=1.0,
    lookat_z=0.0
)

renderer = Renderer(device, cam_params=cam_params)

# load model
data_path    = "data/characters"
motion_data  = "Amy_careful_Walking_input.npz"
char_name = "Amy"



motion_name = os.path.join(data_path,char_name,motion_data)
print(f"char : {char_name}\t motion : {motion_name}")

if os.path.exists(f"{data_path}/{char_name}/textures"):
    diffuse_name = [x for x in os.listdir(pjoin(data_path, char_name, "textures")) if "Diffuse" in x][0]
    model = Model(pjoin(data_path, char_name, f"{char_name}.pkl"),
                  texture_paths=[f"{data_path}/{char_name}/textures/{diffuse_name}"],
                  device=device)
else:
    model = Model(pjoin(data_path, char_name, f"{char_name}.pkl"),
                  device=device)
model.set_motion(np.load(motion_name))

# pre-compute deformed vertices
model.deform_by_motion(model.local_quats, model.root_pos)

char : Amy	 motion : data/characters/Amy/Amy_careful_Walking_input.npz


In [10]:
# cam rotation control
azimuth   = widgets.FloatSlider(description="azim",
                                min=-180, max=180, step=1, value=0.0)
elevation = widgets.FloatSlider(description="elev",
                                min=-180, max=180, step=1, value=0.0)
distance  = widgets.FloatSlider(description="dist",
                                value=2.7,
                                min=1.0, max=10.0, step=0.1)
cam_display = widgets.VBox([azimuth, elevation, distance])

# lookat position
lookat_x = widgets.FloatSlider(description="lookat_x",
                               min=-5, max=5, step=0.1, value=0.0)
lookat_y = widgets.FloatSlider(description="lookat_y",
                               min=-5, max=5, step=0.1, value=1.0)
lookat_z = widgets.FloatSlider(description="lookat_z",
                               min=-5, max=5, step=0.1, value=0.0)

at_display = widgets.VBox([lookat_x, lookat_y, lookat_z])

# frame control
frame_idx = widgets.IntSlider(description="frame",
                              min=0, max=model.local_quats.shape[0]-1, step=1, value=0)

In [11]:
def get_cam_relative_3d_joints(frame_idx):
    joint_pos = model.get_joint_pos(frame_idx)
    joint_pos = renderer.cam.transform_points_screen(joint_pos)
    return joint_pos
    
def draw_image(frame_idx, azimuth, elevation, distance, lookat_x, lookat_y, lookat_z):
    at = torch.tensor([[lookat_x, lookat_y, lookat_z]], dtype=torch.float32)
    renderer.set_camera(azimuth, elevation, distance, at)

    mesh = model.get_mesh(frame_idx)
    image = renderer.render(mesh)
    image = Image.fromarray(renderer.tensor2image(image))

    display(image)
    
def save_image(img_btn: widgets.Button):
    at = torch.tensor([[lookat_x.value,
                        lookat_y.value,
                        lookat_z.value]], dtype=torch.float32)
    renderer.set_camera(azimuth.value,
                        elevation.value,
                        distance.value, at)

    mesh = model.get_mesh(frame_idx.value)
    image = renderer.render(mesh)
    image = Image.fromarray(renderer.tensor2image(image))
    image.save("image.png")

In [None]:

def save_video(video_btn: widgets.Button):
    
    z_difference = 0
    # render setting
    print('x',lookat_x.value)
    print('y',lookat_y.value)
    print('z',lookat_z.value)
    print('azimuth',azimuth.value)
    print('elevation',elevation.value)
    print('distance',distance.value)
    
    at = torch.tensor([[lookat_x.value,
                        lookat_y.value,
                        lookat_z.value]], dtype=torch.float32)
    renderer.set_camera(azimuth.value,
                        elevation.value,
                        distance.value, at)
    # save images
    save_dir = pjoin("images", char_name, motion_data.replace(".npz", ""))
    os.makedirs(save_dir, exist_ok=True)
    
    for i in tqdm(range(model.local_quats.shape[0])):
        if (i>60) and (i%30 != 0) :
            continue
        joint_pos = get_cam_relative_3d_joints(i)
        np.save(f"{save_dir}/frame_{i:04d}.npy", joint_pos.cpu().numpy())

        mesh = model.get_mesh(i)
        image = renderer.render(mesh)
        image = Image.fromarray(renderer.tensor2image(image))
        image.save(f"{save_dir}/frame_{i:04d}.png")


    print("Images saved!")

    # save render config
    at = [float(Decimal(lookat_x.value)),
          float(Decimal(lookat_y.value)),
          float(Decimal(lookat_z.value))]
          
    cam_params = {
        "znear": 0.1,
        "zfar": 30.0,
        "aspect_ratio": 1.0,
        "fov": 30.0,
        "azim": float(Decimal(azimuth.value)),
        "elev": float(Decimal(elevation.value)),
        "dist":  float(Decimal(distance.value)),
        "lookat_x": at[0],
        "lookat_y": at[1],
        "lookat_z": at[2],
    }
    with open(pjoin(save_dir, "cam_params.json"), "w") as f:
        json.dump(cam_params, f)
        
    at = torch.tensor([[lookat_x.value,
                        lookat_y.value,
                        lookat_z.value]], dtype=torch.float32)
    renderer.set_camera(azimuth.value-179,
                        elevation.value,
                        distance.value, at)
    
    for i in tqdm(range(model.local_quats.shape[0])):
        if (i>60) and (i%30 != 0) :
            continue
        #save 3d here
        joint_pos = get_cam_relative_3d_joints(i)
        np.save(f"{save_dir}/back_{i:04d}.npy", joint_pos.cpu().numpy())

        mesh = model.get_mesh(i)
        image = renderer.render(mesh)
        image = Image.fromarray(renderer.tensor2image(image))
        image.save(f"{save_dir}/back_{i:04d}.png")
            
    
    at = torch.tensor([[lookat_x.value,
                        lookat_y.value,
                        lookat_z.value+z_difference]], dtype=torch.float32)
    renderer.set_camera(azimuth.value-5,
                        elevation.value,
                        distance.value, at)
    
    for i in tqdm(range(model.local_quats.shape[0])):
        #only for context and keyframes
        if (i>60) and (i%30 != 0) :
            continue
        #save 3d here
        joint_pos = get_cam_relative_3d_joints(i)
        np.save(f"{save_dir}/left_{i:04d}.npy", joint_pos.cpu().numpy())

        mesh = model.get_mesh(i)
        image = renderer.render(mesh)
        image = Image.fromarray(renderer.tensor2image(image))
        image.save(f"{save_dir}/left_{i:04d}.png")
            
    
    at = torch.tensor([[lookat_x.value,
                        lookat_y.value,
                        lookat_z.value-z_difference]], dtype=torch.float32)
    renderer.set_camera(azimuth.value+5,
                        elevation.value,
                        distance.value, at)
    
    for i in tqdm(range(model.local_quats.shape[0])):
        if (i>60) and (i%30 != 0) :
            continue
        
        joint_pos = get_cam_relative_3d_joints(i)
        np.save(f"{save_dir}/right_{i:04d}.npy", joint_pos.cpu().numpy())
        #save 3d here

        mesh = model.get_mesh(i)
        image = renderer.render(mesh)
        image = Image.fromarray(renderer.tensor2image(image))
        image.save(f"{save_dir}/right_{i:04d}.png")
            

    
def reset_value(reset_btn: widgets.Button):
    azimuth.value = 90.0
    elevation.value = 0.0
    distance.value = 1.8
    lookat_x.value = 0.0
    lookat_y.value = 0.8
    lookat_z.value = 1.3
    frame_idx.value = 0

In [13]:
torch.cuda.empty_cache()

In [14]:
output = widgets.interactive_output(draw_image, {"frame_idx": frame_idx,
                                                 "azimuth": azimuth,
                                                 "elevation": elevation,
                                                 "distance": distance,
                                                 "lookat_x": lookat_x,
                                                 "lookat_y": lookat_y,
                                                 "lookat_z": lookat_z})

save_img_button = widgets.Button(description="Save Image")
save_img_button.on_click(save_image)

save_vid_button = widgets.Button(description="Save Video")
save_vid_button.on_click(save_video)

reset_params_button = widgets.Button(description="Reset")
reset_params_button.on_click(reset_value)

control_display = widgets.VBox([reset_params_button, cam_display, at_display,
                                widgets.HBox([save_img_button, save_vid_button]), #, save_cam_button]),
                                frame_idx])
display(widgets.HBox([control_display, output]))

HBox(children=(VBox(children=(Button(description='Reset', style=ButtonStyle()), VBox(children=(FloatSlider(val…