# Predicting from a pre-trained motion generating model

### Running the notebook:


In [1]:
from tensorflow.compat.v1 import keras
from tensorflow.compat.v1.keras import backend as K
from tensorflow.compat.v1.keras.layers import Dense, Input
import tensorflow.compat.v1 as tf
import numpy as np
import mdn

import csv
from IPython.display import HTML, display

## Load a small dataset to use when generating new motion

In [2]:
datafolder = '1004zipped_testSCALED_yseqtest_seqlen256.npz'

loaded = np.load(datafolder)
x = loaded['x']
#y = loaded['y']

x = x.astype('float32', copy=False)
#y = y.astype('float32', copy=False)

print('x.shape: ', x.shape)
#print('y.shape: ', y.shape)

x.shape:  (14, 256, 66)


## Load the pre-trained model

In [3]:
HIDDEN_UNITS1 = 1024 # number of hidden units - ideally Id like to try 1024 buy OOM error..
HIDDEN_UNITS2 = 512
HIDDEN_UNITS3 = 256
N_MIXES =  3 # number of mixture components
INPUT_DIMS = 66 # 22 joints * 3 
OUTPUT_DIMS = 66  # number of real-values predicted by each mixture component
SEQ_LEN = 256 # Number of frames in an example
lr = 0.00001
opt = keras.optimizers.Adam(learning_rate=lr)
freq = 30 # frame rate of data




decoder = keras.Sequential()
decoder.add(keras.layers.LSTM(HIDDEN_UNITS1, batch_input_shape=(1,SEQ_LEN,INPUT_DIMS), return_sequences=True, stateful=True))#,batch_input_shape=(1,SEQ_LEN,INPUT_DIMS),input_shape=(SEQ_LEN,INPUT_DIMS)))
decoder.add(keras.layers.LSTM(HIDDEN_UNITS2, batch_input_shape=(1,SEQ_LEN,INPUT_DIMS), return_sequences=True, stateful=True))
decoder.add(keras.layers.LSTM(HIDDEN_UNITS3, stateful=True))
decoder.add(mdn.MDN(OUTPUT_DIMS, N_MIXES))
decoder.add(keras.layers.Activation('linear', dtype='float32'))
decoder.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMS,N_MIXES), optimizer=opt)
decoder.summary()

model_name = 'DD-SCALED-units-mixed-mixtures3-drop0.2-lr1e-05-seqlen256'#'DD-SCALED-units-mixed-mixtures3-drop0.2-lr1e-05-seqlen256'
decoder.load_weights(model_name+'.h5') # load weights independently from file

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  (1, 256, 1024)            4468736   
_________________________________________________________________
lstm_1 (LSTM)                (1, 256, 512)             3147776   
_________________________________________________________________
lstm_2 (LSTM)                (1, 256)                  787456    
_________________________________________________________________
mdn (MDN)                    (1, 399)                  102543    
_________________________________________________________________
activation (Activation)      (1, 399)                  0         
Total params: 8,506,511
Trainable params: 8,506,511
Non-trainable params: 0
_________________________________________________________________


# Predicting motion
TODO: 
- add date to filename
- remove print statements from mdn file

In [4]:
def write_ex_to_tsv(ex, fn):
    num_frames = ex.shape[0]
    marker_names = ['MARKER_NAMES','Head','neck','lsho','lelb','lwri','lhan','rsho','relb','rwri','rhan','t10','root','lhip','lknee','lank','lfoot','ltoe','rhip','rknee','rank','rfoot','rtoe']
    
    
    with open(fn+'.tsv', 'wt') as out_file:
        tsv_writer = csv.writer(out_file, delimiter='\t')
        tsv_writer.writerow(['NO_OF_FRAMES', num_frames])
        tsv_writer.writerow(['NO_OF_CAMERAS', 0])
        tsv_writer.writerow(['NO_OF_MARKERS', 22])
        tsv_writer.writerow(['FREQUENCY', freq])
        tsv_writer.writerow(['NO_OF_ANALOG', 0])
        tsv_writer.writerow(['ANALOG_FREQUENCY', 0])
        tsv_writer.writerow(['DESCRIPTION--', ''])
        tsv_writer.writerow(['TIME_STAMP--', ''])
        tsv_writer.writerow(['DATA_INCLUDED', '3D'])
        tsv_writer.writerow(marker_names)
        
        for frame in range(num_frames):
            tsv_writer.writerow(ex[frame,:])

            
