![logo](assets/logo.png "anim∞")

In [None]:
import torch
from tqdm import tqdm

from model.demo_diffusion import GaussianDiffusion
#from lafan.preprocess import Normalizer
from lafan.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, sk_parents, amass_offsets)
import imageio

import pickle
import onnx
import onnxruntime as ort
from tempfile import TemporaryDirectory
from IPython.display import Image
from pathlib import Path
from matplotlib import pyplot as plt
import matplotlib.animation as animation
import os
import numpy as np

import glob
import random

## Utility functions for processing, transforming, and plotting data.

In [None]:
def splitPosQ(posq):
    # Array comes in the format:
    # seq_len x (4 contact labels + 3*joint_num position DoFs + 4*joint_num quaternion DoFs + 20 phase variables)
    joint_num = 22
    contact = posq[:, :4] > 0.5
    pos = posq[:,4:4+joint_num*3]
    orn = posq[:, 4+joint_num*3:]
    phase = posq[:,4+joint_num*7:]
    pos = pos.reshape((pos.shape[0], -1, 3))
    orn = orn.reshape((orn.shape[0], -1, 4))
    return pos, orn, phase, contact

def global_to_local(orn, x, skeleton_mocap):
    root_pos_tensor = torch.Tensor(x[:,0]).unsqueeze(0).cuda()
    local_orn = skeleton_mocap.global_to_local(orn)
    local_orn_tensor = torch.Tensor(local_orn).unsqueeze(0).cuda() # 1 x s x j x 4
    global_pos, _ = skeleton_mocap.forward_kinematics_with_rotation(local_orn_tensor, root_pos_tensor)
    return global_pos, local_orn, root_pos_tensor

def set_line_data_3d(line, x):
    line.set_data(x[:, :2].T)
    line.set_3d_properties(x[:, 2])
        
def plot_single_pose(
    num, poses, lines, ax, skeleton
):
    parent_idx = skeleton.parents()

    pose = poses[num]
    for i, (p, line) in enumerate(zip(parent_idx, lines)):
        # don't plot root
        if i == 0:
            continue
        # stack to create a line
        data = np.stack((pose[i], pose[p]), axis=0)
        set_line_data_3d(line, data)

    x_min = pose[:, 0].min()
    x_max = pose[:, 0].max()

    y_min = pose[:, 1].min()
    y_max = pose[:, 1].max()

    z_min = pose[:, 2].min()
    z_max = pose[:, 2].max()
    
    xdiff = x_max - x_min
    ydiff = y_max - y_min
    zdiff = z_max - z_min
    
    xcenter = x_min + xdiff/2
    ycenter = y_min + ydiff/2
    zcenter = z_min + zdiff/2
    
    biggestdiff = max([xdiff, ydiff, zdiff])
    step = biggestdiff/2
    x_min, x_max = xcenter - step, xcenter + step
    y_min, y_max = ycenter - step, ycenter + step
    z_min, z_max = zcenter - step, zcenter + step    

    ax.set_xlim(x_min, x_max)
    ax.set_xlabel("$X$ Axis")

    ax.set_ylim(z_min, z_max)
    ax.set_ylabel("$Y$ Axis")

    ax.set_zlim(y_min, y_max)
    ax.set_zlabel("$Z$ Axis")

    plt.draw()
    
def plot(sample, skeleton_mocap, out):
    Path(out).mkdir(parents=True, exist_ok=True)
    for i, (label, sequence) in enumerate(zip(sample["labels"], sample["samples"])):
        x = sequence
        x, orn, phase, contact = splitPosQ(x)
        num_steps = sequence.shape[0]
        
        # convert orn from global to local
        global_pos, local_orn, root_pos_tensor = global_to_local(orn, x, skeleton_mocap)
        root_pos_tensor = root_pos_tensor.detach().cpu().numpy()
        x2 = global_pos.squeeze().cpu().numpy()
        
        fig = plt.figure()
        ax = fig.add_subplot(projection="3d")
        
        # Create lines initially without data
        lines = [ax.plot([], [], [], zorder=10)[0] for _ in skeleton_mocap.parents()]
        
        file_out = './{}{}.gif'.format(label, i)
        
        x2[:,:,(1, 2)] = x2[:,:,(2,1)]
        
        anim = animation.FuncAnimation(
            fig, plot_single_pose, num_steps, fargs=(x2, lines, ax, skeleton_mocap), interval=1000//30)
        gifname = os.path.join(out, file_out)
        anim.save(gifname)
        plt.close()

## Constants setup

In [None]:
# find device
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")

# Load Skeleton
offset = sk_offsets
skeleton_mocap = Skeleton(offsets=offset, parents=sk_parents, device=device)
skeleton_mocap.remove_joints(sk_joints_to_remove)

num_phases = 0
num_foot_joints = 2
pos_dim = 22 * 3
rot_dim = 22 * 4
repr_dim = 2 * num_foot_joints + pos_dim + rot_dim + 2 * num_phases

# unfortunately you can't change this, since it's baked into the onnx model
sample_size = 4

# auxiliary trained hyperparameters
UNC_LABEL, le, normalizer, horizon, num_labels = pickle.load(open("demo_info.pkl","rb"))

# trained model, compiled into a fast onnx graph
model = ort.InferenceSession('unet.onnx', providers=['CUDAExecutionProvider'])

diffusion = GaussianDiffusion(model, 
                      horizon, 
                      repr_dim, 
                      skeleton=skeleton_mocap, 
                      predict_epsilon=False, 
                      unc_token=UNC_LABEL,
                      n_timesteps=1000,
                      guidance_weight=3
                     )
diffusion.to(device)
diffusion.eval()
# generated shape
shape = (sample_size, 1, horizon-1, repr_dim)

## Let's generate a sample!

In [None]:
# pick random semantic labels
y = torch.randint(0, num_labels, (sample_size, 1))
# or, you can set them manually with names from the LAFAN dataset as follows
#y = torch.Tensor(le.transform(["aiming"] * sample_size)).int()
y = y.squeeze()
print("Generating a sample for classes {}".format([le.inverse_transform([x.cpu().numpy()])[0] for x in y]))
sample = diffusion(shape, y).squeeze(1)
sample = normalizer.unnormalize(sample.cpu())

sample = sample.detach().cpu().numpy()
out = {
        "labels":[le.inverse_transform([x.cpu().numpy()])[0] for x in y],
        "samples":sample,
        "constraints":None
      }
print("Plotting samples")
plot(out, skeleton_mocap, "out")
torch.cuda.empty_cache()

## Pick a random gif to show*

*Results vary lol

In [None]:
files = glob.glob("out/*.gif")
Image(open(random.choice(files),'rb').read())