Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adversarial Robustness: problem with loading pertained model #317

Closed
YiDongOuYang opened this issue Jan 5, 2022 · 4 comments
Closed

Adversarial Robustness: problem with loading pertained model #317

YiDongOuYang opened this issue Jan 5, 2022 · 4 comments

Comments

@YiDongOuYang
Copy link

https://github.com/deepmind/deepmind-research/blob/fba48d1e44d86628b65a31549560b7be2a25d823/adversarial_robustness/jax/eval.py#L89

Dear authors, thank you for your great work! However, I cannot load the pertained model provided by you, i.e., cifar10_linf_wrn28-10_cutmix_ddpm_v2.npy. The key in this npy file is different from the WideResNet class in model_zoo.py.

what the npy file saved:
what the npy file saved

what we need for params:
what we need for params

what we need for state:
what we need for state

@sgowal
Copy link
Collaborator

sgowal commented Jan 5, 2022

Thank you for reporting the error. Indeed, some JAX checkpoint were wrongly imported (all PyTorch checkpoints should work). Could you try the new cifar10_linf_wrn28-10_cutmix_ddpm_v2.npy model and let us know if it works.

@YiDongOuYang
Copy link
Author

Dear Gowal, thank you very much for your precise and swift reply. I really appreciate it!
https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_cutmix_ddpm_v2.npy This model works well!

BTW, the PyTorch checkpoints work very well! However, there is no training code provided in this repository (pytorch version), and we can only use the eval.py to test the clear accuracy rather than robust accuracy. Even I get the pretrained model, I cannot finetune the model or test the model under robust criterion unless I write these codes in pytorch.

I do have a look at Rahul Rade's implementation https://github.com/imrahulr/adversarial_robustness_pytorch. However, there is also some gaps, i.e., they don't provide pre-trained model, the pre-trained model provided by you cannot load into his repository. (I believe it can with some small efforts.)

I am wondering could you please release some of the necessary code of pytorch version? (I don't want to impose on you. If it is not convenient, I could also fill in the gap between your pretrained model and Rahul Rade's implementation to help others like me to adopt your code!) Thank you again:D

@sgowal
Copy link
Collaborator

sgowal commented Jan 7, 2022

Hi YiDongOuYang,

Unfortunately we do not plan to release an official PyTorch version of the training pipeline. We encourage you to use the JAX implementation if you can.

Feel free to reach out to me directly if you find other JAX checkpoints that cannot be loaded. I apologize for the inconvenience.

Regards,
Sven

@YiDongOuYang
Copy link
Author

Hi Sven,

No worry! Thank you again!

I check other JAX checkpoints just now. I found that four links for cifar-10 datasets are not correct https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness/iclrw2021doing.

Best wishes,
Yidong

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants