# Deps

In [1]:
import torch
import clip
import argparse
from os.path import join as pjoin
from utils.codebook import *
import yaml
import visualization.plot_3d_global as plot_3d
from utils.motion_process import recover_from_ric 
import warnings
warnings.filterwarnings("ignore")
import os
import numpy as np

In [2]:
def get_cfg_ckpt_path(folder_path):
    if folder_path is None:
        return None, None
    else:
        ckpt_path = pjoin(folder_path, 'net_best_fid.pth')
        config_path = pjoin(folder_path, 'arguments.yaml')
    
    return config_path, ckpt_path

# Args

In [3]:
config_path = './pretrained/exp_refine_transformer/arguments.yaml'
rt2m_checkpoint_folder_path = config_path.replace('arguments.yaml', '')
use_gpu = True

with open(config_path, 'r') as f:
    arg_dict = yaml.safe_load(f)

args = argparse.Namespace(**arg_dict)

In [4]:
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cpu'), jit=False)  # Must set jit=False for training
clip_model.eval()
for p in clip_model.parameters():
    p.requires_grad = False

# Load Models

In [5]:
from models.pg_tokenizer import PoseGuidedTokenizer
dec_config, dec_checkpoint_path = get_cfg_ckpt_path(args.dec_checkpoint_folder)

if dec_config is None or dec_checkpoint_path is None:
    raise ValueError("Decoder config or checkpoint path is None. Please provide a valid folder path.")

with open(dec_config, 'r') as f:
    arg_dict = yaml.safe_load(f)

dec_args = argparse.Namespace(**arg_dict)
net = PoseGuidedTokenizer(
                    dec_args, 
                    dec_args.nb_code,                      # nb_code
                    dec_args.code_dim,                    # code_dim
                    dec_args.output_emb_width,            # output_emb_width
                    dec_args.down_t,                      # down_t
                    dec_args.stride_t,                    # stride_t
                    dec_args.width,                       # width
                    dec_args.depth,                       # depth
                    dec_args.dilation_growth_rate,        # dilation_growth_rate
                    dec_args.vq_act,                      # activation
                    dec_args.vq_norm,                     # norm
                    num_quantizers=dec_args.rvq_num_quantizers,
                    shared_codebook=dec_args.rvq_shared_codebook,
                    quantize_dropout_prob=dec_args.rvq_quantize_dropout_prob,
                    quantize_dropout_cutoff_index=dec_args.rvq_quantize_dropout_cutoff_index,
                    rvq_nb_code=dec_args.rvq_nb_code,
                    mu=dec_args.rvq_mu,
                    residual_ratio=dec_args.rvq_residual_ratio,
                    vq_loss_beta=dec_args.rvq_vq_loss_beta,
                    quantizer_type=dec_args.rvq_quantizer_type,
                    params_soft_ent_loss=dec_args.params_soft_ent_loss,
                    use_ema=(not dec_args.unuse_ema),
                    init_method=dec_args.rvq_init_method
                    )
    
print ('loading decoder checkpoint from {}'.format(dec_checkpoint_path))
ckpt = torch.load(dec_checkpoint_path, map_location='cpu')
net.load_state_dict(ckpt['net'], strict=True)

net.eval()
if use_gpu:
    net.cuda()

loading decoder checkpoint from ./pretrained/exp_pg_tokenizer/net_best_fid.pth


In [6]:
import models.t2m_trans as t2m
t2m_config, t2m_checkpoint_path = get_cfg_ckpt_path(args.t2m_checkpoint_folder)

if t2m_config is None or t2m_checkpoint_path is None:
    raise ValueError("Decoder config or checkpoint path is None. Please provide a valid folder path.")

with open(t2m_config, 'r') as f:
    arg_dict = yaml.safe_load(f)

t2m_args = argparse.Namespace(**arg_dict)
trans_net = t2m.BaseTrans(num_vq=t2m_args.nb_code, 
                        embed_dim=t2m_args.embed_dim_gpt, 
                        clip_dim=t2m_args.clip_dim, 
                        block_size=t2m_args.block_size, 
                        num_layers=t2m_args.num_layers, 
                        n_head=t2m_args.n_head_gpt, 
                        drop_out_rate=t2m_args.drop_out_rate, 
                        fc_rate=t2m_args.ff_rate)

print ('loading transformer checkpoint from {}'.format(t2m_checkpoint_path))
trans_ckpt = torch.load(t2m_checkpoint_path, map_location='cpu')
trans_net.load_state_dict(trans_ckpt['trans'], strict=True)

if use_gpu:
    trans_net.cuda()
trans_net.eval()

## Token Emb: Linear Token Embedding Selected ##
loading transformer checkpoint from ./pretrained/exp_base_transformer/net_best_fid.pth


