# Deps

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
import torch
import torch.nn as nn
from utils.motion_process import *
import visualization.plot_3d_global as plot_3d
import warnings
warnings.filterwarnings('ignore')
from utils.motion_process import recover_from_ric 
from os.path import join as pjoin
import numpy as np
from utils.codebook import *
from models.motion_dec import MotionDec
from models.pg_tokenizer import PoseGuidedTokenizer

In [None]:
import random
def fixseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
fixseed(1234)

In [None]:
import datetime
now = datetime.datetime.now()
timestamp = now.strftime("%Y%m%d")
use_gpu = True
use_custom_samples = True
save_dir = f'./motion_editing_result/{timestamp}'
os.makedirs(save_dir, exist_ok=True)

test_sample_num = 20
unit_length = 4

sample_ids = []
code_motions = []
feats_motions = []
int_indices = []
motion_lens = []
texts = []

# Data Load

In [None]:
from utils.word_vectorizer import WordVectorizer
w_vectorizer = WordVectorizer('./glove', 'our_vab')
from dataset import dataset_TM_eval # 

if not use_custom_samples:
    is_test = True
    val_loader = dataset_TM_eval.DATALoader('t2m', is_test, 32, w_vectorizer, codebook_size=392,
                                            use_keywords=True,
                                            use_word_only=False)

    val_loader_iter = dataset_TM_eval.cycle(val_loader)

### Search sample

In [None]:
import os
import numpy as np

split = 'test'
with open(f'./dataset/HumanML3D/{split}.txt', 'r') as f:
    split_file_names = f.read().splitlines()
print(f"number of {split} samples: {len(split_file_names)}")

example_text = ['a person lowers their arms']
text_dir = './dataset/HumanML3D/texts'

file_names = []
retrieved_texts = []
for file in os.listdir(text_dir):
    if file.endswith('.txt'):
        with open(f"{text_dir}/{file}", 'r') as f:
            for idx, line in enumerate(f):
                line = line.strip()
                matched = any(phrase in line for phrase in example_text)
                if matched:
                    file_name = file.replace('.txt', '')
                    if file_name in split_file_names:
                        retrieved_texts.append(line.split('#')[0])
                        file_names.append(file_name)
                        print("name:", file_name, "text:", line.split('#')[0])
                    else:
                        continue

searched_file_ids = [file_name for file_name in file_names]

In [None]:
motion_dir = './dataset/HumanML3D/new_joint_vecs'
motion_coord_dir = './dataset/HumanML3D/new_joints'
code_dir = './dataset/HumanML3D/codes'
text_dir = './dataset/HumanML3D/texts'

In [None]:
from utils.misc import kh2index
target_file_list = ['013474']
retrieved_texts = []
use_ref_text = False
for id in target_file_list:
    with open(f"{text_dir}/{id}.txt", 'r') as f:
        lines = f.readlines()
        for line in lines:
            sentence = line.strip().split('#')[0]
            if use_ref_text:
                if example_text[0] in sentence:
                    retrieved_texts.append(sentence)
                    break
            else:
                retrieved_texts.append(sentence)
                break

code_motions = []
feats_motions = []
int_indices = []
motion_lens = []
unit_length = 4
print(retrieved_texts)

