-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to calculate empirical NTK of model being used in a classification task #68
Comments
Set the |
Hey @romanngg - thanks for a quick response. That helped! I tried to play with the trace_axes option.
Shouldn't I be expecting something along the lines of 640 x 640 ? How would I extrapolate from this to what I need? I'm sure I'm getting something wrong here. I would appreciate it if you could elaborate. |
To clarify - you need to set In general, if your outputs then the output kernel will have shape BUT it will have pairs of axes having a subscript |
Ah - that definitely helps. With the empty parens, I get an output kernel with shape Does that sound alright to you? |
Yep, correct! |
Awesome. Thanks! PS: I was trying to tag this as a question, but I wasn't able to. |
Also, I just went from a batch of 64 to a batch 256. Computing the output empirical NTK for n = 256 and k = 10 takes about 7 times longer (4.12 sec -> 27 sec) How would you suggest I optimize this? If I have to run this for a lot of epochs, calculating the empirical NTK at each epoch, (and not only at the beginning of training) might take a bit long. |
Likely only repo owners are allowed to do this, I don't see an option to let users set it... Sadly the increased time is expected (in fact it would grow quadratically with batch size). See #30 for ongoing discussion about performance. Depending on your application, you may want to compute the empirical NTK with a single output logit to gain yourself an extra factor of |
Okay - what you're saying is that in the infinite width/sample limit, we may be justified in computing the How do you think that would change for sparse layers/networks though? Wouldn't the empirical NTK be more accurate for sparse networks if computed across all logits as compared to only a single logit? Maybe there's no answer to that question yet. Re: #30 , I did comment on that earlier yesterday. I'm not sure I follow the method of accumulating the empirical output NTK since that ignores cross-layer weight covariances. Unless we're doing that on purpose, which probably translates to accumulating the diagonal matrix (same as what we do here above using the diagonal_axes() option) for a single logit. |
Re accuracy, it likely indeed depends on how you measure its accuracy precisely. For example, if your measure is how close the empirical kernel in terms of Frobenius norm to the infinite-width, infinite-sample NTK, or only infinite-sample but finite-width NTK, I think you would still get better accuracy / FLOPS if you use a single logit than multiple. If you want the best linear approximation to your finite-width network, then having multiple logits is more accurate. Re #30, I have not looked into the per-layer implementation, but AFAIK "cross-layer weight covariances" are not needed in NTK, in fact no cross-weight covariances are present in the expression: Also, I've just pushed f15b652 which should make computing empirical NTK faster wrt to |
For a model being used for classification with
k
classes, forn
datapoints, the NTK should be of the sizenk
Xnk
. How would we get that with neural-tangents?Currently, I'm able to get a
n
Xn
matrix.The text was updated successfully, but these errors were encountered: