-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Empirical Fisher Estimation #11
Comments
Thanks for your comment! You’re right that the order in which the square and average operations are applied matters, exactly for the reason you point out. But the code in my repository does use the correct order, because the Fisher Information matrix is calculated with batches of size 1 (see here). I admit this might not be the most efficient implementation (alternative suggestions are very welcome!), but the reason I chose this implementation for now is that (as far as I’m aware) in PyTorch it is currently not possible to access the gradients of individual elements in a sum. |
I see. Estimating the Fisher information matrix with a single example seems to result in high variance (see this notebook for example). Have you seen any difference if a larger batch size is used? As you said, I don't think it's possible to access individual gradients with just one backward call. |
Sorry, I should have been clearer, I realise my use of ‘batches’ is a bit confusing here. What I meant is that the backward passes(*) are done one-by-one for each sample used the calculate the Fisher Information matrix. The number of samples used to calculate the Fisher Information matrix is typically not 1 (this is set by the option (*) actually also the forward passes; I now realise this could be made more efficient by at least performing the forward passes with larger ‘batches’. I’ll look into that when I get some time. |
It seems convenient to average the gradients over samples by calling
F.nll_loss
before squaring them, as we only need one backward pass. However, I feel like the diagonal of the empirical Fisher information matrix should be calculated by squaring the gradients before taking their average (as done in this Tensorflow implementation). Can you please confirm that the order doesn't matter here?My understanding is that the expected values of the gradients are 0 (see this Wiki), so if you do averaging first, the Fisher values are very close to 0, which seems incorrect. Am I missing something here? Please let me know what you think. Thank you.
continual-learning/continual_learner.py
Lines 110 to 125 in ff0e03c
The text was updated successfully, but these errors were encountered: