Skip to content

Commit

Permalink
add joints3d
Browse files Browse the repository at this point in the history
  • Loading branch information
mkocabas committed Dec 26, 2019
1 parent 1db923a commit b95074a
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main(args):

with torch.no_grad():

pred_cam, pred_verts, pred_pose, pred_betas, norm_joints2d = [], [], [], [], []
pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], []

for batch in dataloader:
if has_keypoints:
Expand All @@ -163,12 +163,14 @@ def main(args):
pred_verts.append(output['verts'].reshape(batch_size * seqlen, -1, 3))
pred_pose.append(output['theta'][:,:,3:75].reshape(batch_size * seqlen, -1))
pred_betas.append(output['theta'][:, :,75:].reshape(batch_size * seqlen, -1))
pred_joints3d.append(output['kp_3d'].reshape(batch_size * seqlen, -1, 3))


pred_cam = torch.cat(pred_cam, dim=0)
pred_verts = torch.cat(pred_verts, dim=0)
pred_pose = torch.cat(pred_pose, dim=0)
pred_betas = torch.cat(pred_betas, dim=0)
pred_joints3d = torch.cat(pred_joints3d, dim=0)

del batch

Expand All @@ -180,7 +182,7 @@ def main(args):

# Run Temporal SMPLify
update, new_opt_vertices, new_opt_cam, new_opt_pose, new_opt_betas, \
new_opt_joint_loss, opt_joint_loss = smplify_runner(
new_opt_joints3d, new_opt_joint_loss, opt_joint_loss = smplify_runner(
pred_rotmat=pred_pose,
pred_betas=pred_betas,
pred_cam=pred_cam,
Expand All @@ -200,6 +202,7 @@ def main(args):
pred_cam[update] = new_opt_cam[update]
pred_pose[update] = new_opt_pose[update]
pred_betas[update] = new_opt_betas[update]
pred_joints3d[update] = new_opt_joints3d[update]

elif args.run_smplify and args.tracking_method == 'bbox':
print('[WARNING] You need to enable pose tracking to run Temporal SMPLify algorithm!')
Expand All @@ -210,6 +213,7 @@ def main(args):
pred_verts = pred_verts.cpu().numpy()
pred_pose = pred_pose.cpu().numpy()
pred_betas = pred_betas.cpu().numpy()
pred_joints3d = pred_joints3d.cpu().numpy()

orig_cam = convert_crop_cam_to_orig_img(
cam=pred_cam,
Expand All @@ -224,6 +228,7 @@ def main(args):
'verts': pred_verts,
'pose': pred_pose,
'betas': pred_betas,
'joints3d': pred_joints3d,
'joints2d': joints2d,
'bboxes': bboxes,
'frame_ids': frames,
Expand Down

0 comments on commit b95074a

Please sign in to comment.