In [None]:
import os
from os.path import join as pjoin
import torch
from diffusers import DDIMScheduler

from utils.get_opt import get_opt
from models.vae.model import VAE
from models.denoiser.model import Denoiser
from models.denoiser.trainer import DenoiserTrainer


In [None]:
def load_vae(vae_opt):
    print(f'Loading VAE Model {vae_opt.name}')

    model = VAE(vae_opt)
    ckpt = torch.load(pjoin(vae_opt.checkpoints_dir, vae_opt.dataset_name, vae_opt.name, 'model', 'net_best_fid.tar'),
                            map_location='cpu')
    model.load_state_dict(ckpt["vae"])
    model.freeze()
    return model


def load_denoiser(opt, vae_dim):
    print(f'Loading Denoiser Model {opt.name}')
    denoiser = Denoiser(opt, vae_dim)
    ckpt = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name, 'model', 'net_best_fid.tar'),
                            map_location='cpu')
    missing_keys, unexpected_keys = denoiser.load_state_dict(ckpt["denoiser"], strict=False)

    assert len(unexpected_keys) == 0, f'Unexpected keys in denoiser model: {unexpected_keys}'
    assert all([k.startswith('clip_model.') for k in missing_keys]), f'Missing keys in denoiser model: {missing_keys}'
    
    denoiser.to(opt.device)
    return denoiser

In [None]:
# change these two names below as you want
dataset_name = "t2m"
denoiser_name = "t2m_denoiser_vpred_vaegelu" # << Change this to use different denoisers
device = torch.device("cuda")

opt = get_opt(pjoin("checkpoints", dataset_name, denoiser_name, "opt.txt"), device)
vae_opt = get_opt(pjoin("checkpoints", dataset_name, opt.vae_name, "opt.txt"), device)

opt.num_inference_timesteps = 50 # << Change the diffusion timesteps for generation as you want

# load models and scheduler
vae_model = load_vae(vae_opt).to(opt.device)
denoiser = load_denoiser(opt, vae_opt.latent_dim)
scheduler = DDIMScheduler(
    num_train_timesteps=opt.num_train_timesteps,
    beta_start=opt.beta_start,
    beta_end=opt.beta_end,
    beta_schedule=opt.beta_schedule,
    prediction_type=opt.prediction_type,
    clip_sample=False,
)

# trainer
trainer = DenoiserTrainer(opt, denoiser, vae_model, scheduler)

In [None]:
from utils.utils import attn2img
import numpy as np
from utils.fixseed import fixseed

fixseed(42)

# inputs for generation
text = [
    "a man jumps forward and then walks forward, while raising the arms to the sides",
]
m_lens = torch.tensor([100]).to(device) # << Change the length of the motion sequence as you want
motion = torch.FloatTensor(1, m_lens[0].item(), 263 if dataset_name == "t2m" else 251).to(device)

assert m_lens % 4 == 0

# setup output directory
os.makedirs("exp_attn", exist_ok=True)
motion, attns = trainer.generate((text, motion, m_lens), need_attn=True)
skel_attn, temp_attn, cross_attn = attns

# attn map
attn_vis = torch.mean(cross_attn, dim=(1, 2, 3))[1]
attn_vis = attn_vis.reshape(m_lens[0].item() // 4, 7, 77).transpose(0, 1)

In [None]:
import numpy as np
from utils.motion_process import recover_from_ric
from utils.plot_script import plot_3d_motion_attn, plot_3d_motion
from scipy.ndimage import gaussian_filter1d

def plot_t2m(data, text):
    data = data[:m_lens[0].item()]
    joint = recover_from_ric(torch.from_numpy(data).float(), opt.joints_num).numpy()
    joint = gaussian_filter1d(joint, 1, axis=0)
    save_path = pjoin("exp_attn", "video.mp4")
    plot_3d_motion(save_path, opt.kinematic_chain, joint, title=text, fps=20)
    np.save(pjoin("exp_attn", "motion.npy"), joint)
    
    
def plot_t2m_attn(data, text, attn_map, idx):
    data = data[:m_lens[0].item()]
    joint = recover_from_ric(torch.from_numpy(data).float(), opt.joints_num).numpy()
    save_path = pjoin("exp_attn", f"{idx:02d}-{text}.mp4")
    plot_3d_motion_attn(save_path, opt.kinematic_chain, joint, text, attn_map, fps=20)


# mean and std for de-normalization
wrapper_opt = get_opt(opt.dataset_opt_path, torch.device('cuda'))
mean = np.load(pjoin(wrapper_opt.meta_dir, 'mean.npy'))
std = np.load(pjoin(wrapper_opt.meta_dir, 'std.npy'))

motion_np = motion.detach().cpu().numpy() * std + mean

# text info
tokens = denoiser.clip_model.tokenize(text)
word_emb, mask, max_id = denoiser.clip_model.encode_text(text)

# plot
print(text)
os.system("rm exp_attn/*.png")
os.makedirs("exp_attn", exist_ok=True)
plot_t2m(motion_np[0], text[0])

from IPython.display import display
for j in range(max_id[0].item()):
    text_decoded0 = denoiser.clip_model.decode_text_from_tokens(tokens.input_ids[0][j])
    text_decoded1 = denoiser.clip_model.decode_text_from_tokens(tokens.input_ids[0][j+1])
    attn = attn2img(attn_vis[:, :, j], pjoin("exp_attn", f"attn-{j:02d}-{text_decoded0}.png"), title=text_decoded0)
    # attn = attn2img(attn_vis[:, :, j] + attn_vis[:, :, j+1], pjoin("exp_attn", f"attn-{j:02d}-{text_decoded0}-{text_decoded1}.png"), title=text_decoded0 + " " + text_decoded1)
    display(attn)
    # plot_t2m_attn(motion_np[0], text_decoded, torch.repeat_interleave(attn_vis[:, :, j].transpose(0,1), 4, dim=0), j)