for target_file_id in target_file_list:
    m_feats = np.load(f"{motion_dir}/{target_file_id}.npy")
    m_code = np.load(f"{code_dir}/{target_file_id}.npy")

    n_seq, _ = m_code.shape
    real_n_seq = (n_seq // unit_length) * unit_length

    int_idx = kh2index(torch.from_numpy(m_code[:real_n_seq:unit_length]).float().unsqueeze(0)).squeeze(0)
    p_length = int_idx.shape[0]
    motion_lens.append(p_length)
    int_indices.append(int_idx)
    code_motions.append(m_code[:real_n_seq:unit_length])
    feats_motions.append(m_feats[:real_n_seq])

sample_ids = target_file_list
texts = retrieved_texts

### Aggregate

In [None]:
from utils.misc import kh2index
if not use_custom_samples:
    code_motions = []
    feats_motions = []
    int_indices = []
    motion_lens = []
    
    batch = next(val_loader_iter)
    word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name, indices, keyword_embeddings, *_ = batch
    
    sample_ids = name[:test_sample_num]
    texts = clip_text[:test_sample_num]

    for i in range(test_sample_num):
        idx = i
        int_idx = kh2index(indices[idx:idx+1, :m_length[idx]:unit_length]).squeeze(0)
        p_length = int_idx.shape[0]
        motion_lens.append(p_length)
        int_indices.append(int_idx)
        code_motions.append(indices[idx, :m_length[idx]:unit_length])
        feats_motions.append(pose[idx, :m_length[idx]])


### Save Ground Truth

In [None]:
save_dir

In [None]:
from tqdm import tqdm

draw_gt = True
vis_save_dir = pjoin(save_dir, 'gt', 'gif')
npy_pose_save_dir = pjoin(save_dir, 'gt', 'npy_pose')
npy_coords_save_dir = pjoin(save_dir, 'gt', 'npy_coords')
npy_codes_save_dir = pjoin(save_dir, 'gt', 'codes')
texts_save_dir = pjoin(save_dir, 'gt', 'texts')

print("Saving results to:", save_dir)

os.makedirs(vis_save_dir, exist_ok=True)
os.makedirs(npy_pose_save_dir, exist_ok=True)
os.makedirs(npy_coords_save_dir, exist_ok=True)
os.makedirs(npy_codes_save_dir, exist_ok=True)
os.makedirs(texts_save_dir, exist_ok=True)

# ground truth pose
for idx, id in tqdm(enumerate(sample_ids)):
    print(f"## id:{id} ##")
    pose = feats_motions[idx]
    
    if not use_custom_samples:
        gt_denorm = val_loader.dataset.inv_transform(pose.cpu().numpy())
    else:
        gt_denorm = pose
    np.save(f"{npy_pose_save_dir}/{id}.npy", gt_denorm)

    gt_pose_xyz = recover_from_ric(torch.from_numpy(gt_denorm).float(), 22).numpy()
    np.save(f"{npy_coords_save_dir}/{id}.npy", gt_pose_xyz)

    code_ = code_motions[idx]
    np.save(f"{npy_codes_save_dir}/{id}.npy", code_) 

    sentence = texts[idx]
    with open(f"{texts_save_dir}/{id}.txt", 'w') as f:
        f.write(sentence)

    gt_pose_xyz = gt_pose_xyz[None, :]
    gt_save_path = f"{vis_save_dir}/{id}.gif"
    if draw_gt:
        plot_3d.draw_to_batch(gt_pose_xyz, [sentence], [gt_save_path], footer_text='Ground Truth', footer_fontsize=15)

# Prepare Prompt

In [None]:
texts

In [None]:
editing_scenario = {
    1: {
        "text": "the person put his hands on his knees.",
        "scenario": "Bend both arms more deeply"
    },
}


'''editing_scenario = {
    1: {
        "text": "the person put his hands on his knees.",
        "scenario": "Bend both arms more deeply"
    },
    2: {
        "text": "a person is walking forward",
        "scenario": "Raise left hand at the end."
    },
    3: {
        "text": "this person bends forward as if to bow.",
        "scenario": "Bring both hands closer together."
    },
    4: {
        "text": "a person waves their arms over their heads.",
        "scenario": " Bend your knees more deeply"
    },
    5: {
        "text": "a man stands and raises both of his arms overhead and then makes up and downwards movements.",
        "scenario": "Bend both arms more deeply"
    }
}'''

print()

In [None]:

p_table1 = {0: 'L-knee angle',
 1: 'R-knee angle',
 2: 'L-elbow angle',
 3: 'R-elbow angle',
 4: 'L-elbow vs R-elbow distance',
 5: 'L-hand vs R-hand distance',
 6: 'L-knee vs R-knee distance',
 7: 'L-foot vs R-foot distance',
 8: 'L-hand vs L-shoulder distance',
 9: 'L-hand vs R-shoulder distance',
 10: 'R-hand vs R-shoulder distance',
 11: 'R-hand vs L-shoulder distance',
 12: 'L-hand vs R-elbow distance',
 13: 'R-hand vs L-elbow distance',
 14: 'L-hand vs L-knee distance',
 15: 'L-hand vs R-knee distance',
 16: 'R-hand vs R-knee distance',
 17: 'R-hand vs L-knee distance',
 18: 'L-hand vs L-foot distance',
 19: 'L-hand vs R-foot distance',
 20: 'R-hand vs R-foot distance',
 21: 'R-hand vs L-foot distance',
 22: 'L-hand vs R-hand relX',
 23: 'neck vs pelvis relX',
 24: 'L-hand vs L-shoulder relX',
 25: 'R-hand vs R-shoulder relX',
 26: 'L-foot vs L-hip relX',
 27: 'R-foot vs R-hip relX',
 28: 'L-shoulder vs R-shoulder relY',
 29: 'L-elbow vs R-elbow relY',
 30: 'L-hand vs R-hand relY',
 31: 'L-ankle vs neck relY',
 32: 'R-ankle vs neck relY',
 33: 'L-knee vs R-knee relY',
 34: 'L-hip vs L-knee relY',
 35: 'R-hip vs R-knee relY',
 36: 'L-hand vs L-shoulder relY',
 37: 'R-hand vs R-shoulder relY',
 38: 'L-foot vs L-hip relY',
 39: 'R-foot vs R-hip relY',
 40: 'L-wrist vs neck relY',
 41: 'R-wrist vs neck relY',
 42: 'L-hand vs L-hip relY',
 43: 'R-hand vs R-hip relY',
 44: 'L-shoulder vs R-shoulder relZ',
 45: 'L-elbow vs R-elbow relZ',
 46: 'L-hand vs R-hand relZ',
 47: 'L-knee vs R-knee relZ',
 48: 'neck vs pelvis relZ',
 49: 'L-hand vs torso relZ',
 50: 'R-hand vs torso relZ',
 51: 'L-foot vs torso relZ',
 52: 'R-foot vs torso relZ',
 53: 'L-hip vs L-knee relV',
 54: 'R-hip vs R-knee relV',
 55: 'L-knee vs L-ankle relV',
 56: 'R-knee vs R-ankle relV',
 57: 'L-shoulder vs L-elbow relV',
 58: 'R-shoulder vs R-elbow relV',
 59: 'L-elbow vs L-wrist relV',
 60: 'R-elbow vs R-wrist relV',
 61: 'pelvis vs L-shoulder relV',
 62: 'pelvis vs R-shoulder relV',
 63: 'pelvis vs neck relV',
 64: 'L-hand vs R-hand relV',
 65: 'L-foot vs R-foot relV',
 66: 'L-knee ground',
 67: 'R-knee ground',
 68: 'L-foot ground',
 69: 'R-foot ground'}
p_table2 = {0: 'L-knee angle 0 (bent to almost 10 degrees)',
 1: 'L-knee angle 1 (bent to almost 20 degrees)',
 2: 'L-knee angle 2 (bent to almost 30 degrees)',
 3: 'L-knee angle 3 (bent to almost 40 degrees)',
 4: 'L-knee angle 4 (bent to almost 50 degrees)',
 5: 'L-knee angle 5 (bent to almost 60 degrees)',
 6: 'L-knee angle 6 (bent to almost 70 degrees)',
 7: 'L-knee angle 7 (bent to almost 80 degrees)',
 8: 'L-knee angle 8 (bent to almost 90 degrees)',
 9: 'L-knee angle 9 (bent to almost 100 degrees)',
 10: 'L-knee angle 10 (bent to almost 110 degrees)',
 11: 'L-knee angle 11 (bent to almost 120 degrees)',
 12: 'L-knee angle 12 (bent to almost 130 degrees)',
 13: 'L-knee angle 13 (bent to almost 140 degrees)',
 14: 'L-knee angle 14 (bent to almost 150 degrees)',
 15: 'L-knee angle 15 (bent to almost 160 degrees)',
 16: 'L-knee angle 16 (bent to almost 170 degrees)',
 17: 'L-knee angle 17 (straight)',
 18: 'R-knee angle 0 (bent to almost 10 degrees)',
 19: 'R-knee angle 1 (bent to almost 20 degrees)',
 20: 'R-knee angle 2 (bent to almost 30 degrees)',
 21: 'R-knee angle 3 (bent to almost 40 degrees)',
 22: 'R-knee angle 4 (bent to almost 50 degrees)',
 23: 'R-knee angle 5 (bent to almost 60 degrees)',
 24: 'R-knee angle 6 (bent to almost 70 degrees)',
 25: 'R-knee angle 7 (bent to almost 80 degrees)',
 26: 'R-knee angle 8 (bent to almost 90 degrees)',
 27: 'R-knee angle 9 (bent to almost 100 degrees)',
 28: 'R-knee angle 10 (bent to almost 110 degrees)',
 29: 'R-knee angle 11 (bent to almost 120 degrees)',
 30: 'R-knee angle 12 (bent to almost 130 degrees)',
 31: 'R-knee angle 13 (bent to almost 140 degrees)',
 32: 'R-knee angle 14 (bent to almost 150 degrees)',
 33: 'R-knee angle 15 (bent to almost 160 degrees)',
 34: 'R-knee angle 16 (bent to almost 170 degrees)',
 35: 'R-knee angle 17 (straight)',
 36: 'L-elbow angle 0 (bent to almost 10 degrees)',
 37: 'L-elbow angle 1 (bent to almost 20 degrees)',
 38: 'L-elbow angle 2 (bent to almost 30 degrees)',
 39: 'L-elbow angle 3 (bent to almost 40 degrees)',
 40: 'L-elbow angle 4 (bent to almost 50 degrees)',
 41: 'L-elbow angle 5 (bent to almost 60 degrees)',
 42: 'L-elbow angle 6 (bent to almost 70 degrees)',
 43: 'L-elbow angle 7 (bent to almost 80 degrees)',
 44: 'L-elbow angle 8 (bent to almost 90 degrees)',
 45: 'L-elbow angle 9 (bent to almost 100 degrees)',
 46: 'L-elbow angle 10 (bent to almost 110 degrees)',
 47: 'L-elbow angle 11 (bent to almost 120 degrees)',
 48: 'L-elbow angle 12 (bent to almost 130 degrees)',
 49: 'L-elbow angle 13 (bent to almost 140 degrees)',
 50: 'L-elbow angle 14 (bent to almost 150 degrees)',
 51: 'L-elbow angle 15 (bent to almost 160 degrees)',
 52: 'L-elbow angle 16 (bent to almost 170 degrees)',
 53: 'L-elbow angle 17 (straight)',
 54: 'R-elbow angle 0 (bent to almost 10 degrees)',
 55: 'R-elbow angle 1 (bent to almost 20 degrees)',
 56: 'R-elbow angle 2 (bent to almost 30 degrees)',
 57: 'R-elbow angle 3 (bent to almost 40 degrees)',
 58: 'R-elbow angle 4 (bent to almost 50 degrees)',
 59: 'R-elbow angle 5 (bent to almost 60 degrees)',
 60: 'R-elbow angle 6 (bent to almost 70 degrees)',
 61: 'R-elbow angle 7 (bent to almost 80 degrees)',
 62: 'R-elbow angle 8 (bent to almost 90 degrees)',
 63: 'R-elbow angle 9 (bent to almost 100 degrees)',
 64: 'R-elbow angle 10 (bent to almost 110 degrees)',
 65: 'R-elbow angle 11 (bent to almost 120 degrees)',
 66: 'R-elbow angle 12 (bent to almost 130 degrees)',
 67: 'R-elbow angle 13 (bent to almost 140 degrees)',
 68: 'R-elbow angle 14 (bent to almost 150 degrees)',
 69: 'R-elbow angle 15 (bent to almost 160 degrees)',
 70: 'R-elbow angle 16 (bent to almost 170 degrees)',
 71: 'R-elbow angle 17 (straight)',
 72: 'L-elbow vs R-elbow distance 0 (very close)',
 73: 'L-elbow vs R-elbow distance 1 (slightly close)',
 74: 'L-elbow vs R-elbow distance 2 (close)',
 75: 'L-elbow vs R-elbow distance 3 (almost shoulder width apart)',
 76: 'L-elbow vs R-elbow distance 4 (shoulder width apart)',
 77: 'L-elbow vs R-elbow distance 5 (almost spread)',
 78: 'L-elbow vs R-elbow distance 6 (spread)',
 79: 'L-elbow vs R-elbow distance 7 (slightly wide)',
 80: 'L-elbow vs R-elbow distance 8 (wide)',
 81: 'L-elbow vs R-elbow distance 9 (very wide)',
 82: 'L-hand vs R-hand distance 0 (very close)',
 83: 'L-hand vs R-hand distance 1 (slightly close)',
 84: 'L-hand vs R-hand distance 2 (close)',
 85: 'L-hand vs R-hand distance 3 (almost shoulder width apart)',
 86: 'L-hand vs R-hand distance 4 (shoulder width apart)',
 87: 'L-hand vs R-hand distance 5 (almost spread)',
 88: 'L-hand vs R-hand distance 6 (spread)',
 89: 'L-hand vs R-hand distance 7 (slightly wide)',
 90: 'L-hand vs R-hand distance 8 (wide)',
 91: 'L-hand vs R-hand distance 9 (very wide)',
 92: 'L-knee vs R-knee distance 0 (very close)',
 93: 'L-knee vs R-knee distance 1 (slightly close)',
 94: 'L-knee vs R-knee distance 2 (close)',
 95: 'L-knee vs R-knee distance 3 (almost shoulder width apart)',
 96: 'L-knee vs R-knee distance 4 (shoulder width apart)',
 97: 'L-knee vs R-knee distance 5 (almost spread)',
 98: 'L-knee vs R-knee distance 6 (spread)',
 99: 'L-knee vs R-knee distance 7 (slightly wide)',
 100: 'L-knee vs R-knee distance 8 (wide)',
 101: 'L-knee vs R-knee distance 9 (very wide)',
 102: 'L-foot vs R-foot distance 0 (very close)',
 103: 'L-foot vs R-foot distance 1 (slightly close)',
 104: 'L-foot vs R-foot distance 2 (close)',
 105: 'L-foot vs R-foot distance 3 (almost shoulder width apart)',
 106: 'L-foot vs R-foot distance 4 (shoulder width apart)',
 107: 'L-foot vs R-foot distance 5 (almost spread)',
 108: 'L-foot vs R-foot distance 6 (spread)',
 109: 'L-foot vs R-foot distance 7 (slightly wide)',
 110: 'L-foot vs R-foot distance 8 (wide)',
 111: 'L-foot vs R-foot distance 9 (very wide)',
 112: 'L-hand vs L-shoulder distance 0 (very close)',
 113: 'L-hand vs L-shoulder distance 1 (slightly close)',
 114: 'L-hand vs L-shoulder distance 2 (close)',
 115: 'L-hand vs L-shoulder distance 3 (almost shoulder width apart)',
 116: 'L-hand vs L-shoulder distance 4 (shoulder width apart)',
 117: 'L-hand vs L-shoulder distance 5 (almost spread)',
 118: 'L-hand vs L-shoulder distance 6 (spread)',
 119: 'L-hand vs L-shoulder distance 7 (slightly wide)',
 120: 'L-hand vs L-shoulder distance 8 (wide)',
 121: 'L-hand vs L-shoulder distance 9 (very wide)',
 122: 'L-hand vs R-shoulder distance 0 (very close)',
 123: 'L-hand vs R-shoulder distance 1 (slightly close)',
 124: 'L-hand vs R-shoulder distance 2 (close)',
 125: 'L-hand vs R-shoulder distance 3 (almost shoulder width apart)',
 126: 'L-hand vs R-shoulder distance 4 (shoulder width apart)',
 127: 'L-hand vs R-shoulder distance 5 (almost spread)',
 128: 'L-hand vs R-shoulder distance 6 (spread)',
 129: 'L-hand vs R-shoulder distance 7 (slightly wide)',
 130: 'L-hand vs R-shoulder distance 8 (wide)',
 131: 'L-hand vs R-shoulder distance 9 (very wide)',
 132: 'R-hand vs R-shoulder distance 0 (very close)',
 133: 'R-hand vs R-shoulder distance 1 (slightly close)',
 134: 'R-hand vs R-shoulder distance 2 (close)',
 135: 'R-hand vs R-shoulder distance 3 (almost shoulder width apart)',
 136: 'R-hand vs R-shoulder distance 4 (shoulder width apart)',
 137: 'R-hand vs R-shoulder distance 5 (almost spread)',
 138: 'R-hand vs R-shoulder distance 6 (spread)',
 139: 'R-hand vs R-shoulder distance 7 (slightly wide)',
 140: 'R-hand vs R-shoulder distance 8 (wide)',
 141: 'R-hand vs R-shoulder distance 9 (very wide)',
 142: 'R-hand vs L-shoulder distance 0 (very close)',
 143: 'R-hand vs L-shoulder distance 1 (slightly close)',
 144: 'R-hand vs L-shoulder distance 2 (close)',
 145: 'R-hand vs L-shoulder distance 3 (almost shoulder width apart)',
 146: 'R-hand vs L-shoulder distance 4 (shoulder width apart)',
 147: 'R-hand vs L-shoulder distance 5 (almost spread)',
 148: 'R-hand vs L-shoulder distance 6 (spread)',
 149: 'R-hand vs L-shoulder distance 7 (slightly wide)',
 150: 'R-hand vs L-shoulder distance 8 (wide)',
 151: 'R-hand vs L-shoulder distance 9 (very wide)',
 152: 'L-hand vs R-elbow distance 0 (very close)',
 153: 'L-hand vs R-elbow distance 1 (slightly close)',
 154: 'L-hand vs R-elbow distance 2 (close)',
 155: 'L-hand vs R-elbow distance 3 (almost shoulder width apart)',
 156: 'L-hand vs R-elbow distance 4 (shoulder width apart)',
 157: 'L-hand vs R-elbow distance 5 (almost spread)',
 158: 'L-hand vs R-elbow distance 6 (spread)',
 159: 'L-hand vs R-elbow distance 7 (slightly wide)',
 160: 'L-hand vs R-elbow distance 8 (wide)',
 161: 'L-hand vs R-elbow distance 9 (very wide)',
 162: 'R-hand vs L-elbow distance 0 (very close)',
 163: 'R-hand vs L-elbow distance 1 (slightly close)',
 164: 'R-hand vs L-elbow distance 2 (close)',
 165: 'R-hand vs L-elbow distance 3 (almost shoulder width apart)',
 166: 'R-hand vs L-elbow distance 4 (shoulder width apart)',
 167: 'R-hand vs L-elbow distance 5 (almost spread)',
 168: 'R-hand vs L-elbow distance 6 (spread)',
 169: 'R-hand vs L-elbow distance 7 (slightly wide)',
 170: 'R-hand vs L-elbow distance 8 (wide)',
 171: 'R-hand vs L-elbow distance 9 (very wide)',
 172: 'L-hand vs L-knee distance 0 (very close)',
 173: 'L-hand vs L-knee distance 1 (slightly close)',
 174: 'L-hand vs L-knee distance 2 (close)',
 175: 'L-hand vs L-knee distance 3 (almost shoulder width apart)',
 176: 'L-hand vs L-knee distance 4 (shoulder width apart)',
 177: 'L-hand vs L-knee distance 5 (almost spread)',
 178: 'L-hand vs L-knee distance 6 (spread)',
 179: 'L-hand vs L-knee distance 7 (slightly wide)',
 180: 'L-hand vs L-knee distance 8 (wide)',
 181: 'L-hand vs L-knee distance 9 (very wide)',
 182: 'L-hand vs R-knee distance 0 (very close)',
 183: 'L-hand vs R-knee distance 1 (slightly close)',
 184: 'L-hand vs R-knee distance 2 (close)',
 185: 'L-hand vs R-knee distance 3 (almost shoulder width apart)',
 186: 'L-hand vs R-knee distance 4 (shoulder width apart)',
 187: 'L-hand vs R-knee distance 5 (almost spread)',
 188: 'L-hand vs R-knee distance 6 (spread)',
 189: 'L-hand vs R-knee distance 7 (slightly wide)',
 190: 'L-hand vs R-knee distance 8 (wide)',
 191: 'L-hand vs R-knee distance 9 (very wide)',
 192: 'R-hand vs R-knee distance 0 (very close)',
 193: 'R-hand vs R-knee distance 1 (slightly close)',
 194: 'R-hand vs R-knee distance 2 (close)',
 195: 'R-hand vs R-knee distance 3 (almost shoulder width apart)',
 196: 'R-hand vs R-knee distance 4 (shoulder width apart)',
 197: 'R-hand vs R-knee distance 5 (almost spread)',
 198: 'R-hand vs R-knee distance 6 (spread)',
 199: 'R-hand vs R-knee distance 7 (slightly wide)',
 200: 'R-hand vs R-knee distance 8 (wide)',
 201: 'R-hand vs R-knee distance 9 (very wide)',
 202: 'R-hand vs L-knee distance 0 (very close)',
 203: 'R-hand vs L-knee distance 1 (slightly close)',
 204: 'R-hand vs L-knee distance 2 (close)',
 205: 'R-hand vs L-knee distance 3 (almost shoulder width apart)',
 206: 'R-hand vs L-knee distance 4 (shoulder width apart)',
 207: 'R-hand vs L-knee distance 5 (almost spread)',
 208: 'R-hand vs L-knee distance 6 (spread)',
 209: 'R-hand vs L-knee distance 7 (slightly wide)',
 210: 'R-hand vs L-knee distance 8 (wide)',
 211: 'R-hand vs L-knee distance 9 (very wide)',
 212: 'L-hand vs L-foot distance 0 (very close)',
 213: 'L-hand vs L-foot distance 1 (slightly close)',
 214: 'L-hand vs L-foot distance 2 (close)',
 215: 'L-hand vs L-foot distance 3 (almost shoulder width apart)',
 216: 'L-hand vs L-foot distance 4 (shoulder width apart)',
 217: 'L-hand vs L-foot distance 5 (almost spread)',
 218: 'L-hand vs L-foot distance 6 (spread)',
 219: 'L-hand vs L-foot distance 7 (slightly wide)',
 220: 'L-hand vs L-foot distance 8 (wide)',
 221: 'L-hand vs L-foot distance 9 (very wide)',
 222: 'L-hand vs R-foot distance 0 (very close)',
 223: 'L-hand vs R-foot distance 1 (slightly close)',
 224: 'L-hand vs R-foot distance 2 (close)',
 225: 'L-hand vs R-foot distance 3 (almost shoulder width apart)',
 226: 'L-hand vs R-foot distance 4 (shoulder width apart)',
 227: 'L-hand vs R-foot distance 5 (almost spread)',
 228: 'L-hand vs R-foot distance 6 (spread)',
 229: 'L-hand vs R-foot distance 7 (slightly wide)',
 230: 'L-hand vs R-foot distance 8 (wide)',
 231: 'L-hand vs R-foot distance 9 (very wide)',
 232: 'R-hand vs R-foot distance 0 (very close)',
 233: 'R-hand vs R-foot distance 1 (slightly close)',
 234: 'R-hand vs R-foot distance 2 (close)',
 235: 'R-hand vs R-foot distance 3 (almost shoulder width apart)',
 236: 'R-hand vs R-foot distance 4 (shoulder width apart)',
 237: 'R-hand vs R-foot distance 5 (almost spread)',
 238: 'R-hand vs R-foot distance 6 (spread)',
 239: 'R-hand vs R-foot distance 7 (slightly wide)',
 240: 'R-hand vs R-foot distance 8 (wide)',
 241: 'R-hand vs R-foot distance 9 (very wide)',
 242: 'R-hand vs L-foot distance 0 (very close)',
 243: 'R-hand vs L-foot distance 1 (slightly close)',
 244: 'R-hand vs L-foot distance 2 (close)',
 245: 'R-hand vs L-foot distance 3 (almost shoulder width apart)',
 246: 'R-hand vs L-foot distance 4 (shoulder width apart)',
 247: 'R-hand vs L-foot distance 5 (almost spread)',
 248: 'R-hand vs L-foot distance 6 (spread)',
 249: 'R-hand vs L-foot distance 7 (slightly wide)',
 250: 'R-hand vs L-foot distance 8 (wide)',
 251: 'R-hand vs L-foot distance 9 (very wide)',
 252: 'L-hand vs R-hand relX at the right of',
 253: 'L-hand vs R-hand relX x-ignored',
 254: 'L-hand vs R-hand relX at the left of',
 255: 'neck vs pelvis relX at the right of',
 256: 'neck vs pelvis relX x-ignored',
 257: 'neck vs pelvis relX at the left of',
 258: 'L-hand vs L-shoulder relX at the right of',
 259: 'L-hand vs L-shoulder relX x-ignored',
 260: 'L-hand vs L-shoulder relX at the left of',
 261: 'R-hand vs R-shoulder relX at the right of',
 262: 'R-hand vs R-shoulder relX x-ignored',
 263: 'R-hand vs R-shoulder relX at the left of',
 264: 'L-foot vs L-hip relX at the right of',
 265: 'L-foot vs L-hip relX x-ignored',
 266: 'L-foot vs L-hip relX at the left of',
 267: 'R-foot vs R-hip relX at the right of',
 268: 'R-foot vs R-hip relX x-ignored',
 269: 'R-foot vs R-hip relX at the left of',
 270: 'L-shoulder vs R-shoulder relY below',
 271: 'L-shoulder vs R-shoulder relY y-ignored',
 272: 'L-shoulder vs R-shoulder relY above',
 273: 'L-elbow vs R-elbow relY below',
 274: 'L-elbow vs R-elbow relY y-ignored',
 275: 'L-elbow vs R-elbow relY above',
 276: 'L-hand vs R-hand relY below',
 277: 'L-hand vs R-hand relY y-ignored',
 278: 'L-hand vs R-hand relY above',
 279: 'L-ankle vs neck relY below',
 280: 'L-ankle vs neck relY y-ignored',
 281: 'L-ankle vs neck relY above',
 282: 'R-ankle vs neck relY below',
 283: 'R-ankle vs neck relY y-ignored',
 284: 'R-ankle vs neck relY above',
 285: 'L-knee vs R-knee relY below',
 286: 'L-knee vs R-knee relY y-ignored',
 287: 'L-knee vs R-knee relY above',
 288: 'L-hip vs L-knee relY below',
 289: 'L-hip vs L-knee relY y-ignored',
 290: 'L-hip vs L-knee relY above',
 291: 'R-hip vs R-knee relY below',
 292: 'R-hip vs R-knee relY y-ignored',
 293: 'R-hip vs R-knee relY above',
 294: 'L-hand vs L-shoulder relY below',
 295: 'L-hand vs L-shoulder relY y-ignored',
 296: 'L-hand vs L-shoulder relY above',
 297: 'R-hand vs R-shoulder relY below',
 298: 'R-hand vs R-shoulder relY y-ignored',
 299: 'R-hand vs R-shoulder relY above',
 300: 'L-foot vs L-hip relY below',
 301: 'L-foot vs L-hip relY y-ignored',
 302: 'L-foot vs L-hip relY above',
 303: 'R-foot vs R-hip relY below',
 304: 'R-foot vs R-hip relY y-ignored',
 305: 'R-foot vs R-hip relY above',
 306: 'L-wrist vs neck relY below',
 307: 'L-wrist vs neck relY y-ignored',
 308: 'L-wrist vs neck relY above',
 309: 'R-wrist vs neck relY below',
 310: 'R-wrist vs neck relY y-ignored',
 311: 'R-wrist vs neck relY above',
 312: 'L-hand vs L-hip relY below',
 313: 'L-hand vs L-hip relY y-ignored',
 314: 'L-hand vs L-hip relY above',
 315: 'R-hand vs R-hip relY below',
 316: 'R-hand vs R-hip relY y-ignored',
 317: 'R-hand vs R-hip relY above',
 318: 'L-shoulder vs R-shoulder relZ behind',
 319: 'L-shoulder vs R-shoulder relZ z-ignored',
 320: 'L-shoulder vs R-shoulder relZ in front of',
 321: 'L-elbow vs R-elbow relZ behind',
 322: 'L-elbow vs R-elbow relZ z-ignored',
 323: 'L-elbow vs R-elbow relZ in front of',
 324: 'L-hand vs R-hand relZ behind',
 325: 'L-hand vs R-hand relZ z-ignored',
 326: 'L-hand vs R-hand relZ in front of',
 327: 'L-knee vs R-knee relZ behind',
 328: 'L-knee vs R-knee relZ z-ignored',
 329: 'L-knee vs R-knee relZ in front of',
 330: 'neck vs pelvis relZ behind',
 331: 'neck vs pelvis relZ z-ignored',
 332: 'neck vs pelvis relZ in front of',
 333: 'L-hand vs torso relZ behind',
 334: 'L-hand vs torso relZ z-ignored',
 335: 'L-hand vs torso relZ in front of',
 336: 'R-hand vs torso relZ behind',
 337: 'R-hand vs torso relZ z-ignored',
 338: 'R-hand vs torso relZ in front of',
 339: 'L-foot vs torso relZ behind',
 340: 'L-foot vs torso relZ z-ignored',
 341: 'L-foot vs torso relZ in front of',
 342: 'R-foot vs torso relZ behind',
 343: 'R-foot vs torso relZ z-ignored',
 344: 'R-foot vs torso relZ in front of',
 345: 'L-hip vs L-knee relV vertical',
 346: 'L-hip vs L-knee relV ignored',
 347: 'L-hip vs L-knee relV horizontal',
 348: 'R-hip vs R-knee relV vertical',
 349: 'R-hip vs R-knee relV ignored',
 350: 'R-hip vs R-knee relV horizontal',
 351: 'L-knee vs L-ankle relV vertical',
 352: 'L-knee vs L-ankle relV ignored',
 353: 'L-knee vs L-ankle relV horizontal',
 354: 'R-knee vs R-ankle relV vertical',
 355: 'R-knee vs R-ankle relV ignored',
 356: 'R-knee vs R-ankle relV horizontal',
 357: 'L-shoulder vs L-elbow relV vertical',
 358: 'L-shoulder vs L-elbow relV ignored',
 359: 'L-shoulder vs L-elbow relV horizontal',
 360: 'R-shoulder vs R-elbow relV vertical',
 361: 'R-shoulder vs R-elbow relV ignored',
 362: 'R-shoulder vs R-elbow relV horizontal',
 363: 'L-elbow vs L-wrist relV vertical',
 364: 'L-elbow vs L-wrist relV ignored',
 365: 'L-elbow vs L-wrist relV horizontal',
 366: 'R-elbow vs R-wrist relV vertical',
 367: 'R-elbow vs R-wrist relV ignored',
 368: 'R-elbow vs R-wrist relV horizontal',
 369: 'pelvis vs L-shoulder relV vertical',
 370: 'pelvis vs L-shoulder relV ignored',
 371: 'pelvis vs L-shoulder relV horizontal',
 372: 'pelvis vs R-shoulder relV vertical',
 373: 'pelvis vs R-shoulder relV ignored',
 374: 'pelvis vs R-shoulder relV horizontal',
 375: 'pelvis vs neck relV vertical',
 376: 'pelvis vs neck relV ignored',
 377: 'pelvis vs neck relV horizontal',
 378: 'L-hand vs R-hand relV vertical',
 379: 'L-hand vs R-hand relV ignored',
 380: 'L-hand vs R-hand relV horizontal',
 381: 'L-foot vs R-foot relV vertical',
 382: 'L-foot vs R-foot relV ignored',
 383: 'L-foot vs R-foot relV horizontal',
 384: 'L-knee ground on the ground',
 385: 'L-knee ground ground-ignored',
 386: 'R-knee ground on the ground',
 387: 'R-knee ground ground-ignored',
 388: 'L-foot ground on the ground',
 389: 'L-foot ground ground-ignored',
 390: 'R-foot ground on the ground',
 391: 'R-foot ground ground-ignored'}

In [None]:
import json
import sys
np.set_printoptions(threshold=sys.maxsize)
torch.set_printoptions(threshold=float('inf'))

"""
- Q1. code_prompt_1: Identify which joint states can be affected by the editing instruction.
- Q2. frame_prompt_1: Use the joint states to determine which temporal segment should be edited.
- Q3. code_prompt_2: Decide how to edit the code sequence of length p_length.
- Q4. frame_prompt_2: Determine the frame range to edit.

- Find the joint states that are mentioned in common in Q1 and Q2.
- Based on those joint states, decide how to edit the code for Q3.
- Decide which frame range to edit.
"""

for idx in range(1, len(editing_scenario) + 1):
    save_prompt_dir = pjoin(save_dir, f'prompt/{sample_ids[idx-1]}')
    os.makedirs(save_prompt_dir, exist_ok=True)
    p_code = int_indices[idx-1].numpy()
    p_length = motion_lens[idx-1]
    p_details = editing_scenario[idx]["text"]
    p_edit = editing_scenario[idx]["scenario"]
    
    frame_prompt_1 = f"""

    Motion is represented by a set of joint states, defined as follows:
    Table 1 Joint State Meanings (Key: Joint State Index, Value: Joint State Meaning): {p_table1}
    Given the edit instruction: {p_edit}
    Return a semi-colon separated sequence of the ids of the joint states you will need to examine
    in order to determine the starting and ending frame of a motion sequence that will be affected
    by the edit instruction.
    Format example: 0;1;5;9. Do not reply anything else.

    """

    frame_prompt_2 = f"""

    You will be provided with a text description of the motion, a motion code sequence and a motion
    edit instruction. You are be required to determine the starting and ending frame of the sequence
    that will be affected by the edit. Here is what you need to know about the encoding of the motion
    sequences:
    The motion is represented a number of time frames, each time frame contains a set of joint states,
    each joint state contains a code value. The definitions are:
    Table 1 Joint State Meanings (Key: Joint State Index, Value: Joint State Meaning): {p_table1}
    Table 2 Code Meaning (Key: Code ID, Value: Code Meaning): {p_table2}
    Rules: smaller angles indicates more bending.
    The motion code sequence is: {{p_code}}
    The total number of time frames is {p_length}
    The text description is: {p_details}
    The edit instruction is: {p_edit}
    Return the starting index and ending index of the segment that is affected by the edit, separated
    by semi-colon, if the edit affects the overall movement, select the entire sequence. Format
    example: 0;19. Do not reply anything else.

    """

    code_prompt_1 = f"""

    Motion is represented by a set of joint states, defined as follows:
    Table 1 Joint State Meanings (Key: Joint State Index, Value: Joint State Meaning): {p_table1}
    Given the edit instruction: {p_edit}
    Return a semi-colon separated sequence of the ids of the joint states you may be affected by the
    edit instruction. Format example: 0;1;5;9. Do not reply anything else.

    """

    code_prompt_2 = f"""

    You will be provided with a text description of the motion, a motion code sequence for a given
    joint state and a motion edit instruction. You will be required to determine how to modify the
    codes within the provided sequence accordingly.
    Here is what you need to know about the encoding of the motion sequences: The motion is
    represented as a list of joint states of length T, T is the number time frames. Each joint
    state contains a code value. The usable codes are defined as follows:
    Table 1 Usable Code Meaning (Key: Code ID, Value: Code Meaning): {p_table2}
    Rules: smaller angles indicates more bending.
    You are given this motion code sequence for the joint state {{joint_state}}, it has already been sliced
    to keep only the segment you will need to edit: {{p_code}}.
    The text description of the overall motion sequence is: {p_details}.
    The edit instruction is: {p_edit}
    Return the edited motion only as a sequence of integer code ids of length {p_length} separated by
    semi-colons, only use code ids in the provided table. If no edit needs to be made, return the
    original sequence. Format example: 1;2;3;4. Do not reply anything else. No explanation needed.

    """

    sample_data = {
        "frame_prompt_1": frame_prompt_1,
        "frame_prompt_2": frame_prompt_2,
        "code_prompt_1": code_prompt_1,
        "code_prompt_2": code_prompt_2
    }

    for file_name, item in sample_data.items():
        file_path = pjoin(save_prompt_dir, file_name + '.txt')
        with open(file_path, 'w', encoding='utf-8') as f:
            f.write(item)



# Get Reply From ChatGPT

In [None]:
from openai import OpenAI
import os
import time
import numpy as np
from os.path import join as pjoin
import sys
import re
from tqdm import tqdm
from utils.codebook import group_id_to_full_group_name
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

def wait(interval=15):
    tqdm.write(f"Await {interval} seconds to avoid rate limit...")
    time.sleep(interval)
    tqdm.write("Resuming...")

def parse_wait_interval(error_message: str) -> int:
    m = re.search(r'in\s+"?([\d\.]+)\s*(ms|s)"?', error_message)
    if m:
        value = float(m.group(1))   
        unit = m.group(2)         

        if unit == 'ms':
            seconds = value / 1000.0
        else:  # 's'
            seconds = value
    else:
        seconds = 20.0
    
    return float(seconds)
    

def kh2index(k_hot_array: np.ndarray):
    """
    k_hot_array: (T, D) — 2D array of k-hot vectors.
    Returns: (T, K) — indices of 1s.
    """
    k = int(k_hot_array.sum(axis=-1)[0])
    indices = np.argsort(k_hot_array, axis=-1)[:, -k:]

    sorted_indices = np.sort(indices, axis=-1)
    return sorted_indices

client = OpenAI(
    api_key="your api key here",
)

from os.path import join as pjoin
import datetime
time_str = None

if time_str == None:
    time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

dateid = timestamp
prompt_dir = f'./motion_editing_result/{dateid}/prompt'
save_reply_dir = f'./motion_editing_result/{dateid}/reply/{time_str}'
code_dir = f'./motion_editing_result/{dateid}/gt/codes'
model = "gpt-5"
max_waiting_interval = 600  # seconds
max_tol_pose_code_num = 40
n_max_retry = 3

name_list = target_file_list
# 
print("total samples:", len(name_list))
os.makedirs(save_reply_dir, exist_ok=True)
ban_list = []

##########################
for sample_id in tqdm(name_list):
    tqdm.write(f"Processing sample_id: {sample_id}", file=sys.stderr)

    if sample_id in ban_list:
        tqdm.write(f"sample_id:{sample_id} is in ban list, skip.", file=sys.stderr)
        continue

    n_retry = 0
    prompt_contents = []
    answers = {}
    skip_sample = False
    prompt_names = ['code_prompt_1', 'code_prompt_2', 'frame_prompt_1', 'frame_prompt_2']
    
    if os.path.exists(pjoin(save_reply_dir, sample_id)):
        tqdm.write(f"found existing reply, skip:{sample_id}", file=sys.stderr)
        continue

    # load prompt and reply template
    for prompt_name in prompt_names:
        file_path = pjoin(prompt_dir, sample_id, f'{prompt_name}.txt')
        with open(file_path, 'r', encoding='utf-8') as f:
            prompt = f.read()

        prompt_contents.append(prompt)
        answers[f"{prompt_name}_reply"] = []

    # Q1. Joint states affected by the edit instruction
    tqdm.write("Q1 processing...", file=sys.stderr)
    while True:
        try:
            res1 = client.responses.create(
            model=model,
            instructions="You are a motion-editing assistant. Follow the instructions written in the input text exactly and only return the requested output format",
            input=f"{prompt_contents[0]}",
            store=True            
            )

            reply_joint_states = res1.output_text

            joint_states = [int(e) for e in list(reply_joint_states.split(';'))]

            if len(joint_states) > max_tol_pose_code_num:
                tqdm.write("model tries to return too many joint states, retry...", file=sys.stderr)
                
                n_retry += 1
                
                if n_retry >= n_max_retry:
                    tqdm.write(f"exceed max retry limit({n_max_retry}), sample_id:{sample_id}", file=sys.stderr)
                    skip_sample = True
                    break
                
                continue

            answers['code_prompt_1_reply'] = joint_states
            break
        except Exception as e:
            tqdm.write(str(e), file=sys.stderr)
            interval = parse_wait_interval(str(e))
            interval = max(max_waiting_interval, interval * 5 + 5)
            tqdm.write(f"Error processing sample_id {sample_id} in Q1. will be retried after {interval}sec", file=sys.stderr)
            wait(interval=interval)
            continue
    
    if skip_sample:
        tqdm.write(f"Skipping sample_id:{sample_id}", file=sys.stderr)
        continue
    
    tqdm.write("Q1 done", file=sys.stderr)
    tqdm.write("Q2 processing...", file=sys.stderr)

    progress_idx = 0
    # Q2. 
    while True:
        try:
            code_file_path = pjoin(code_dir, f'{sample_id}.npy')
            p_code = np.load(code_file_path)
            int_code = kh2index(p_code)

            # joint state에 해당하는 p_code만 가져오기
            int_code_selected = int_code[:, joint_states].T

            for id in range(progress_idx, len(joint_states)):
                temp = prompt_contents[1]
                joint_state = group_id_to_full_group_name[joint_states[id]]
                selected_codes = int_code_selected.tolist()[id]
                new_prompt_content = temp.replace("{joint_state}", str(joint_state)).replace("{p_code}", str(selected_codes))
                
                res2 = client.responses.create(
                model=model,
                instructions="You are a motion-editing assistant. Follow the instructions written in the input text exactly and only return the requested output format",
                input=f"{new_prompt_content}",
                previous_response_id=res1.id
                )
                answers['code_prompt_2_reply'].append([int(e) for e in list(res2.output_text.split(';'))])
                progress_idx += 1
            break
        except Exception as e:
            tqdm.write(str(e), file=sys.stderr)
            interval = parse_wait_interval(str(e))
            p_progress = (progress_idx / len(joint_states)) * 100
            interval = max(max_waiting_interval, (interval * len(joint_states) + 5))
            tqdm.write(f"Error processing sample_id {sample_id} in Q2(progress:{p_progress:.2f}%). will be retried after {interval}sec", file=sys.stderr)
            wait(interval=interval)
            continue

    tqdm.write("Q2 done", file=sys.stderr)
    tqdm.write("Q3 processing...", file=sys.stderr)

    # Q3.
    while True:
        try:
            res3 = client.responses.create(
                model=model,
                instructions="You are a motion-editing assistant. Follow the instructions written in the input text exactly and only return the requested output format",
                input=f"{prompt_contents[2]}",
                store=True,
                previous_response_id=res1.id
            )
            joint_states_2 = [int(e) for e in list(res3.output_text.split(';'))]
            answers['frame_prompt_1_reply'] = joint_states_2
            break
        except Exception as e:
            tqdm.write(str(e), file=sys.stderr)
            interval = parse_wait_interval(str(e))
            interval = max(max_waiting_interval, interval * 5 + 5)
            tqdm.write(f"Error processing sample_id {sample_id} in Q3. will be retried after {interval}sec", file=sys.stderr)
            wait(interval=interval)
            continue

    tqdm.write("Q3 done", file=sys.stderr)
    tqdm.write("Q4 processing...", file=sys.stderr)

    # Q4.
    progress_idx = 0
    while True:
        try:
            for idx, p_codes in enumerate(answers['code_prompt_2_reply']):
                
                if idx < progress_idx:
                    continue

                temp = prompt_contents[3]
                new_prompt = temp.replace("{p_code}", str(p_codes))

                res4 = client.responses.create(
                model=model,
                instructions="You are a motion-editing assistant. Follow the instructions written in the input text exactly and only return the requested output format",
                input=f"{new_prompt}",
                previous_response_id=res3.id
                )

                answers['frame_prompt_2_reply'].append([int(e) for e in list(res4.output_text.split(';'))])
                
                progress_idx += 1
            break
        except Exception as e:
            tqdm.write(str(e), file=sys.stderr)
            interval = parse_wait_interval(str(e))
            p_progress = (progress_idx / len(answers['code_prompt_2_reply'])) * 100
            interval = max(max_waiting_interval, (interval * len(answers['code_prompt_2_reply']) + 5))
            tqdm.write(f"Error processing sample_id {sample_id} in Q4(progress:{p_progress:.2f}%). will be retried after {interval}sec", file=sys.stderr)
            wait(interval=interval)
            continue

    tqdm.write("Q4 done", file=sys.stderr)
    tqdm.write("Saving all answers...", file=sys.stderr)
    # Save all answers
    reply_format = f"""

    Q1: Return a semi-colon separated sequence of the ids of the joint states you may be affected by the edit instruction.
    A1: {answers['code_prompt_1_reply']}

    Q2: Return the edited motion only as a sequence of integer code ids of length 49 separated by semi-colons, only use code ids in the provided table.
    A2: {answers['code_prompt_2_reply']}

    Q3: Return a semi-colon separated sequence of the ids of the joint states you will need to examine in order to determine the starting and ending frame of a motion sequence that will be affected by the edit instruction.
    A3: {answers['frame_prompt_1_reply']}

    Q4: Return the starting index and ending index of the segment that is affected by the edit, separated by semi-colon, if the edit affects the overall movement, select the entire sequence.
    A4: {answers['frame_prompt_2_reply']}

    """

    file_dir = pjoin(save_reply_dir, sample_id)
    os.makedirs(file_dir, exist_ok=True)
    file_path = pjoin(file_dir, 'reply.txt')
    with open(file_path, 'w', encoding='utf-8') as f:
        f.write(reply_format)

    tqdm.write(f"Finished sample_id:{sample_id}", file=sys.stderr)

# Load Reply Data With Completed Samples

In [None]:
import re
from os.path import join as pjoin
import json
import ast
date = timestamp
time_stamp = time_str
reply_dir = f'./motion_editing_result/{date}/reply/{time_stamp}'
cached_dir = f'./motion_editing_result/{date}/gt'

def parse_reply(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    matches = re.findall(r'A\d:\s*(.*?)(?=\n\s*Q|\Z)', text, re.DOTALL)

    if len(matches) == 4:
        a1_str = matches[0].strip()
        a2_str = matches[1].strip()
        a3_str = matches[2].strip()
        a4_str = matches[3].strip()

        try:
            A1 = json.loads(a1_str)
            A2 = json.loads(a2_str)
            A3 = json.loads(a3_str)
            A4 = json.loads(a4_str)
        except json.JSONDecodeError as e:
            print(f"Error: Failed to convert the extracted string into a Python list. - {e}")
            print("--- Original string ---")
            print(f"A1_str: {a1_str}")
            print(f"A2_str: {a2_str}")
            print(f"A3_str: {a3_str}")
            print(f"A4_str: {a4_str}")
    else:
        print("Error: Could not find all answers for A1, A2, A3, and A4 in the file.")
    return A1, A2, A3, A4

def parse_edit_text(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    pattern = r"Given the edit instruction:\s*(.+?)\s*(?:\n|$)"
    match = re.search(pattern, text)
    if match:
        instruction = match.group(1)  # "Make the movement look more energetic."
    else:
        print("No instruction found")
        raise Exception("No instruction found in the provided file.")
    
    return instruction





In [None]:
if os.path.exists(reply_dir) and os.path.isdir(reply_dir):
    sample_ids = [name for name in os.listdir(reply_dir) if os.path.isdir(pjoin(reply_dir, name))]

print(sample_ids)

In [None]:
def index2kh(indices_array: np.ndarray, D: int) -> np.ndarray:
    T = indices_array.shape[0]
    k_hot_array = np.zeros((T, D), dtype=int)
    row_indices = np.arange(T)[:, np.newaxis]
    k_hot_array[row_indices, indices_array] = 1
    
    return k_hot_array

def kh2index(k_hot_array: np.ndarray):
    """
    k_hot_array: (T, D) — 2D array of k-hot vectors.
    Returns: (T, K) — indices of 1s.
    """

    k = int(k_hot_array.sum(axis=-1)[0])
    indices = np.argsort(k_hot_array, axis=-1)[:, -k:]
    sorted_indices = np.sort(indices, axis=-1)
    
    return sorted_indices

In [None]:
code_motions = []
feats_motions = []
int_indices = []
texts = []
edit_texts = []
edit_mode = 'gpt_edit'

if edit_mode == 'gpt_edit':
    prompt_dir = f'./motion_editing_result/{date}/prompt'

for sample_id in sample_ids:
    code_motion = np.load(pjoin(cached_dir, 'codes', f'{sample_id}.npy'))
    pose = np.load(pjoin(cached_dir, 'npy_pose', f'{sample_id}.npy'))
    coords = np.load(pjoin(cached_dir, 'npy_coords', f'{sample_id}.npy'))
    txt_path = pjoin(cached_dir, 'texts', f'{sample_id}.txt')
    with open(txt_path, 'r', encoding='utf-8') as f:
        text = f.read().strip()
    texts.append(text)
    
    if edit_mode == 'gpt_edit':
        edit_text = parse_edit_text(pjoin(prompt_dir, sample_id, 'frame_prompt_1.txt'))
        edit_texts.append(edit_text)
    
    int_idx = kh2index(code_motion)
    code_motions.append(code_motion)
    int_indices.append(int_idx)
    feats_motions.append(pose)
    


# Model Load

In [None]:
import yaml
import argparse

checkpoint_dir = './pretrained/exp_pg_tokenizer'
exp_name = checkpoint_dir.split('/')[-1]
print(exp_name)

config = pjoin(checkpoint_dir, 'arguments.yaml')
with open(config, 'r') as f:
    arg_dict = yaml.safe_load(f)

args = argparse.Namespace(**arg_dict)

ours = PoseGuidedTokenizer(
                    args, 
                    args.nb_code,                      # nb_code
                    args.code_dim,                    # code_dim
                    args.output_emb_width,            # output_emb_width
                    args.down_t,                      # down_t
                    args.stride_t,                    # stride_t
                    args.width,                       # width
                    args.depth,                       # depth
                    args.dilation_growth_rate,        # dilation_growth_rate
                    args.vq_act,                      # activation
                    args.vq_norm,                     # norm
                    args.cfg_cla,                     # cfg_cla
                    aggregate_mode=None,    # aggregate_mode
                    num_quantizers=args.rvq_num_quantizers,
                    shared_codebook=args.rvq_shared_codebook,
                    quantize_dropout_prob=args.rvq_quantize_dropout_prob,
                    quantize_dropout_cutoff_index=args.rvq_quantize_dropout_cutoff_index,
                    rvq_nb_code=args.rvq_nb_code,
                    mu=args.rvq_mu,
                    resi_beta=args.rvq_resi_beta,
                    quantizer_type=getattr(args, 'rvq_quantizer_type', 'hard'),
                    params_soft_ent_loss=0.0,
                    use_ema=(not getattr(args, 'unuse_ema', False)),
                    init_method=getattr(args, 'rvq_init_method', 'enc'),  # 'enc', 'xavier', 'uniform',
)

torch.manual_seed(args.seed)

checkpoint_path = pjoin(checkpoint_dir, 'net_best_fid.pth')
ckpt = torch.load(checkpoint_path, map_location='cpu')
unit_length = 4

if args.dataname == 'kit' : 
    args.nb_joints = 21
    args.max_motion_len = 196
else:
    args.nb_joints = 22
    args.max_motion_len = 196

ours.load_state_dict(ckpt['net'], strict=True)
print(f"loaded iter:{ckpt['nb_iter']}")

if use_gpu:
    ours.cuda()

ours.eval()
print()


In [None]:
max_motion_len =  196
unit_length = 4

def inv_transform(data, mean, std):
    return data * std + mean

def transform(data, mean, std):
    return (data - mean) / std

meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
mean = np.load(pjoin(meta_dir, 'mean.npy'))
std = np.load(pjoin(meta_dir, 'std.npy'))

# Edit

In [None]:
from utils.codebook import *
import torch
full_cat_name_to_cat_id = {v: k for k, v in cat_id_to_full_cat_name.items()}

edited_codes = []
target_lens = []
mot_downsampled_list = []
edit_mode = 'gpt_edit' # or 'gpt'

In [None]:
def edit_custom(code_indices, target_cat_name, force_edit=True):
    cat_id = full_cat_name_to_cat_id[target_cat_name]
    group_id = cat_id_to_group_id[cat_id]
    end, start = vq_to_range[group_id]
    # code_indices.shape -> N, C
    new_code_indices = code_indices.clone()

    activated = torch.argmax(code_indices[:, start:end+1].float(), dim=1) # 상대적인 카테고리 순번
    activated = start + activated

    if force_edit:  # 전체 시간에 대하여 
        new_code_indices[:, start:end+1] = 0
        new_code_indices[:, cat_id] = 1
        result_dict = {}

        new_activated = torch.argmax(new_code_indices[:, start:end+1].float(), dim=1) # 상대적인 카테고리 순번
        new_activated = start + new_activated
        for i in range(activated.shape[0]):
            result_dict[f"frame_{i}(src -> tgt)"] = (id_to_name[activated[i].item()], target_cat_name) 
    else:
        raise NotImplementedError()
    
    return new_code_indices, result_dict

In [None]:
# to_edit_code_indices: a variable that indicates which codes need to be changed

def edit_motion_code_from_reply(source_code, joint_indices, new_code_sequences, frame_ranges):
    
    # Create a copy to avoid modifying the original array in place.
    modified_code = source_code.copy()

    # Use zip to iterate over the three lists (joints, new codes, frame ranges) simultaneously.
    for joint_idx, new_codes, frame_range in zip(joint_indices, new_code_sequences, frame_ranges):
        
        # Get the start and end frames of the segment to edit.
        start_frame, end_frame = frame_range
        
        # To include end_frame, set the slicing range up to end_frame + 1.
        target_slice = slice(start_frame, end_frame + 1)
        
        # Use slicing and indexing to select a specific joint and frame range,
        # and overwrite it with the new code sequence.
        # The length of new_codes must match the length of the sliced segment.
        if len(new_codes) == (end_frame - start_frame + 1):
            modified_code[target_slice, joint_idx] = new_codes
        else:
            print(f"Warning: The code sequence length ({len(new_codes)}) for joint {joint_idx} "
                  f"does not match the length of the frame range ({start_frame}-{end_frame}), so it is skipped.")

    return modified_code

def edit_motion_code_from_reply_v2(source_code, joint_indices_1, joint_indices_2, new_code_sequences, frame_ranges):
    """
    Assumptions:
    - source_code.shape == (T, J)
    - len(new_codes) >= T  (or at least >= end_frame + 1)
    - frame_ranges is a list of (start_frame, end_frame) tuples
    """
    modified_code = source_code.copy()
    T = source_code.shape[0]

    joint_indices = list(set(joint_indices_1) & set(joint_indices_2))
    print("intersected joint indices:", joint_indices)
    
    for joint_idx, new_codes, frame_range in zip(joint_indices, new_code_sequences, frame_ranges):
        start_frame, end_frame = frame_range

        # Clamp to valid array bounds
        start = max(0, start_frame)
        end = min(T - 1, end_frame)

        if end < start:
            print(f"Warning: Invalid frame range ({start_frame}, {end_frame}), so it is skipped.")
            continue

        if end >= len(new_codes):
            print(
                f"Warning: For joint {joint_idx}, end_frame({end}) exceeds new_codes length ({len(new_codes)}), "
                f"so it is adjusted to {len(new_codes)-1}."
            )
            end = len(new_codes) - 1
            if end < start:
                print(
                    f"Warning: The adjusted range ({start}, {end}) is invalid, so it is skipped."
                )
                continue

        target_slice = slice(start, end + 1)

        modified_code[target_slice, joint_idx] = new_codes[target_slice]

    return modified_code


### Edit Pose Code

In [None]:
from torch import device


if edit_mode == 'custom_edit':
    target_code_list = ['L-hand vs R-hand distance 1', 'L-elbow vs R-elbow distance 1', 'L-hand vs R-elbow distance 1', 'R-hand vs L-elbow distance 1']
    for idx, id in enumerate(sample_ids):
        print(f"## id:{id} ##")
        
        # data
        mo_pose_codes = code_motions[idx] #
        mo_pose_codes = torch.from_numpy(mo_pose_codes)

        motion = feats_motions[idx]
        motion = torch.from_numpy(motion)
        
        mo_len, cb_num = mo_pose_codes.shape

        # edit: changes every frame as L-knee angle 10
        temp = mo_pose_codes
        for target_code in target_code_list:
            temp, _ = edit_custom(temp, target_code, force_edit=True)
        new_code_indices = temp
        
        target_len = (mo_len // unit_length) * unit_length
        motion_downsampled = motion[:target_len,:].unsqueeze(0)
        mot_downsampled_list.append(motion_downsampled)
        
        target_lens.append(target_len)
        
        if use_gpu:
            new_code_indices = new_code_indices.cuda()

        edited_codes.append(new_code_indices.float())
        edit_texts.append(f"Custom Edit:({str(target_code_list)})")

elif edit_mode == 'gpt_edit':
    for idx, id in enumerate(sample_ids):
        print(f"## id:{id} ##")

        mo_pose_codes = code_motions[idx] #
        mo_pose_codes = torch.from_numpy(mo_pose_codes)

        motion = feats_motions[idx]
        motion = torch.from_numpy(motion)
        
        mo_len, cb_num = mo_pose_codes.shape
        
        # edit from gpt reply
        reply_file_path = pjoin(reply_dir, id, 'reply.txt')
        a1, a2, a3, a4 = parse_reply(reply_file_path)
        edited_code = edit_motion_code_from_reply_v2(int_indices[idx], a1, a3, a2, a4)

        edited_kh_code = index2kh(edited_code, D=392)
        target_len = (mo_len // unit_length)*unit_length
        motion_downsampled = motion[:target_len,:].unsqueeze(0)
        mot_downsampled_list.append(motion_downsampled)

        target_lens.append(target_len)

        _edited_kh_code = torch.from_numpy(edited_kh_code).float()

        if use_gpu:
            _edited_kh_code = _edited_kh_code.cuda()
        edited_codes.append(_edited_kh_code)
        
            

In [None]:
edit_save_dir = pjoin(save_dir, 'gt', 'edit')
edit_coord_dir = f"{edit_save_dir}/{time_stamp}/{exp_name}/edit/{edit_mode}/npy_coords"
vis_save_dir = f"{edit_save_dir}/{time_stamp}/{exp_name}/edit/{edit_mode}/gif"

if not os.path.exists(edit_coord_dir):
    os.makedirs(edit_coord_dir)

if not os.path.exists(vis_save_dir):
    os.makedirs(vis_save_dir)

for idx, id in enumerate(sample_ids):
    print(f"## id:{id} ##")

    save_path = f"{vis_save_dir}/{id}.gif"

    if os.path.exists(save_path):
        print(f"Skip existing {save_path}")
        continue
    
    # target_len = target_lens[idx]
    mo_pose_codes = edited_codes[idx].unsqueeze(0)
    mot_downsampled = mot_downsampled_list[idx]
    print(mo_pose_codes.shape)
    print(mot_downsampled.shape)

    pred_motion, *_ = ours(code_indices=mo_pose_codes.float(), motion=mot_downsampled.float(), drop_out_residual_quantization=True)
        
    pred_denorm = inv_transform(pred_motion.cpu().detach().numpy(), mean, std)
    pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float(), 22)
    pred_xyz = pred_xyz.numpy()

    np.save(pjoin(edit_coord_dir, f"{id}.npy"), pred_xyz.squeeze(0))

    print(pred_xyz.shape)
    
    sentence = texts[idx]

    plot_3d.draw_to_batch(pred_xyz, [sentence], [save_path], footer_text=f'Edit, edit_text:{edit_texts[idx]}', footer_fontsize=15)