In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import yaml

In [None]:
sys.path.append("/home/caleml/main-pe/")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from data.datasets.h36m import Human36M
from data.utils.data_utils import TEST_MODE, TRAIN_MODE, VALID_MODE
from data.loader import BatchLoader

In [None]:
from model import config
from model.utils import pose_format, log

# Dataset

In [None]:
h36m_path = '/home/caleml/datasets/h36m'
h36m = Human36M(h36m_path, dataconf=config.human36m_dataconf, poselayout=pose_format.pa17j3d, topology='frames') 

In [None]:
# validation dataset
data_val_h36m = BatchLoader(
    h36m, 
    ['frame'], 
    ['pose_w', 'pose_uvd', 'afmat', 'camera'], 
    VALID_MODE, 
    batch_size=16, 
    shuffle=True)

# batch_size=h36m.get_length(VALID_MODE), 

# Load models

In [None]:
exps = {
    'cycle_tf2': '/home/caleml/pe_experiments/external/exp_20190517_1748_cycle_reduced_None_2b_bs16',
    'cycle_tf12': '/home/caleml/pe_experiments/exp_20190517_1940_cycle_reduced_None_2b_bs16',
    'cycle_128': '/home/caleml/pe_experiments/exp_20190507_1845_cycle_reduced_None_2b_bs16'
}

In [None]:
best_data = {
    'cycle_128': ['15', 70.45250463981606],
    'cycle_tf12': ['7', 71.84316002648272],
    'cycle_tf2': ['2', 77.12471914392282]
}

In [None]:
from model.networks.cycle_reduced import CycleReduced

eval_models = dict()

for model_name, model_folder in exps.items():
    
    # config
    config_file = os.path.join(model_folder, 'config.yaml')
    with open(config_file, 'r') as f_conf:
        model_config = yaml.safe_load(f_conf)
        
    # checkpoint h5
    latest_file = None
    latest_num = 0
    for filename in os.listdir(model_folder):
        if not filename.startswith('weights_'):
            continue
        
        file_id = int(filename.split('_')[1].split('.')[0])
        if file_id > latest_num:
            latest_file = os.path.join(model_folder, filename)
            latest_num = file_id
    print('Found latest weights %s for model %s' % (latest_file, model_name))
    
    # load
    eval_model = CycleReduced(dim=3, n_joints=17, nb_pose_blocks=model_config['pose_blocks'])
    eval_model.build()
    eval_model.load_weights(latest_file)
    
    eval_models[model_name] = eval_model

## Eval

In [None]:
val_data = data_val_h36m.get_data(1, VALID_MODE)
print(val_data['frame'].shape)

In [None]:
val_preds = dict()

for model_name, eval_model in eval_models.items():
    val_pred = eval_model.predict(val_data['frame'])
    # print(val_pred.shape)
    
    val_preds[model_name] = val_pred

### figure

In [None]:
batch_size = len(val_data['frame'])

n_cols = 2 + len(eval_models)
n_rows = batch_size
fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(12, 50))

# column titles
col_names = ['z_a_img', 'z_p_img'] + list(eval_models.keys())
for ax, cname in zip(axes[0], col_names):
    ax.set_title(cname)

for i in range(batch_size):
    z_a_img = val_data['frame'][i]
    z_p_img = val_data['frame'][(i+1) % batch_size]

    # z_a image
    axes[i][0].imshow(z_a_img)

    # z_p image
    axes[i][1].imshow(z_p_img)

    # i_hat_mixed
    for j, val_pred in enumerate(val_preds.values()):
        pred_img = val_pred[-1][i]  # i_hat_mixed is the last output
        axes[i][2 + j].imshow(pred_img)
    
plt.tight_layout() 
plt.show()
fig.savefig('/home/caleml/main-pe/experiments/cycle_viz_comparison.png')

In [None]:
fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(12, 75))

axes[0][0].imshow(pred_img)