In [19]:
import torch
import torch.nn as nn
import numpy as np
import os
from PoseTransformer import PoseTransformer, CrossModalTransformer
from NoiseScheduler import NoiseScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from utils.motion_process import recover_from_ric

class ResidualMLPBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
        )
    def forward(self, x):
        return x + self.net(x)

class Encoder(nn.Module):
    # ... (Same definition as in AE training script)
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, latent_dim)
        )
    def forward(self, x): return self.net(x)

class Decoder(nn.Module):
    # ... (Same definition as in AE training script)
     def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU(),
            nn.Linear(256, output_dim)
        )
     def forward(self, x): return self.net(x)

class Autoencoder(nn.Module):
     def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(latent_dim, input_dim)
     def forward(self, x):
        latent = self.encoder(x)
        reconstruction = self.decoder(latent)
        return reconstruction
    
MODEL_PATH = "model_saves/v7.0.0_TT_datapercent0.87_lr0.0002_WD0.1_P8C8_CFG0.1_ACC128"
EMBEDDING_DIM = 512
POSE_FEATURES_DIM = 263
CLIP_MAX_LENGTH = 77
LATENT_DIM = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load trained models
print("Loading models...")
pose_transformer = PoseTransformer(
    pose_dim=LATENT_DIM,
    embedding_dim=EMBEDDING_DIM,
    num_heads=8,
    num_layers=8,
    dropout=0,
    use_decoder=False
).to(device)

text_cross_transformer = CrossModalTransformer(
    pose_dim=EMBEDDING_DIM,
    memory_dim=EMBEDDING_DIM,
    embedding_dim=EMBEDDING_DIM,
    num_heads=8,
    num_layers=8,
    dropout=0,
    use_decoder=True
).to(device)

noise_predictor = nn.Sequential(
    ResidualMLPBlock(EMBEDDING_DIM),
    nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM),
    nn.SiLU(),
    nn.Linear(EMBEDDING_DIM, LATENT_DIM),
).to(device)

autoencoder = Autoencoder(POSE_FEATURES_DIM, LATENT_DIM).to(device)
ae_checkpoint = torch.load("model_saves\\autoencoder\\pose_ae_best.pth", map_location=device)
# Verify dimensions match
if ae_checkpoint['input_dim'] != POSE_FEATURES_DIM:
        raise ValueError(f"AE Checkpoint input_dim ({ae_checkpoint['input_dim']}) doesn't match expected ({POSE_FEATURES_DIM})")
if ae_checkpoint['latent_dim'] != LATENT_DIM:
        raise ValueError(f"AE Checkpoint latent_dim ({ae_checkpoint['latent_dim']}) doesn't match expected ({LATENT_DIM})")

autoencoder.load_state_dict(ae_checkpoint['model_state_dict'])
for param in autoencoder.parameters():
    param.requires_grad = False
    
# CLIP Model and Tokenizer
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Model checkpoint not found at: {MODEL_PATH}")
# weights only to true to stop receiving the warning
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True)

pose_transformer.load_state_dict(checkpoint["pose_transformer"])
text_cross_transformer.load_state_dict(checkpoint["text_cross_transformer"])
noise_predictor.load_state_dict(checkpoint["noise_predictor"])
print("Model weights loaded successfully.")

pose_transformer.eval()
text_cross_transformer.eval()
noise_predictor.eval()
clip_text_model.eval()
autoencoder.eval()

noise_scheduler = NoiseScheduler(timesteps=1000)

Using device: cuda
Loading models...


  ae_checkpoint = torch.load("model_saves\\autoencoder\\pose_ae_best.pth", map_location=device)


Model weights loaded successfully.


In [20]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation, FFMpegFileWriter
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import mpl_toolkits.mplot3d.axes3d as p3

def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):
    matplotlib.use('Agg')

    title_sp = title.split(' ')
    if len(title_sp) > 10:
        title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
    def init():
        ax.set_xlim3d([-radius / 2, radius / 2])
        ax.set_ylim3d([0, radius])
        ax.set_zlim3d([0, radius])
        # print(title)
        fig.suptitle(title, fontsize=20)
        ax.grid(b=False)

    def plot_xzPlane(minx, maxx, miny, minz, maxz):
        ## Plot a plane XZ
        verts = [
            [minx, miny, minz],
            [minx, miny, maxz],
            [maxx, miny, maxz],
            [maxx, miny, minz]
        ]
        xz_plane = Poly3DCollection([verts])
        xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
        ax.add_collection3d(xz_plane)

    #         return ax

    # (seq_len, joints_num, 3)
    data = joints.copy().reshape(len(joints), -1, 3)
    # fig = plt.figure(figsize=figsize)
    # ax = p3.Axes3D(fig)
    fig, ax = plt.subplots(subplot_kw={'projection': '3d'}, figsize=figsize)  # Use subplots to create a 3D axis

    init()
    MINS = data.min(axis=0).min(axis=0)
    MAXS = data.max(axis=0).max(axis=0)
    colors = ['red', 'blue', 'black', 'red', 'blue',  
              'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
             'darkred', 'darkred','darkred','darkred','darkred']
    frame_number = data.shape[0]
    #     print(data.shape)

    height_offset = MINS[1]
    data[:, :, 1] -= height_offset
    trajec = data[:, 0, [0, 2]]
    
    data[..., 0] -= data[:, 0:1, 0]
    data[..., 2] -= data[:, 0:1, 2]

    #     print(trajec.shape)

    def update(index):
        #         print(index)
        # ax.lines = []
        # ax.collections = []
        ax.cla()
        ax.view_init(elev=120, azim=-90)
        ax.dist = 7.5
        #         ax =
        plot_xzPlane(MINS[0]-trajec[index, 0], MAXS[0]-trajec[index, 0], 0, MINS[2]-trajec[index, 1], MAXS[2]-trajec[index, 1])
