In [1]:
# python train.py \
# --config configs/pose3d/MB_train_h36m.yaml \
# --evaluate checkpoint/pose3d/MB_train_h36m/best_epoch.bin         

In [2]:
import getpass
user = getpass.getuser()
motionbert_root = '/home/{}/codes/MotionBERT'.format(user)

In [3]:
import os
import numpy as np
import argparse
import errno
import math
import pickle
import tensorboardX
from tqdm import tqdm
from time import time
import copy
import random
import prettytable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

os.chdir(motionbert_root)

from lib.utils.tools import *
from lib.utils.learning import *
from lib.utils.utils_data import flip_data
from lib.data.dataset_motion_2d import PoseTrackDataset2D, InstaVDataset2D
from lib.data.dataset_motion_3d import MotionDataset3D
from lib.data.augmentation import Augmenter2D
from lib.data.datareader_aihub import DataReaderAIHUB
from lib.model.loss import *

from train import set_random_seed, save_checkpoint

In [4]:
config = 'MB_ft_tr_aihub_sport_ts_30'
model_name = 'FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30'
#model_name = 'MB_train_h36m'

In [5]:
import easydict

opts = easydict.EasyDict({
    "config": "configs/pose3d/{}.yaml".format(config),
    "checkpoint": 'checkpoint',
    "pretrained": 'checkpoint',
    "resume": '',
    "evaluate": 'checkpoint/pose3d/{}/best_epoch.bin'.format(model_name),
    "selection": 'best_epoch.bin',
    "seed": 0,
    })
set_random_seed(opts.seed)
args = get_config(opts.config)

In [6]:
try:
    os.makedirs(opts.checkpoint)
except OSError as e:
    if e.errno != errno.EEXIST:
        raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint)
train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs"))

In [7]:
args.batch_size

16

In [8]:
args.subset_list

['AIHUB_tr_SPORT_ts_30']

In [9]:
print('Loading dataset...')
trainloader_params = {
        'batch_size': args.batch_size,
        'shuffle': True,
        'num_workers': 12,
        'pin_memory': True,
        'prefetch_factor': 4,
        'persistent_workers': True
}

testloader_params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'num_workers': 12,
        'pin_memory': True,
        'prefetch_factor': 4,
        'persistent_workers': True
}

train_dataset = MotionDataset3D(args, args.subset_list, 'train')
test_dataset = MotionDataset3D(args, args.subset_list, 'test')
train_loader_3d = DataLoader(train_dataset, **trainloader_params)
test_loader = DataLoader(test_dataset, **testloader_params)

Loading dataset...


In [10]:
datareader = DataReaderAIHUB(n_frames=args.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=args.clip_len, dt_root = 'data/motion3d', dt_file=args.dt_file)

In [11]:
min_loss = 100000
model_backbone = load_backbone(args)
model_params = 0
for parameter in model_backbone.parameters():
    model_params = model_params + parameter.numel()
print('INFO: Trainable parameter count:', model_params)

if torch.cuda.is_available():
    model_backbone = nn.DataParallel(model_backbone)
    model_backbone = model_backbone.cuda()

INFO: Trainable parameter count: 42466317


In [12]:
args.finetune, opts.resume, opts.evaluate

(True,
 '',
 'checkpoint/pose3d/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30/best_epoch.bin')

In [13]:
chk_filename = opts.evaluate if opts.evaluate else opts.resume
print('Loading checkpoint', chk_filename)
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
model_backbone.load_state_dict(checkpoint['model_pos'], strict=True)
model_pos = model_backbone

Loading checkpoint checkpoint/pose3d/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30/best_epoch.bin


In [14]:
args.partial_train

In [15]:
opts.evaluate

'checkpoint/pose3d/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30/best_epoch.bin'

#### evaluate

In [16]:
# args, model_pos, test_loader, datareader

In [17]:
args.no_conf, args.flip, args.rootrel, args.gt_2d

(False, True, True, False)

In [18]:
args.flip = True

In [19]:
torch.cuda.is_available()

True

In [20]:
results_all = []
model_pos.eval()            
with torch.no_grad():
    for batch_input, batch_gt in tqdm(test_loader):
        N, T = batch_gt.shape[:2] # B, N
        if torch.cuda.is_available():
            batch_input = batch_input.cuda()
        if args.flip:    
            batch_input_flip = flip_data(batch_input)
            predicted_3d_pos_1 = model_pos(batch_input)
            predicted_3d_pos_flip = model_pos(batch_input_flip)
            predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip)                   # Flip back
            predicted_3d_pos = (predicted_3d_pos_1+predicted_3d_pos_2) / 2
        else:
            predicted_3d_pos = model_pos(batch_input)
        results_all.append(predicted_3d_pos.cpu().numpy())