BaseTrans(
  (trans_base): CrossCondTransBase(
    (tok_emb): Linear(in_features=394, out_features=1024, bias=True)
    (cond_emb): Linear(in_features=512, out_features=1024, bias=True)
    (pos_embedding): Embedding(62, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (blocks): Sequential(
      (0): Block(
        (ln1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (ln2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (key): Linear(in_features=1024, out_features=1024, bias=True)
          (query): Linear(in_features=1024, out_features=1024, bias=True)
          (value): Linear(in_features=1024, out_features=1024, bias=True)
          (attn_drop): Dropout(p=0.1, inplace=False)
          (resid_drop): Dropout(p=0.1, inplace=False)
          (proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (mlp): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
   

In [7]:
eval_reference = 'net_best_fid'
r_trans_config, r_trans_checkpoint_path = get_cfg_ckpt_path(rt2m_checkpoint_folder_path)

if r_trans_config is None or r_trans_checkpoint_path is None:
    raise ValueError("Residual Transformer config or checkpoint path is None. Please provide a valid folder path.")

with open(r_trans_config, 'r') as f:
    arg_dict = yaml.safe_load(f)

r_trans_args = argparse.Namespace(**arg_dict)

In [8]:
from models.rt2m_trans import RefineTrans

res_trans_net = RefineTrans(num_vq=r_trans_args.nb_code, 
                                    num_rvq=dec_args.rvq_nb_code,
                                    embed_dim=r_trans_args.embed_dim_gpt, 
                                    clip_dim=r_trans_args.clip_dim, 
                                    block_size=r_trans_args.block_size, 
                                    num_layers=r_trans_args.num_layers, 
                                    n_head=r_trans_args.n_head_gpt, 
                                    drop_out_rate=r_trans_args.drop_out_rate, 
                                    fc_rate=r_trans_args.ff_rate,
                                    num_key=11,
                                    mode=None,
                                    num_quantizer=dec_args.rvq_num_quantizers,
                                    share_weight=r_trans_args.share_weight)


print ('loading transformer checkpoint from {}'.format(r_trans_checkpoint_path))
r_trans_ckpt = torch.load(r_trans_checkpoint_path, map_location='cpu')
res_trans_net.load_state_dict(r_trans_ckpt['r_trans'], strict=True)

# eval mode로 고정
if use_gpu:
    res_trans_net.cuda()
res_trans_net.eval()
print()

loading transformer checkpoint from ./pretrained/exp_refine_transformer/net_best_fid.pth



# Test Data Load

In [9]:
from utils.word_vectorizer import WordVectorizer
w_vectorizer = WordVectorizer('./glove', 'our_vab')
from dataset import dataset_TM_eval # 

is_test = True
val_loader = dataset_TM_eval.DATALoader(args.dataname, is_test, 32, w_vectorizer, codebook_size=392,
                                        use_keywords=args.use_keywords,
                                        use_word_only=args.use_word_only,
                                        codes_folder_name=args.codes_folder_name)

val_loader_iter = dataset_TM_eval.cycle(val_loader)

# Note!: Validation Shuffle :True
# Codes Folder name: codes


100%|██████████| 4384/4384 [00:03<00:00, 1305.78it/s]

#--------- RESULT ---------#
split:test
186 files filtered
0 files occured an error
#--------- END ---------#
Pointer Pointing at 0





# Search Sample

In [10]:
split = 'test'
with open(f'./dataset/HumanML3D/{split}.txt', 'r') as f:
    split_file_names = f.read().splitlines()
print(f"number of {split} samples: {len(split_file_names)}")

number of test samples: 4384


In [13]:
import os
import numpy as np

query = ['abruptly']
text_dir = './dataset/HumanML3D/texts'

file_names = []
texts = []
for file in os.listdir(text_dir):
    if file.endswith('.txt'):
        with open(f"{text_dir}/{file}", 'r') as f:
            for idx, line in enumerate(f):
                line = line.strip()
                matched = any(phrase in line for phrase in query)
                if matched:
                    file_name = file.replace('.txt', '')
                    if file_name in split_file_names:
                        texts.append(line.split('#')[0])
                        file_names.append(file_name)
                        print("name:", file_name, "text:", line.split('#')[0])
                    else:
                        continue

searched_file_ids = [file_name for file_name in file_names]

name: 011965 text: a person abruptly stumbles forward and regains his balance as if he had been pushed from behind.
name: M013056 text: a person angrily and quickly walks then abruptly stops with force.
name: 013249 text: a person walks diagonally forward, stops abruptly, grabs his head with his left hand, walks backward briefly, and then turn to the left and walks straight.
name: M012488 text: a person abruptly staggers backwards.
name: M011965 text: a person abruptly stumbles forward and regains his balance as if he had been pushed from behind.
name: M013249 text: a person walks diagonally forward, stops abruptly, grabs his head with his right hand, walks backward briefly, and then turn to the right and walks straight.
name: 000067 text: a person casually walks abruptly steps left and then recovers to a straight walk.
name: M004472 text: a person who is standing is getting pushed by something and moving abruptly.
name: M000067 text: a person casually walks abruptly steps right and th

In [16]:
target_file_list = ['011334'] # e.g., ['M004719']
text_dir = './dataset/HumanML3D/texts'
texts = []
pose_list = []

for id in target_file_list:
    with open(f"{text_dir}/{id}.txt", 'r') as f:
        lines = f.readlines()
        for line in lines:
            sentence = line.strip().split('#')[0]
            if query[0] in sentence:
                texts.append(sentence)
                break


print(texts)

['person walks forward speedily then abruptly stops']


In [17]:
import torch
import numpy as np

m_length_list = []
search_mode = True
unit_length = 4
keyword_embeddings = None

keyword_emb_dir = './dataset/HumanML3D/keyword_embeddings'
pose_dir = './dataset/HumanML3D/new_joint_vecs'

for file_name in target_file_list:
    print("file_name:", file_name)
    emb = np.load(os.path.join(keyword_emb_dir, f'{file_name}.npy'))
    keyword_indices = np.arange(0,33,3) + np.random.randint(0,3,size = (11,))    
    keyword_data = torch.from_numpy(emb[keyword_indices]).unsqueeze(0) # 변환 작업

    pose = torch.from_numpy(np.load(os.path.join(pose_dir, f'{file_name}.npy'))).unsqueeze(0)
    pose_list.append(val_loader.dataset.forward_transform(pose))
    m_length_list.append((pose.shape[1] // unit_length) * unit_length)
    
    if keyword_embeddings is None:
        keyword_embeddings = keyword_data
    else:
        keyword_embeddings = torch.cat([keyword_embeddings, keyword_data], dim=0)

file_name: 011334


# Custom Inference

In [18]:
import clip
import datetime
now = datetime.datetime.now()
date_time = now.strftime("%Y-%m-%d %H-%M-%S")
print("Current date and time : ", date_time)

text = clip.tokenize(texts, truncate=True) # .cuda('cuda:1')
feat_clip_text = clip_model.encode_text(text).float().unsqueeze(1) 
feat_clip_text = torch.cat((feat_clip_text, keyword_embeddings.float()), dim=1)  # bs x 11+1 x 512
feat_clip_text = feat_clip_text.cuda()

split = 'test'

save_dir = f'./inference/{date_time}/{split}/'
os.makedirs(save_dir, exist_ok=True)

Current date and time :  2025-12-15 15-20-26


In [19]:
feat_clip_text.shape

torch.Size([1, 12, 512])

In [20]:
#############################################
pred_p_code_list = []
pred_r_code_list = []
name_list = target_file_list
clip_text_list = []
clip_text = texts

for k in range(len(target_file_list)):
    print(f"iter:{k}")
    pred_p_codes, pred_r_codes = res_trans_net.sample(feat_clip_text[k:k+1], trans_net) # 1 x t x code_num -> k-hot vector (bs, seq_len, 394)
    
    pred_p_code_list.append(pred_p_codes)
    pred_r_code_list.append(pred_r_codes)
    name_list.append(name[k])
    clip_text_list.append(clip_text[k])


iter:0


In [21]:
num_joints = 22

In [22]:
def postprocess_custom_inference(pose, val_loader, sentence, save_dir, sample_name, save_name):
    pred_denorm = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
    pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float(), num_joints)
    pred_xyz = pred_xyz.numpy()
    print(pred_xyz.shape)
    os.makedirs(f"{save_dir}/{sample_name}", exist_ok=True)
    np.save(f"{save_dir}/{sample_name}/{save_name}_coord.npy", pred_xyz)
    np.save(f"{save_dir}/{sample_name}/{save_name}_pose.npy", pose.cpu().detach().numpy())
    save_path = f"{save_dir}/{sample_name}/{save_name}.gif"
    plot_3d.draw_to_batch(pred_xyz, [sentence], [save_path], footer_text=f'{save_name}', footer_fontsize=15)

In [23]:
from tqdm import tqdm

for k in tqdm(range(len(pred_p_code_list))):
    sentence = clip_text_list[k]
    pred_pose_codes = pred_p_code_list[k][..., :-2]
    pred_residual_codes = pred_r_code_list[k]
    gt_pose = pose_list[k]
    name = name_list[k]

    print(f"sample: {name}")

    ##################
    gt_pose = gt_pose[:, :m_length_list[k], :]
    postprocess_custom_inference(gt_pose, val_loader, sentence, save_dir, name, 'ground_truth')

    ##################
    pred_pose = net.inference(code_indices=pred_pose_codes.float(), residual_codes=pred_residual_codes.float(), drop_out_residual_quantization=True)
    postprocess_custom_inference(pred_pose, val_loader, sentence, save_dir, name, 'wo_residual')

    #################
    pred_pose = net.inference(pred_residual_codes.float(), pred_pose_codes.float())
    postprocess_custom_inference(pred_pose, val_loader, sentence, save_dir, name, 'w_residual')

  0%|          | 0/1 [00:00<?, ?it/s]

sample: 011334
(1, 68, 22, 3)
(1, 120, 22, 3)
(1, 120, 22, 3)


100%|██████████| 1/1 [00:39<00:00, 39.04s/it]