#         ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3)
        
        if index > 1:
            ax.plot3D(trajec[:index, 0]-trajec[index, 0], np.zeros_like(trajec[:index, 0]), trajec[:index, 1]-trajec[index, 1], linewidth=1.0,
                      color='blue')
        #             ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2])
        
        
        for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
#             print(color)
            if i < 5:
                linewidth = 4.0
            else:
                linewidth = 2.0
            ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, color=color)
        #         print(trajec[:index, 0].shape)

        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_zticklabels([])

    ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False)

    ani.save(save_path, fps=fps)
    plt.close()


In [None]:
@torch.no_grad()
def infer(text_input, seq_length=60, batch_size=1, guidance_scale=7):
    text_input = [text_input] * batch_size 
    # Tokenize text input and get text embeddings
    tokenized = clip_tokenizer(
        text_input,
        return_tensors="pt",
        padding=True,        # pad to longest in this *batch*
        truncation=True,
        max_length=None      # optional: drop any hard max
    ).to(device)
    
    input_ids      = tokenized.input_ids
    attention_mask = tokenized.attention_mask
    text_pad_mask = (attention_mask == 0)

    txt_out = clip_text_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        return_dict=True
    )

    text_embeddings = txt_out.last_hidden_state
    null_text_embeddings = torch.zeros_like(text_embeddings)

    # Initialize random noise as the starting pose
    pose_shape = (batch_size, seq_length, LATENT_DIM)
    current_pose = torch.randn(pose_shape, device=device)
    pose_mask = torch.zeros((batch_size, seq_length), device=device, dtype=torch.float32) # <-- Create mask

    print("Starting denoising loop...")
    print(current_pose.shape, pose_mask.shape)

    # Perform reverse diffusion (denoising)
    for t in reversed(range(noise_scheduler.timesteps)):
        timesteps = torch.full((batch_size,), t, device=device, dtype=torch.long)

        # Get pose embeddings from transformer
        pose_embeddings = pose_transformer(
            current_pose,
            pose_mask=pose_mask,
            timesteps=timesteps
        )
        # Predict noise
        uncond_embeddings = text_cross_transformer(
            pose_embeddings,
            null_text_embeddings,
            pose_mask,
            memory_mask=None
        )
        predicted_noise_uncond = noise_predictor(uncond_embeddings)

        # (2) Conditional
        cond_embeddings = text_cross_transformer(
            pose_embeddings,
            text_embeddings,
            pose_mask,
            memory_mask=text_pad_mask
        )
        predicted_noise_cond = noise_predictor(cond_embeddings)

        # --- Combine ---
        predicted_noise = predicted_noise_uncond + guidance_scale * (predicted_noise_cond - predicted_noise_uncond)

        beta_t = noise_scheduler.betas[t].to(device)
        alpha_t = noise_scheduler.alphas[t].to(device)
        sqrt_one_minus_alpha_cumprod_t = noise_scheduler.sqrt_one_minus_alphas_cumprod[t].to(device)
        sqrt_alpha_t = torch.sqrt(alpha_t) 

        # Calculate the term multiplying the predicted noise
        noise_coeff = beta_t / sqrt_one_minus_alpha_cumprod_t

        # Calculate the mean of x_{t-1}
        mean_x_t_minus_1 = (1.0 / sqrt_alpha_t) * (current_pose - noise_coeff * predicted_noise)

        # Add noise (variance term) - except for the last step (t=0)
        if t > 0:
            variance = beta_t # Use beta_t for variance (sigma_t = sqrt(beta_t))
            std_dev = torch.sqrt(variance)
            noise = torch.randn_like(current_pose)
            current_pose = mean_x_t_minus_1 + std_dev * noise # Update current_pose to x_{t-1}
        else:
            current_pose = mean_x_t_minus_1

    return current_pose.cpu()


In [22]:
# Example usage
text_prompt = "a man walks forward in a straight line."
seq_length = 50
batch_size = 3
generated_latent_poses = infer(text_prompt, batch_size=batch_size, seq_length=seq_length, guidance_scale=3)
latent_poses_flat = generated_latent_poses.view(-1, LATENT_DIM).to(device)
with torch.no_grad(): # Ensure no gradients through AE
    poses_flat = autoencoder.decoder(latent_poses_flat)

Starting denoising loop...
torch.Size([3, 50, 64]) torch.Size([3, 50])


In [23]:
poses = poses_flat.view(batch_size, seq_length, POSE_FEATURES_DIM).cpu()
mean = torch.tensor(np.load("pose_stats.npz")['mean'], dtype=torch.float32).unsqueeze(0)
std = torch.tensor(np.load("pose_stats.npz")['std'], dtype=torch.float32).unsqueeze(0)
poses = poses * std + mean

In [24]:
import matplotlib.pyplot as plt
import numpy as np
from utils.utils import *

%matplotlib inline

joints_num = 22
kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
# print(example_data_ml3d[0:5])

print(poses.shape)

for idx, sequence in enumerate(poses):
    joint = recover_from_ric(sequence.float(), joints_num).numpy()
    joint = motion_temporal_filter(joint, sigma=1)
    plot_3d_motion(f"output/AE_CFG3_test_ani2.{idx+1}.mp4", kinematic_chain, joint, title=text_prompt, fps=10)


torch.Size([3, 50, 263])