results_all = np.concatenate(results_all)
results_all = datareader.denormalize(results_all)

100%|██████████| 1/1 [00:02<00:00,  2.03s/it]


In [21]:
results_all.shape

(11, 243, 17, 3)

In [22]:
np.save('custom_codes/evaluation/{}_result_denormalized.npy'.format(model_name), results_all)

In [23]:
results_all = np.load('custom_codes/evaluation/{}_result_denormalized.npy'.format(model_name))

_, split_id_test = datareader.get_split_id()
actions = np.array(datareader.dt_dataset['test']['action'])
factors = np.array(datareader.dt_dataset['test']['2.5d_factor'])
gts = np.array(datareader.dt_dataset['test']['joints_2.5d_image'])
sources = np.array(datareader.dt_dataset['test']['source'])

num_test_frames = len(actions)
frames = np.array(range(num_test_frames))
action_clips = np.array([actions[split_id_test[i]] for i in range(len(split_id_test))]) # actions[split_id_test]
factor_clips = np.array([factors[split_id_test[i]] for i in range(len(split_id_test))]) # factors[split_id_test]
source_clips = np.array([sources[split_id_test[i]] for i in range(len(split_id_test))]) # sources[split_id_test]
frame_clips = np.array([frames[split_id_test[i]] for i in range(len(split_id_test))]) # frames[split_id_test]
gt_clips = np.array([gts[split_id_test[i]] for i in range(len(split_id_test))]) # gts[split_id_test]
assert len(results_all)==len(action_clips)

e1_all = np.zeros(num_test_frames)
e2_all = np.zeros(num_test_frames)
oc = np.zeros(num_test_frames)
action_names = sorted(set(datareader.dt_dataset['test']['action']))
block_list = ['s_09_act_05_subact_02', 
                's_09_act_10_subact_02', 
                's_09_act_13_subact_01']

for idx in range(len(action_clips)):
    source = source_clips[idx][0]
    if source in block_list:
        continue
    frame_list = frame_clips[idx] # numpy.ndarray
    action = action_clips[idx][0]
    factor = factor_clips[idx][:,None,None]
    gt = gt_clips[idx]
    pred = copy.deepcopy(results_all[idx])
    pred *= factor
    
    # Root-relative Errors
    pred = pred - pred[:,0:1,:] # (243, 17, 3)
    gt = gt - gt[:,0:1,:] # (243, 17, 3)
    err1 = mpjpe(pred, gt) # (243,)
    err2 = p_mpjpe(pred, gt) # (243,)
    e1_all[frame_list] += err1 # numpy.ndarray를 인덱스로 사용 가능
    e2_all[frame_list] += err2
    oc[frame_list] += 1 # 프레임별 카운팅

In [24]:
results = {}
results_procrustes = {}

for action in action_names:
    results[action] = []
    results_procrustes[action] = []

for idx in range(num_test_frames):
    if e1_all[idx] > 0:
        err1 = e1_all[idx] / oc[idx]
        err2 = e2_all[idx] / oc[idx]
        action = actions[idx]
        results[action].append(err1)
        results_procrustes[action].append(err2)

final_result = []
final_result_procrustes = []
summary_table = prettytable.PrettyTable()
summary_table.field_names = ['test_name'] + action_names
for action in action_names:
    final_result.append(np.mean(results[action]))
    final_result_procrustes.append(np.mean(results_procrustes[action]))
summary_table.add_row(['P1'] + final_result)
summary_table.add_row(['P2'] + final_result_procrustes)
print(summary_table)
e1 = np.mean(np.array(final_result))
e2 = np.mean(np.array(final_result_procrustes))
print('Protocol #1 Error (MPJPE):', e1, 'mm')
print('Protocol #2 Error (P-MPJPE):', e2, 'mm')
print('----------')

+-----------+--------------------+
| test_name |         30         |
+-----------+--------------------+
|     P1    | 67.40624741414243  |
|     P2    | 46.555171046777424 |
+-----------+--------------------+
Protocol #1 Error (MPJPE): 67.40624741414243 mm
Protocol #2 Error (P-MPJPE): 46.555171046777424 mm
----------


### Visualization

In [25]:
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.gridspec as gridspec
os.chdir(motionbert_root)
from custom_codes.test_utils import *

plt.switch_backend('TkAgg')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [26]:
pred.shape, gt.shape

((243, 17, 3), (243, 17, 3))

In [27]:
pred[0].shape

(17, 3)

In [28]:
frame = 200
visualize_3d_pose([pred[frame], gt[frame]])

In [29]:
source_clips.shape, frame_clips.shape

((11, 243), (11, 243))

In [30]:
source_clips[-1][frame], frame_clips[-1][frame]

('res_30_F170D_5', 3142)

### Visualize one clip

In [31]:
# AIHUB_tr_SPORT_ts_30/test/00000010.pkl
idx = 10
factor = factor_clips[idx][:,None,None]
gt = copy.deepcopy(gt_clips[idx])
pred = copy.deepcopy(results_all[idx])
gt /= factor
pred = pred - pred[:,0:1,:] # (243, 17, 3)
gt = gt - gt[:,0:1,:] # (243, 17, 3)

#### save frames

In [37]:
xlim=(-512, 512)
ylim=(-512, 512)
zlim=(-512, 512)
fig = plt.figure(0, figsize=(10, 10))
ax = plt.axes(projection="3d")
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_zlim(zlim)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
for i in tqdm(range(pred.shape[0])):
    _ax = copy.deepcopy(ax)
    _ax.view_init(elev=12., azim=80+i)
    visualize_multiple_3d_pose([pred[i], gt[i]], _ax, save=True, save_path='./custom_codes/evaluation/{}_idx{}_result'.format(model_name, idx), name='{}.jpg'.format(i), i=i)

100%|██████████| 243/243 [00:59<00:00,  4.10it/s]


#### make the video

In [49]:
import imageio
from natsort import natsorted

img_list = natsorted(os.listdir('./custom_codes/evaluation/{}_idx{}_result'.format(model_name, idx)))
videowriter = imageio.get_writer('./custom_codes/evaluation/{}_idx{}_result/video.mp4'.format(model_name, idx), fps=30)

for img in img_list:
    img_path = os.path.join('./custom_codes/evaluation/{}_idx{}_result'.format(model_name, idx), img)
    print(img_path)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
    videowriter.append_data(img)
videowriter.close()



./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/0.jpg




./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/1.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/2.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/3.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/4.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/5.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/6.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/7.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/8.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/9.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/10.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aihub_sport_ts_30_idx10_result/11.jpg
./custom_codes/evaluation/FT-MB_ft_h36m-MB_ft_tr_aih

### D3DP

In [27]:
# pip install timm einops
d3dp_root = "/home/hrai/codes/D3DP"
os.chdir(d3dp_root)
from common.diffusionpose import *

In [28]:
import argparse

def d3dp_parse_args():
    parser = argparse.ArgumentParser(description='Training script')

    # General arguments
    parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva
    parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str, metavar='NAME', help='2D detections to use')
    parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST',
                        help='training subjects separated by comma')
    parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma')
    parser.add_argument('-sun', '--subjects-unlabeled', default='', type=str, metavar='LIST',
                        help='unlabeled subjects separated by comma for self-supervision')
    parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST',
                        help='actions to train/test on, separated by comma, or * for all')
    parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
                        help='checkpoint directory')
    parser.add_argument('-l', '--log', default='log/default', type=str, metavar='PATH',
                        help='log file directory')
    parser.add_argument('-cf','--checkpoint-frequency', default=20, type=int, metavar='N',
                        help='create a checkpoint every N epochs')
    parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME',
                        help='checkpoint to resume (file name)')
    parser.add_argument('--nolog', action='store_true', help='forbiden log function')
    parser.add_argument('--evaluate', default='h36m_best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
    parser.add_argument('--render', action='store_true', help='visualize a particular video')
    parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)')
    parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images')


    # Model arguments
    parser.add_argument('-s', '--stride', default=243, type=int, metavar='N', help='chunk size to use during training')
    parser.add_argument('-e', '--epochs', default=400, type=int, metavar='N', help='number of training epochs')
    parser.add_argument('-b', '--batch-size', default=11, type=int, metavar='N', help='batch size in terms of predicted frames')
    parser.add_argument('-drop', '--dropout', default=0., type=float, metavar='P', help='dropout probability')
    parser.add_argument('-lr', '--learning-rate', default=0.00006, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('-lrd', '--lr-decay', default=0.993, type=float, metavar='LR', help='learning rate decay per epoch')
    parser.add_argument('--coverlr', action='store_true', help='cover learning rate with assigned during resuming previous model')
    parser.add_argument('-mloss', '--min_loss', default=100000, type=float, help='assign min loss(best loss) during resuming previous model')
    parser.add_argument('-no-da', '--no-data-augmentation', dest='data_augmentation', action='store_false',
                        help='disable train-time flipping')
    parser.add_argument('-cs', default=512, type=int, help='channel size of model, only for trasformer') 
    parser.add_argument('-dep', default=8, type=int, help='depth of model')    
    parser.add_argument('-alpha', default=0.01, type=float, help='used for wf_mpjpe')
    parser.add_argument('-beta', default=2, type=float, help='used for wf_mpjpe')
    parser.add_argument('--postrf', action='store_true', help='use the post refine module')
    parser.add_argument('--ftpostrf', action='store_true', help='For fintune to post refine module')
    # parser.add_argument('-no-tta', '--no-test-time-augmentation', dest='test_time_augmentation', action='store_false',
    #                     help='disable test-time flipping')
    # parser.add_argument('-arc', '--architecture', default='3,3,3', type=str, metavar='LAYERS', help='filter widths separated by comma')
    parser.add_argument('-f', '--number-of-frames', default='243', type=int, metavar='N',
                        help='how many frames used as input')
    # parser.add_argument('--causal', action='store_true', help='use causal convolutions for real-time processing')
    # parser.add_argument('-ch', '--channels', default=1024, type=int, metavar='N', help='number of channels in convolution layers')

    # Experimental
    parser.add_argument('-gpu', default='0', type=str, help='assign the gpu(s) to use')
    parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction')
    parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)')
    parser.add_argument('--warmup', default=1, type=int, metavar='N', help='warm-up epochs for semi-supervision')
    parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)')
    parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions')
    parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions')
    parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection')
    parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term',
                        help='disable bone length term in semi-supervised settings')
    parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting')
    parser.add_argument('--ft', action='store_true', help='use ft 2d(only for detection keypoints!)')
    parser.add_argument('--ftpath', default='checkpoint/exp13_ft2d', type=str, help='assign path of ft2d model chk path')
    parser.add_argument('--ftchk', default='epoch_330.pth', type=str, help='assign ft2d model checkpoint file name')
    parser.add_argument('--no_eval', action='store_true', default=False, help='no_eval')
    
    # Visualization
    parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render')
    parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render')
    parser.add_argument('--viz-camera', type=int, default=1, metavar='N', help='camera to render')
    parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video')
    parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video')
    parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)')
    parser.add_argument('--viz-export', type=str, metavar='PATH', help='output file name for coordinates')
    parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos')
    parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses')
    parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames')
    parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N')
    parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size')
    parser.add_argument('--compare', action='store_true', default=False, help='Whether to compare with other methods e.g. Poseformer')
    # parser.add_argument('-comchk', type=str, default='/mnt/data3/home/zjl/workspace/3dpose/PoseFormer/checkpoint/detected81f.bin', help='checkpoint of comparison methods')

    # ft2d.py
    parser.add_argument('-lcs', '--linear_channel_size', type=int, default=1024, metavar='N', help='channel size of the LinearModel')
    parser.add_argument('-depth', type=int, default=4, metavar='N', help='nums of blocks of the LinearModel')
    parser.add_argument('-ldg', '--lr_decay_gap', type=float, default=10000, metavar='N', help='channel size of the LinearModel')

    parser.add_argument('-scale', default=1.0, type=float, help='the scale of SNR')
    parser.add_argument('-timestep', type=int, default=1000, metavar='N', help='timestep')
    #parser.add_argument('-timestep_eval', type=int, default=1000, metavar='N', help='timestep_eval')
    parser.add_argument('-sampling_timesteps', type=int, default=5, metavar='N', help='sampling_timesteps')
    parser.add_argument('-num_proposals', type=int, default=5, metavar='N')
    parser.add_argument('--debug', action='store_true', default=False, help='debugging mode')
    parser.add_argument('--p2', action='store_true', default=False, help='using protocol #2, i.e., P-MPJPE')


    parser.set_defaults(bone_length_term=True)
    parser.set_defaults(data_augmentation=True)
    #parser.set_defaults(test_time_augmentation=True)
    parser.set_defaults(test_time_augmentation=False)

    args = parser.parse_args('')
    # Check invalid configuration
    if args.resume and args.evaluate:
        print('Invalid flags: --resume and --evaluate cannot be set at the same time')
        exit()
        
    if args.export_training_curves and args.no_eval:
        print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time')
        exit()

    return args

