Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does this code reproduce the results of 8 stacked hg in the original paper? #15

Closed
zhiqiangdon opened this issue Sep 12, 2017 · 11 comments

Comments

@zhiqiangdon
Copy link

Hi,

Thanks for sharing your code! Does this code reproduce the results of 8 stacked HG in the original paper? If not, what's your results of 8 stacked HG? Any possible reasons between the gap?

Best,

@bearpaw
Copy link
Owner

bearpaw commented Sep 13, 2017

@zhiqiangdon It should reproduce the results of the original paper. See our latest result with an 8-stack hourglass model on the MPII validation set https://github.com/bearpaw/pytorch-pose#evaluate-the-pckh05-score

@zhiqiangdon
Copy link
Author

Here are the training logs of the original torch code.
https://github.com/anewell/pose-hg-train/files/623028/train.txt
https://github.com/anewell/pose-hg-train/files/623029/valid.txt

According to your figure and their logs, there seems a gap between yours and theirs(theirs is better). Please note that in their logs, they don't count head, neck, shoulders, chest, stomach (6 pts). If all the points are counted, their pckh is more than 90.

@bearpaw
Copy link
Owner

bearpaw commented Sep 13, 2017

@zhiqiangdon There might have some subtle differences in details. I'm not quite sure.

By the way, this project is not to reproduce results from one specific paper. I try to make the code general and extensible enough for other datasets and tasks. If you find any bugs and potential problems in the implementation, welcome to create a pull request. Thank you.

@xiaoyong
Copy link

@zhiqiangdon In this repo, the 6 points are also excluded when computing accuracy. Check
https://github.com/bearpaw/pytorch-pose/blob/master/example/mpii.py#L28
and
https://github.com/anewell/pose-hg-train/blob/master/src/util/dataset/mpii.lua#L6

So they can be fairly compared. How is the training and validation accuracy for the s8-b1 model? @bearpaw

@bearpaw
Copy link
Owner

bearpaw commented Sep 13, 2017

Thanks @xiaoyong . The results are updated in https://github.com/bearpaw/pytorch-pose#evaluate-the-pckh05-score

You can also view the training log at Google Drive: https://drive.google.com/drive/u/1/folders/0B63t5HSgY4SQMUJqZDlQOGpQdlk

However, I don't think the log is comparable to the original log as mentioned by @zhiqiangdon
https://github.com/anewell/pose-hg-train/files/623028/train.txt
https://github.com/anewell/pose-hg-train/files/623029/valid.txt

Also, I cannot get more than 90 PCK (~ 88-89 PCK) by running their code.

@anibali
Copy link

anibali commented Sep 19, 2017

For anyone who might be interested, I have taken the MPII evaluation code (written in Matlab) and tweaked it slightly so that it can load "flat" prediction files (ie the format used by bearpaw's and anewell's repos) and evaluate on the validation set. When I do this, I actually see better performance from the PyTorch HG8 model (https://drive.google.com/drive/folders/0B63t5HSgY4SQMUJqZDlQOGpQdlk) than the original pretrained HG8 model (http://www-personal.umich.edu/~alnewell/pose/).

pckh-total-mpii

If you want to see all of the results for different joints, my adapted evaluation code is up at https://github.com/anibali/eval-mpii-pose and you can generate all of the results graphs with octave evalMPII.m.

@bearpaw
Copy link
Owner

bearpaw commented Sep 20, 2017

@anibali Great job! Thanks for sharing the results and the great evaluation code!

@zhiqiangdon
Copy link
Author

zhiqiangdon commented Sep 20, 2017

@bearpaw I have read your code of evaluating the accuracy. As @xiaoyong mentioned, you also remove the 6 points when computing the pckh during the training and evaluation. Thus, your log is comparable to the log of the original anewell's. This pckh is only an approximation because it is normalized by 1/10 the heatmap height or width and it doesn't count the 6 points with highest pckhs. According to my experience, the average pckh should be obviously higher when you compute it on the original image using the head size as the normalization and all the 16 points. But your exact and approximate pckhs are very close. I wonder if there is something wrong when you computing the exact pckh?

@zhiqiangdon
Copy link
Author

@anibali Thanks for your work. I guess the model in http://www-personal.umich.edu/~alnewell/pose/ is their best model. Do you have some explanations why the pckh @0.5 in your figure is lower than that reported in their paper?

@anibali
Copy link

anibali commented Sep 20, 2017

@zhiqiangdon It's not worse than that reported in their paper. See Fig. 9 of their paper, which shows results on the validation set. Table 2 does show a higher accuracy, but those results are for the test set. The first author reports that test set accuracy is generally higher than validation set accuracy (source: princeton-vl/pose-hg-demo#1 (comment)).

@zhiqiangdon
Copy link
Author

@anibali , I see. Thanks!

@bearpaw bearpaw closed this as completed Sep 30, 2017
@mkocabas mkocabas mentioned this issue May 11, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants