Skip to content

Commit

Permalink
Fixed issue #1 (output channel bug)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpskex committed May 24, 2018
1 parent cb51da8 commit 2b3234f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ data/
*.ckpt*

# Ignore test images
test.*
*.jpg
11 changes: 4 additions & 7 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,12 @@ def predict(img_list, model_path=None, thresh=0.05, is_name=False, cpu_only=True
model.restore_sess(model_path)

# get the last stage's result
pred_map = model.sess.run(model.output, feed_dict={model.img: _input / 255.0})[:, -1]
pred_map = model.sess.run(model.output, feed_dict={model.img: _input / 255.0})[:, -1, :, :, :-1]
if debug:
np.save('pred.npy', pred_map)

j = -1 * np.ones((len(_img_list), pred_map.shape[-1]-1, 2))
w = np.zeros((len(_img_list), pred_map.shape[-1]-1))
j = -1 * np.ones((len(_img_list), pred_map.shape[-1], 2))
w = np.zeros((len(_img_list), pred_map.shape[-1]))
for idx in range(len(_img_list)):
# re-project heatmap to origin size

Expand Down Expand Up @@ -192,10 +192,7 @@ def predict(img_list, model_path=None, thresh=0.05, is_name=False, cpu_only=True
if __name__=='__main__':
""" Demo of Using the model API
"""
img_names = ['000061164.jpg', '000078951.jpg', '000094304.jpg', '000099899.jpg',
'000065339.jpg', '000085370.jpg', '000094342.jpg', '000109154.jpg',
'000071686.jpg', '000090584.jpg', '000099186.jpg','000111209.jpg'
]
img_names = ['test.jpg']
# input must be greater than one
assert len(img_names) >= 1
j = predict(img_names, 'model/model.ckpt-99', debug=True, is_name=True)
Expand Down
9 changes: 5 additions & 4 deletions valPCK.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ def compute_distance(model, dataset, metric='PCKh', debug=False):
j_gt = np.array(j_list)

# estimate by network
j_dt,_ = predict.predict(im_list, model=model, thresh=0.0)
j_dt,_ = predict.predict(im_list, model=model, thresh=0.0, debug=True)
w = np.transpose(np.hstack((np.expand_dims(w,1), np.expand_dims(w,1))), axes=[0,2,1])
assert j_dt.shape == j_gt.shape

for n in range(len(Global.joint_list)):
for n in range(j_dt.shape[0]):
err[_iter*paral+n] = np.linalg.norm(w[n]*(j_gt[n,:,:]-j_dt[n,:,-1::-1]),axis=1) / np.linalg.norm(j_gt[n,normJ_a,:]-j_gt[n,normJ_b,:], axis=0)
print "[*]\tTemp Error is ", np.average(err[_iter*paral:_iter*paral+paral], axis=0)
if debug:
break
return err

aver_err = np.average(err)
print "[*]\tAverage PCKh Normalised distance is ", aver_err
Expand All @@ -123,5 +124,5 @@ def compute_distance(model, dataset, metric='PCKh', debug=False):
model.BuildModel()
model.restore_sess('model/model.ckpt-99')

dist = compute_distance(model, dataset, metric='PCKh', debug=True)
dist = compute_distance(model, dataset, metric='PCKh')
visualize_accuracy(dist, start=0.01, end=0.5, showlist=[0,1,4,5,10,11,14,15])

0 comments on commit 2b3234f

Please sign in to comment.