d3dp_args = d3dp_parse_args()

In [29]:
from common.h36m_dataset import Human36mDataset
dataset_path = 'data/data_3d_' + d3dp_args.dataset + '.npz'
dataset = Human36mDataset(dataset_path)
joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right())

In [30]:
model_pos = D3DP(d3dp_args, joints_left, joints_right,  is_train=False, num_proposals=d3dp_args.num_proposals, sampling_timesteps=d3dp_args.sampling_timesteps)

In [31]:
# make model parallel
if torch.cuda.is_available():
    #model_pos = nn.DataParallel(model_pos)
    model_pos = model_pos.float().cuda()

if d3dp_args.resume or d3dp_args.evaluate:
    chk_filename = os.path.join(d3dp_args.checkpoint, d3dp_args.resume if d3dp_args.resume else d3dp_args.evaluate)
    # chk_filename = args.resume or args.evaluate
    print('Loading checkpoint', chk_filename)
    checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
    print('This model was trained for {} epochs'.format(checkpoint['epoch']))
    model_pos.load_state_dict(checkpoint['model_pos'], strict=False)

Loading checkpoint checkpoint/h36m_best_epoch.bin
This model was trained for 194 epochs


In [32]:
def eval_data_prepare(receptive_field, inputs_2d, inputs_3d):
    # inputs_2d_p = torch.squeeze(inputs_2d)
    # inputs_3d_p = inputs_3d.permute(1,0,2,3)
    # out_num = inputs_2d_p.shape[0] - receptive_field + 1
    # eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2])
    # for i in range(out_num):
    #     eval_input_2d[i,:,:,:] = inputs_2d_p[i:i+receptive_field, :, :]
    # return eval_input_2d, inputs_3d_p
    ### split into (f/f1, f1, n, 2)
    assert inputs_2d.shape[:-1] == inputs_3d.shape[:-1], "2d and 3d inputs shape must be same! "+str(inputs_2d.shape)+str(inputs_3d.shape)
    inputs_2d_p = torch.squeeze(inputs_2d)
    inputs_3d_p = torch.squeeze(inputs_3d)

    if inputs_2d_p.shape[0] / receptive_field > inputs_2d_p.shape[0] // receptive_field: 
        out_num = inputs_2d_p.shape[0] // receptive_field+1
    elif inputs_2d_p.shape[0] / receptive_field == inputs_2d_p.shape[0] // receptive_field:
        out_num = inputs_2d_p.shape[0] // receptive_field

    eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2])
    eval_input_3d = torch.empty(out_num, receptive_field, inputs_3d_p.shape[1], inputs_3d_p.shape[2])

    for i in range(out_num-1):
        eval_input_2d[i,:,:,:] = inputs_2d_p[i*receptive_field:i*receptive_field+receptive_field,:,:]
        eval_input_3d[i,:,:,:] = inputs_3d_p[i*receptive_field:i*receptive_field+receptive_field,:,:]
    if inputs_2d_p.shape[0] < receptive_field:
        from torch.nn import functional as F
        pad_right = receptive_field-inputs_2d_p.shape[0]
        inputs_2d_p = rearrange(inputs_2d_p, 'b f c -> f c b')
        inputs_2d_p = F.pad(inputs_2d_p, (0,pad_right), mode='replicate')
        # inputs_2d_p = np.pad(inputs_2d_p, ((0, receptive_field-inputs_2d_p.shape[0]), (0, 0), (0, 0)), 'edge')
        inputs_2d_p = rearrange(inputs_2d_p, 'f c b -> b f c')
    if inputs_3d_p.shape[0] < receptive_field:
        pad_right = receptive_field-inputs_3d_p.shape[0]
        inputs_3d_p = rearrange(inputs_3d_p, 'b f c -> f c b')
        inputs_3d_p = F.pad(inputs_3d_p, (0,pad_right), mode='replicate')
        inputs_3d_p = rearrange(inputs_3d_p, 'f c b -> b f c')
    eval_input_2d[-1,:,:,:] = inputs_2d_p[-receptive_field:,:,:]
    eval_input_3d[-1,:,:,:] = inputs_3d_p[-receptive_field:,:,:]

    return eval_input_2d, eval_input_3d

