Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support and test for KFAC-expand and KFAC-reduce #26

Merged
merged 30 commits into from
Oct 24, 2023
Merged

Conversation

runame
Copy link
Collaborator

@runame runame commented Oct 17, 2023

Resolves #14 and partially resolves #13.

singd/optim/utils.py can probably still be improved, for now I focused on supporting KFAC-expand and KFAC-reduce and their correctness. However, I think it is already clean enough for a first release.

A note on one design choice I made: Here, I assume that the final loss is averaged over batch_size terms, since this is always the case for the losses we consider when conv layers are used. In contrast, here I assume that if the KFAC-expand approximation is used for linear layers, that the loss was also averaged over the sequence dimension when batch_averaged=True. This holds for language modelling, but e.g. with a vision transformer, a classification task, kfac_approx="expand", and batch_averaged=True this will lead to a mismatch of the scale of the preconditioner and the gradient. We could consider adding an additional flag for this or modifying the batch_averaged flag.

@runame runame added the enhancement New feature or request label Oct 17, 2023
@runame runame requested a review from f-dangel October 17, 2023 20:29
@f-dangel
Copy link
Owner

f-dangel commented Oct 18, 2023

Thanks for the PR, I will take a look now.

Regarding the scaling issues: Do people usually write custom loss functions for the settings you describe, or do they rely on built-in ones? I think our implementation should prioritize easy usage together with PyTorch's built-in modular losses such as nn.MSELoss and nn.CrossEntropyLoss which both have a reduction='mean/sum' argument and also work with d>2-dimensional predictions. For custom loss functions, the user must specify more information than just the reduction argument.

@runame
Copy link
Collaborator Author

runame commented Oct 18, 2023

They usually rely on built-in ones and I agree that we should prioritize nn.MSELoss and nn.CrossEntropyLoss. However, this does not solve this issue; for example, see the vision transformer example I have described above. To rephrase it a bit, when using a vision transformer for image classification and we want to use reduction=mean, i.e. batch_averaged=True, the loss will be divided by batch_size. However, currently we divide the preconditioner by batch_size * sequence_length for KFAC-expand if batch_averaged=True. Conversely, if we choose to only divide by batch_size, this would mean that in a language modelling task we would only scale the preconditioner by batch_size, whereas the loss will be averaged over batch_size * sequence_length.

With other words, this discrepancy arises because we can apply both approximations, KFAC-expand and KFAC-reduce, in the expand AND the reduce setting, i.e. with a loss with N * R or N terms, where R is the sequence length. Without additional information, we cannot deduce the setting just based on the KFAC approximation we use.

Copy link
Owner

@f-dangel f-dangel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the comments from my first pass through the PR.

There are ~5 bigger refactorings that might make sense to work on before giving it a second pass. I also complained about missing documentation, but don't worry too much for now.

singd/optim/optimizer.py Show resolved Hide resolved
singd/optim/optimizer.py Outdated Show resolved Hide resolved
singd/optim/optimizer.py Outdated Show resolved Hide resolved
singd/optim/optimizer.py Outdated Show resolved Hide resolved
singd/optim/utils.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/utils.py Outdated Show resolved Hide resolved
test/optim/utils.py Outdated Show resolved Hide resolved
@runame runame requested a review from f-dangel October 19, 2023 14:43
@runame
Copy link
Collaborator Author

runame commented Oct 19, 2023

Thanks a lot for the thorough review, I think I have addressed all comments and the test is already looking much better now.

Copy link
Owner

@f-dangel f-dangel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting there! I only have one major refactoring request left + many smaller things.

Think we will need 1-2 more rounds to merge 👍

singd/optim/optimizer.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/utils.py Outdated Show resolved Hide resolved
test/optim/utils.py Show resolved Hide resolved
test/optim/utils.py Outdated Show resolved Hide resolved
test/optim/utils.py Outdated Show resolved Hide resolved
@runame runame requested a review from f-dangel October 19, 2023 22:18
Copy link
Owner

@f-dangel f-dangel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only very minor things. Ping me once you've gone through them and I will finish things off 👍

test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
test/optim/test_kfac.py Outdated Show resolved Hide resolved
@runame runame requested a review from f-dangel October 21, 2023 17:01
@f-dangel
Copy link
Owner

Looks good to me! Applied minor refactoring and fixes. Currently running CI and will merge if passing.

@runame runame merged commit 0c121df into main Oct 24, 2023
14 checks passed
@runame runame deleted the test-kfac branch October 24, 2023 00:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support KFAC-expand and KFAC-reduce Clean up singd/optim/utils.py
2 participants