# Motion Continuation (Inference)

## Imports

In [None]:
from matplotlib import pyplot as plt

import motion_model
import motion_synthesis
import motion_sender
import motion_gui
import motion_control

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict
import networkx as nx
import scipy.linalg as sclinalg

import os, sys, time, subprocess
import numpy as np
import math
import pickle
from time import sleep

from common import utils
from common import bvh_tools as bvh
from common import fbx_tools as fbx
from common import mocap_tools as mocap
from common.quaternion import qmul, qrot, qnormalize_np, slerp, qfix

import IPython
from IPython.display import display
import ipywidgets as widgets

In [None]:
%gui qt

## Settings

### Compute Device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

## Mocap Settings

In [None]:
mocap_file_path = "../../../Data/Mocap/"
mocap_files = ["Daniel_ChineseRoom_Take1_50fps.fbx"]
mocap_pos_scale = 1.0
mocap_fps = 50

mocap_pos_scale_gui = widgets.FloatText(mocap_pos_scale, description="Mocap Position Scale:", style={'description_width': 'initial'})
mocap_fps_gui = widgets.IntText(mocap_fps, description="Mocap FPS:", style={'description_width': 'initial'})

mocap_files_all = [f for f in os.listdir(mocap_file_path) if os.path.isfile(os.path.join(mocap_file_path, f))]
#print(mocap_files_all)

mocap_files_gui = widgets.SelectMultiple(
    options=mocap_files_all,
    value=mocap_files,  # default: first option selected; can be empty
    description='Mocap Files:',
    layout=widgets.Layout(width='400px'),
    style={'description_width': 'initial'}
)

display(mocap_pos_scale_gui)
display(mocap_fps_gui)
display(mocap_files_gui)

In [None]:
mocap_pos_scale = mocap_pos_scale_gui.value
mocap_fps = mocap_fps_gui.value
mocap_files = list(mocap_files_gui.value)

## Model Settings

In [None]:
sequence_length = 64
rnn_layer_dim = 512
rnn_layer_count = 2

sequence_length_gui = widgets.IntText(sequence_length, description="Mocap Sequence Length:", style={'description_width': 'initial'})
rnn_layer_dim_gui = widgets.IntText(rnn_layer_dim, description="RNN Layer Dim:", style={'description_width': 'initial'})
rnn_layer_count_gui = widgets.IntText(rnn_layer_count, description="RNN Layer Count:", style={'description_width': 'initial'})

display(sequence_length_gui)
display(rnn_layer_dim_gui)
display(rnn_layer_count_gui)

In [None]:
sequence_length = sequence_length_gui.value
rnn_layer_dim = rnn_layer_dim_gui.value
rnn_layer_count = rnn_layer_count_gui.value

## Training Settings

In [None]:
training_folder_path = "../../../Data/Models/MotionContinuation/"
training_folder = "ChineseRoom_Daniel"
training_epoch = 400

training_folders_all = [f for f in os.listdir(training_folder_path)]
#print(training_folders_all)

training_folders_gui = widgets.Dropdown(
    options=training_folders_all,
    value=training_folder,  # default: first option selected; can be empty
    description='Training Folder:',
    layout=widgets.Layout(width='400px'),
    style={'description_width': 'initial'}
)

training_epoch_gui = widgets.IntText(training_epoch, description="Training Epoch:", style={'description_width': 'initial'})

display(training_folders_gui)
display(training_epoch_gui)

In [None]:
training_folder = training_folders_gui.value
training_epoch = training_epoch_gui.value

## OSC Settings

## OSC Receive Settings

In [None]:
osc_receive_ip = "0.0.0.0"
osc_receive_port = 9002

osc_receive_ip_gui = widgets.Text(value=osc_receive_ip, description="OSC Receive IP:", style={'description_width': 'initial'}) 
osc_receive_port_gui = widgets.IntText(value=osc_receive_port, description="OSC Receive Port:", style={'description_width': 'initial'})

display(osc_receive_ip_gui)
display(osc_receive_port_gui)

In [None]:
osc_receive_ip = osc_receive_ip_gui.value
osc_receive_port = osc_receive_port_gui.value

## OSC Send Settings

In [None]:
osc_send_ip = "127.0.0.1"
osc_send_port = 9004

osc_send_ip_gui = widgets.Text(value=osc_send_ip, description="OSC Send IP:", style={'description_width': 'initial'}) 
osc_send_port_gui = widgets.IntText(value=osc_send_port, description="OSC Send Port:", style={'description_width': 'initial'})

display(osc_send_ip_gui)
display(osc_send_port_gui)

In [None]:
osc_send_ip = osc_send_ip_gui.value
osc_send_port = osc_send_port_gui.value

## Load Mocap Data

In [None]:
bvh_tools = bvh.BVH_Tools()
fbx_tools = fbx.FBX_Tools()
mocap_tools = mocap.Mocap_Tools()

all_mocap_data = []

for mocap_file in mocap_files:
    
    print("process file ", mocap_file)
    
    if mocap_file.endswith(".bvh") or mocap_file.endswith(".BVH"):
        bvh_data = bvh_tools.load(mocap_file_path + "/" + mocap_file)
        mocap_data = mocap_tools.bvh_to_mocap(bvh_data)
    elif mocap_file.endswith(".fbx") or mocap_file.endswith(".FBX"):
        fbx_data = fbx_tools.load(mocap_file_path + "/" + mocap_file)
        mocap_data = mocap_tools.fbx_to_mocap(fbx_data)[0] # first skeleton only
    
    mocap_data["skeleton"]["offsets"] *= mocap_pos_scale
    mocap_data["motion"]["pos_local"] *= mocap_pos_scale
    
    # set x and z offset of root joint to zero
    mocap_data["skeleton"]["offsets"][0, 0] = 0.0 
    mocap_data["skeleton"]["offsets"][0, 2] = 0.0 

    if mocap_file.endswith(".bvh") or mocap_file.endswith(".BVH"):
        mocap_data["motion"]["rot_local"] = mocap_tools.euler_to_quat_bvh(mocap_data["motion"]["rot_local_euler"], mocap_data["rot_sequence"])
    elif mocap_file.endswith(".fbx") or mocap_file.endswith(".FBX"):
        mocap_data["motion"]["rot_local"] = mocap_tools.euler_to_quat(mocap_data["motion"]["rot_local_euler"], mocap_data["rot_sequence"])

    all_mocap_data.append(mocap_data)

all_pose_sequences = []

for mocap_data in all_mocap_data:
    
    pose_sequence = mocap_data["motion"]["rot_local"].astype(np.float32)
    all_pose_sequences.append(pose_sequence)

joint_count = all_pose_sequences[0].shape[1]
joint_dim = all_pose_sequences[0].shape[2]
pose_dim = joint_count * joint_dim

## Load Model

In [None]:
rnn_weights_file = training_folder_path + training_folder + "/weights/rnn_weights_epoch_{}".format(training_epoch)

motion_model.config["input_length"] = sequence_length
motion_model.config["data_dim"] = pose_dim
motion_model.config["node_dim"] = rnn_layer_dim
motion_model.config["layer_count"] = rnn_layer_count
motion_model.config["device"] = device
motion_model.config["weights_path"] = rnn_weights_file

model = motion_model.createModel(motion_model.config) 

## Setup Motion Synthesis

In [None]:
synthesis_config  = motion_synthesis.config
synthesis_config["skeleton"] = all_mocap_data[0]["skeleton"]
synthesis_config["model"] = model
synthesis_config["seq_length"] = motion_model.config["input_length"]
synthesis_config["orig_sequences"] = all_pose_sequences
synthesis_config["orig_seq_index"] = 0
synthesis_config["device"] = motion_model.config["device"] 

synthesis = motion_synthesis.MotionSynthesis(synthesis_config)

## Create OSC Sender

In [None]:
motion_sender.config["ip"] = osc_send_ip
motion_sender.config["port"] = osc_send_port

osc_sender = motion_sender.OscSender(motion_sender.config)

## Create Application

In [None]:
from PyQt5 import QtWidgets
from PyQt5.QtCore import Qt
import pyqtgraph as pg
import pyqtgraph.opengl as gl
from pathlib import Path

motion_gui.config["synthesis"] = synthesis
motion_gui.config["sender"] = osc_sender
motion_gui.config["update_interval"] = 1.0 / mocap_fps

app = QtWidgets.QApplication(sys.argv)
gui = motion_gui.MotionGui(motion_gui.config)

# set close event
def closeEvent():
    QtWidgets.QApplication.quit()
app.lastWindowClosed.connect(closeEvent) # myExitHandler is a callable

## Create OSC Control

In [None]:
motion_control.config["motion_seq"] = pose_sequence
motion_control.config["synthesis"] = synthesis
motion_control.config["gui"] = gui
motion_control.config["ip"] = osc_receive_ip
motion_control.config["port"] = osc_receive_port

osc_control = motion_control.MotionControl(motion_control.config)

## Start Application

In [None]:
osc_control.start()
gui.show()

## Interactive Control

In [None]:
motion_sequence_index = 0
motion_frame_index = 64

motion_sequence_index_gui = widgets.IntText(value=motion_sequence_index, description="Select Motion Sequence:", style={'description_width': 'initial'})
motion_frame_index_gui = widgets.IntText(value=motion_frame_index, description="Select Motion Frame:", style={'description_width': 'initial'})

display(motion_sequence_index_gui)
display(motion_frame_index_gui)

def on_motion_sequence_index_change(value):
    global motion_sequence_index
    motion_sequence_index = value['new']
    synthesis.setOrigSeqIndex(motion_sequence_index)

def on_motion_frame_index_change(value):
    global motion_frame_index
    motion_frame_index = value['new']
    synthesis.setOrigSeqStartFrameIndex(motion_frame_index)      

motion_sequence_index_gui.observe(on_motion_sequence_index_change, names='value')
motion_frame_index_gui.observe(on_motion_frame_index_change, names='value')

## Stop OSC Control

In [None]:
osc_control.stop()