In [1]:
from T2M_GPT_lightning.models.text2sign import Text2Sign
from T2M_GPT_lightning.dataset.toy_vq_vae_dataset import ToyDataset as ToyVQVAEDataset
from T2M_GPT_lightning.dataset.toy_t2m_trans_dataset import ToyDataset as ToyT2MTransDataset

from capstone_utils.plot_skeletons import plot_skeletons_video

In [2]:
# VQ-VAE model
VQ_VAE_MODEL_WEIGHT_PATH = "../logs/vq_vae_finetune/version_0/checkpoints/epoch=999-step=5000.ckpt"
VQ_VAE_MODEL_CONFIG_PATH = "./configs/vq_vae_model_config.yaml"
# CLIP model
CLIP_MODEL = "ViT-B/32"
# GPT model
T2M_TRANS_MODEL_WEIGHT_PATH = "./weights/t2m_trans_model.pth"
T2M_TRANS_MODEL_CONIFG_PATH = "./configs/t2m_trans_model_config.yaml"

In [3]:
text_to_sign = Text2Sign.from_path(
    VQ_VAE_MODEL_WEIGHT_PATH,
    VQ_VAE_MODEL_CONFIG_PATH,
    CLIP_MODEL,
    T2M_TRANS_MODEL_WEIGHT_PATH,
    T2M_TRANS_MODEL_CONIFG_PATH,
)

In [4]:
vq_vae_dataset = ToyVQVAEDataset("../data/toy_data/train.skels", 150, 100, 32)
t2m_trans_dataset = ToyT2MTransDataset(
    text_to_sign.clip_model,
    text_to_sign.vq_vae_model,
    "../data/toy_data/train.text",
    "../data/toy_data/train.skels",
    100,
)

In [5]:
data_index = 3

text = t2m_trans_dataset.texts[data_index]
print(f'Text: "{text}"')
print("-" * 50)

skel = vq_vae_dataset.get_full_sequences_by_idx(data_index)
print(f"Skeleton: {skel}")
print(f"Skeleton shape: {skel.shape}")
print("-" * 50)

skel_indices = t2m_trans_dataset[data_index][1]
print(f"Skeleton indices: {skel_indices}")
print(f"Skeleton indices shape: {skel_indices.shape}")
print("-" * 50)

skel_reconstructed = text_to_sign.vq_vae_model.decode_indices(skel_indices[:-1].unsqueeze(0))
print(f"Reconstructed skeleton: {skel_reconstructed}")
print(f"Reconstructed skeleton shape: {skel_reconstructed.shape}")
print("-" * 50)

indices_prediction = text_to_sign.text_to_indices(text)
print(f"Indices prediction: {indices_prediction}")
print(f"Indices prediction shape: {indices_prediction.shape}")
print("-" * 50)

sign_prediction = text_to_sign.text_to_skels(text)
print(f"Sign prediction: {sign_prediction}")
print(f"Sign prediction shape: {sign_prediction.shape}")
print("-" * 50)

Text: "auch am tag gibt es verbreitet zum teil kräftige schauer oder gewitter und in manchen regionen fallen ergiebige regenmengen ."
--------------------------------------------------
Skeleton: tensor([[0.4154, 0.2883, 0.3328,  ..., 0.4363, 0.6263, 0.5929],
        [0.4155, 0.2875, 0.3333,  ..., 0.4393, 0.6241, 0.5934],
        [0.4123, 0.2885, 0.3333,  ..., 0.4338, 0.6260, 0.6037],
        ...,
        [0.4846, 0.2343, 0.4674,  ..., 0.5611, 0.9324, 0.7510],
        [0.4793, 0.2324, 0.4676,  ..., 0.5613, 0.9443, 0.7521],
        [0.4763, 0.2313, 0.4667,  ..., 0.5614, 0.9495, 0.7484]])
Skeleton shape: torch.Size([185, 150])
--------------------------------------------------
Skeleton indices: tensor([411,  52,  69, 411, 326,  52,  69,  69,  69, 411,  69,  69,  69, 326,
         14,  69,  14,  14,  69,  14,  69,  69,  69,  14,  69,  69,  52,  67,
        411,  69,  69,  69, 411,  69,  14,  69, 411, 411, 479, 326,  14, 326,
         69, 411, 411, 411, 512], device='mps:0')
Skeleton indice

In [6]:
min_value = vq_vae_dataset.min_value
max_value = vq_vae_dataset.max_value

In [7]:
# Convert back to original values
converted_skel = (skel * (max_value - min_value)) + min_value

In [8]:
FOLDER_NAME = None
if FOLDER_NAME:
    import os

    os.makedirs(FOLDER_NAME, exist_ok=True)
else:
    FOLDER_NAME = "."

converted_skel_reconstructed = skel_reconstructed[0].cpu().detach()
converted_skel_reconstructed = (converted_skel_reconstructed * (max_value - min_value)) + min_value
plot_skeletons_video(converted_skel_reconstructed, FOLDER_NAME, "reconstruction", skel, 1, f"{data_index}")

converted_sign_prediction = sign_prediction[0].cpu().detach()
converted_sign_prediction = (converted_sign_prediction * (max_value - min_value)) + min_value
# GPT is not work
# plot_skeletons_video(converted_sign_prediction, FOLDER_NAME, "prediction", skel, 1, f"{data_index}")