In [None]:
from t2m import Text2Motion
from utils.get_opt import get_opt
from utils.fixseed import fixseed

import torch
import numpy as np
from os.path import join as pjoin

In [None]:
denoiser_name = "t2m_denoiser_vpred_vaegelu"
dataset_name = "t2m"
generator = Text2Motion(denoiser_name, dataset_name)

opt = generator.opt
wrapper_opt = get_opt(opt.dataset_opt_path, torch.device("cuda"))
mean = np.load(pjoin(wrapper_opt.meta_dir, "mean.npy"))
std = np.load(pjoin(wrapper_opt.meta_dir, "std.npy"))

### Original

In [None]:
fixseed(42)
src_text = "a man is walking"
m_lens = 64
cfg_scale = 7.5
num_inference_timesteps = 50

init_noise, src_motion, (sa, ta, ca) = generator.generate(src_text,
                                                          m_lens,
                                                          cfg_scale,
                                                          num_inference_timesteps)

### Edit - 4 Different Cases

In [None]:
# edit_text = "slowly"
edit_text = "a man is walking while waving his right hand"
src_proportion = 0.2

# # case 1: mirror
# edit_motion = generator.edit(init_noise,
#                              src_text=src_text,
#                              edit_text=edit_text,
#                              edit_mode="mirror",
#                              mirror_mode="lower",
#                              cfg_scale=cfg_scale,
#                              num_inference_timesteps=num_inference_timesteps,
#                              src_sa=sa,
#                              src_ta=ta,
#                              src_ca=ca,
#                              src_proportion=src_proportion)

# # case 2: reweight
# edit_motion = generator.edit(init_noise,
#                              src_text=src_text,
#                              edit_text=src_text,
#                              edit_mode="reweight",
#                              tgt_word="high",
#                              reweight_scale=-1.0,
#                              cfg_scale=cfg_scale,
#                              num_inference_timesteps=num_inference_timesteps,
#                              src_sa=sa,
#                              src_ta=ta,
#                              src_ca=ca,
#                              src_proportion=src_proportion)

# case 3: refine
edit_motion = generator.edit(init_noise,
                             src_text=src_text,
                             edit_text=edit_text,
                             edit_mode="refine",
                             cfg_scale=cfg_scale,
                             num_inference_timesteps=num_inference_timesteps,
                             src_sa=sa,
                             src_ta=ta,
                             src_ca=ca,
                             src_proportion=src_proportion)

# # case 4: word swap
# edit_motion = generator.edit(init_noise,
#                              src_text=src_text,
#                              edit_text=edit_text,
#                              edit_mode="word_swap",
#                              cfg_scale=cfg_scale,
#                              num_inference_timesteps=num_inference_timesteps,
#                              src_sa=sa,
#                              src_ta=None,
#                              src_ca=ca,
#                              src_proportion=src_proportion,
#                              swap_src_proportion=0.2)


### Visualize

In [None]:
import os
from os.path import join as pjoin
import torch
import numpy as np
from utils.motion_process import recover_from_ric
from utils.plot_script import plot_3d_motion
from utils.get_opt import get_opt

def plot_t2m(data, text, filename):
    os.makedirs("edit_result", exist_ok=True)
    #data = data[:m_lens[0].item()]
    data = data[:m_lens]
    joint = recover_from_ric(torch.from_numpy(data).float(), opt.joints_num).numpy()
    save_path = pjoin("edit_result", f"{filename}.mp4")
    plot_3d_motion(save_path, opt.kinematic_chain, joint, title=text, fps=20)

    np.save(pjoin("edit_result", f"{filename}_pos.npy"), joint)
    np.save(pjoin("edit_result", f"{filename}_feats.npy"), data)
    
# mean and std for de-normalization
wrapper_opt = get_opt(opt.dataset_opt_path, torch.device('cuda'))
mean = np.load(pjoin(wrapper_opt.meta_dir, 'mean.npy'))
std = np.load(pjoin(wrapper_opt.meta_dir, 'std.npy'))

### Plot Motions

In [None]:
src_motion = src_motion.detach().cpu().numpy() * std + mean
plot_t2m(src_motion[0], src_text, "src")

edit_motion = edit_motion.detach().cpu().numpy() * std + mean
plot_t2m(edit_motion[0], edit_text, "edit")

### Video Visualization

In [None]:
from moviepy.editor import VideoFileClip, clips_array
src_video = VideoFileClip(pjoin("edit_result", "src.mp4"))
edit_video = VideoFileClip(pjoin("edit_result", "edit.mp4"))
final_video = clips_array([[src_video, edit_video]])
final_video.write_videofile(pjoin("edit_result", "final.mp4"))


from IPython.display import Video
Video("edit_result/final.mp4", width=800, height=400)