-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support and test for KFAC-expand and KFAC-reduce #26
Conversation
Thanks for the PR, I will take a look now. Regarding the scaling issues: Do people usually write custom loss functions for the settings you describe, or do they rely on built-in ones? I think our implementation should prioritize easy usage together with PyTorch's built-in modular losses such as |
They usually rely on built-in ones and I agree that we should prioritize With other words, this discrepancy arises because we can apply both approximations, KFAC-expand and KFAC-reduce, in the expand AND the reduce setting, i.e. with a loss with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the comments from my first pass through the PR.
There are ~5 bigger refactorings that might make sense to work on before giving it a second pass. I also complained about missing documentation, but don't worry too much for now.
Thanks a lot for the thorough review, I think I have addressed all comments and the test is already looking much better now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting there! I only have one major refactoring request left + many smaller things.
Think we will need 1-2 more rounds to merge 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only very minor things. Ping me once you've gone through them and I will finish things off 👍
Looks good to me! Applied minor refactoring and fixes. Currently running CI and will merge if passing. |
Resolves #14 and partially resolves #13.
singd/optim/utils.py
can probably still be improved, for now I focused on supporting KFAC-expand and KFAC-reduce and their correctness. However, I think it is already clean enough for a first release.A note on one design choice I made: Here, I assume that the final loss is averaged over
batch_size
terms, since this is always the case for the losses we consider when conv layers are used. In contrast, here I assume that if the KFAC-expand approximation is used for linear layers, that the loss was also averaged over the sequence dimension whenbatch_averaged=True
. This holds for language modelling, but e.g. with a vision transformer, a classification task,kfac_approx="expand"
, andbatch_averaged=True
this will lead to a mismatch of the scale of the preconditioner and the gradient. We could consider adding an additional flag for this or modifying thebatch_averaged
flag.