# Predict marker positions based on **sequence** of prev frames and save the result as tsv.
def shift(arr, num, fill_value=np.nan):
    result = np.empty_like(arr)
    if num > 0:
        result[:num] = fill_value
        result[num:] = arr[:-num]
    elif num < 0:
        result[num:] = fill_value
        result[:num] = arr[-num:]
    else:
        result[:] = arr
    return result


def predict_sequence(model, pi=1e-5, sigma=1e-5, frames=256, primer_idx=0, select_mix=False, use_priming=False, mix=0):
    motion = []
    idx = primer_idx
    pred_on = x[idx,:,:] # starting pose 
    
    for i in range(frames):
        reshaped_pred_on = tf.reshape(pred_on,[1,SEQ_LEN,OUTPUT_DIMS])
        params = decoder.predict(reshaped_pred_on, steps=1)
        
        if select_mix:
            pred = mdn.sample_from_output_select_mix(params[0], OUTPUT_DIMS, N_MIXES, temp=pi, sigma_temp=sigma, mix=mix)
        else:
            pred = mdn.sample_from_output(params[0], OUTPUT_DIMS, N_MIXES, temp=pi, sigma_temp=sigma)
            
        motion.append(pred.reshape((OUTPUT_DIMS,)))
        
        if use_priming: 
            if i != 0 and i%SEQ_LEN==0:
                idx += 1
            pred_on = shift(pred_on, -1, fill_value=x[idx,i%SEQ_LEN,:])
        else:
            pred_on = shift(pred_on, -1,  fill_value=pred)
    
    motion = np.array(motion)
    fn = '100420'+ '-pi_temp-' +str(pi) + '-sig_temp-' + str(sigma) + "-mix-" + str(mix) + "-primer_idx-" + str(primer_idx)
    
    if use_priming:
        fn = fn+'PRIMING'
    if select_mix:
        fn = fn+'MIX'
    
    print('Generated motion ', motion.shape, 'with filename ', fn)
    write_ex_to_tsv(motion,fn)
    
    return fn

In [5]:
def display_animation(var_film):
    var_film = var_film+'.mp4'

    link_t = "<div align='middle'><video width='80%' controls><source src='{href}' type='video/mp4'></video></div>"


    # create HTML object, using the string template
    html = HTML(link_t.format(href=var_film))

    # display the HTML object to put the link on the page:
    display(html)

## Sampling from different mixture components

In [6]:
fn1 = predict_sequence(decoder,frames=512, select_mix=True, mix=1)


[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (pis):  1
[0. 1. 0.]
m sampled from categorical (p

m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0.

m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
[0.

m sampled from categorical (pis):  2
[0. 0. 1.]
m sampled from categorical (pis):  2
Generated motion  (512, 66) with filename  100420-pi_temp-1e-05-sig_temp-1e-05-mix-1-primer_idx-0MIX


### Animating in MATLAB
using mocap toolbox

In [10]:
%get fn1
build_animation(fn1);


fn =

    '100420-pi_temp-1e-05-sig_temp-1e-05-mix-1-primer_idx-0MIX.tsv'



In [9]:
print(fn1)
display_animation(fn1)

100420-pi_temp-1e-05-sig_temp-1e-05-mix-1-primer_idx-0MIX


## Sampling with temperature adjustment

In [11]:
fn2 = predict_sequence(decoder,frames=512,pi=1e-3,sigma=1e-2)

Generated motion  (512, 66) with filename  100420-pi_temp-0.001-sig_temp-0.01-mix-0-primer_idx-0


In [12]:
%get fn2
build_animation(fn2);


fn =

    '100420-pi_temp-0.001-sig_temp-0.01-mix-0-primer_idx-0.tsv'



In [13]:
display_animation(fn2)

## Sampling with priming
priming the model on a motion example 

In [14]:
fn3 = predict_sequence(decoder,frames=512, use_priming=True, primer_idx=0)

Generated motion  (512, 66) with filename  100420-pi_temp-1e-05-sig_temp-1e-05-mix-0-primer_idx-0PRIMING


In [15]:
%get fn3
build_animation(fn3);


fn =

    '100420-pi_temp-1e-05-sig_temp-1e-05-mix-0-primer_idx-0PRIMING.tsv'



## Displaying animations

In [19]:
display_animation(fn3)