In [33]:
results_all = []
model_pos.eval()       
model_pos.device = 'cuda'
with torch.no_grad():
    for batch_input, batch_gt in tqdm(test_loader):
        N, T = batch_gt.shape[:2] # B, N
        if torch.cuda.is_available():
            batch_input = batch_input.cuda()
        if args.flip:    
            batch_input_flip = flip_data(batch_input)
            print(batch_input_flip.shape)
            predicted_3d_pos = model_pos(batch_input.float()[:,:,:,:2], None, input_2d_flip=batch_input_flip.float()[:,:,:,:2])
            #predicted_3d_pos_flip = model_pos(batch_input_flip)
            #predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip)                   # Flip back
            #predicted_3d_pos = (predicted_3d_pos_1+predicted_3d_pos_2) / 2
            print(type(predicted_3d_pos), len(predicted_3d_pos), predicted_3d_pos[0].shape)
        else:
            predicted_3d_pos = model_pos(batch_input)
        for i in range(len(predicted_3d_pos)):
            results_all.append(predicted_3d_pos[i].cpu().numpy())
        #results_all.append(predicted_3d_pos.cpu().numpy())
results_all = np.concatenate(results_all)
#results_all = datareader.denormalize(results_all)

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([11, 243, 17, 3])
<class 'list'> 5 torch.Size([11, 5, 243, 17, 3])


100%|██████████| 1/1 [00:04<00:00,  4.66s/it]


In [34]:
results_all.shape

(55, 5, 243, 17, 3)

In [35]:
results_all[0, 0, 0], results_all[3, 0, 0]

(array([[ 0.2839991 ,  0.6319914 , -0.3044045 ],
        [ 0.5497556 ,  0.41571048, -0.121594  ],
        [ 0.3424947 ,  0.5514433 , -0.30137658],
        [ 0.5646411 ,  0.64200276, -0.01160304],
        [ 0.6602721 ,  0.51664174,  0.04166837],
        [ 0.6816317 ,  0.29820374,  0.17094654],
        [ 0.6440543 ,  0.04696334,  0.16382079],
        [ 0.2817564 ,  0.65781814, -0.27978858],
        [ 0.63170516,  0.563669  ,  0.07134522],
        [ 0.56473154,  0.31432903,  0.08045112],
        [ 0.63782513,  0.22986901,  0.21254227],
        [ 0.46128896,  0.5747071 , -0.10331982],
        [ 0.5176493 ,  0.77502596, -0.11199886],
        [ 0.6877425 ,  0.22399183,  0.18962893],
        [ 0.6727799 ,  0.11066622,  0.21668226],
        [ 0.57504237,  0.25577262,  0.15488945],
        [ 0.6294997 ,  0.51427007, -0.14901315]], dtype=float32),
 array([[ 0.57071877,  0.20632711,  0.17829446],
        [ 0.7602687 ,  0.19235393,  0.22151288],
        [ 0.27884093,  0.51204085, -0.3774857 ],
   

In [41]:
get_rootrel_pose(results_all[0, 0, 0]*1000)

array([[   0.       ,    0.       ,    0.       ],
       [ 265.75647  , -216.28091  ,  182.81052  ],
       [  58.495605 ,  -80.548096 ,    3.0279236],
       [ 280.64203  ,   10.011353 ,  292.80148  ],
       [ 376.273    , -115.34967  ,  346.07288  ],
       [ 397.63263  , -333.78766  ,  475.35104  ],
       [ 360.05524  , -585.0281   ,  468.22528  ],
       [  -2.2426758,   25.826721 ,   24.615936 ],
       [ 347.70605  ,  -68.32239  ,  375.74973  ],
       [ 280.73248  , -317.66235  ,  384.85562  ],
       [ 353.82605  , -402.12238  ,  516.9468   ],
       [ 177.28989  ,  -57.2843   ,  201.08469  ],
       [ 233.6502   ,  143.03455  ,  192.40565  ],
       [ 403.7434   , -407.99957  ,  494.03345  ],
       [ 388.78082  , -521.3252   ,  521.0868   ],
       [ 291.04327  , -376.21878  ,  459.29395  ],
       [ 345.5006   , -117.72131  ,  155.39136  ]], dtype=float32)

In [45]:
visualize_3d_pose([get_rootrel_pose(results_all[0, 0, 5]*1000)])