Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to calculate empirical NTK of model being used in a classification task #68

Open
uditsaxena opened this issue Sep 10, 2020 · 10 comments
Labels
question Further information is requested

Comments

@uditsaxena
Copy link

uditsaxena commented Sep 10, 2020

For a model being used for classification with k classes, for n datapoints, the NTK should be of the size nk X nk. How would we get that with neural-tangents?

Currently, I'm able to get a n X n matrix.

@romanngg
Copy link
Contributor

Set the trace_axes=() (see more details on this argument in https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html - happy to elaborate if needed!)

@uditsaxena
Copy link
Author

Hey @romanngg - thanks for a quick response. That helped!

I tried to play with the trace_axes option.
For n = 64, and k = 10, setting trace_axes to :

  • 0 retrieves a 10 x 10 matrix
  • 1 retrieves a 64 x 64 matrix

Shouldn't I be expecting something along the lines of 640 x 640 ? How would I extrapolate from this to what I need?

I'm sure I'm getting something wrong here. I would appreciate it if you could elaborate.
Thanks!

@romanngg
Copy link
Contributor

To clarify - you need to set trace_axes=() (empty tuple, not 0 or 1).

In general, if your outputs f1 and f2 have shapes
(N1_0, N1_1, N1_2, ..., N1_K),
(N2_0, N2_1, N2_2, ..., N1_K),

then the output kernel will have shape
(N1_0, N2_0, N1_1, N2_1, N1_2, N2_2, ..., N1_K, N2_K) (not 2D, but rather 2*K-D),

BUT it will have pairs of axes having a subscript i missing if i is in trace_axes (and a similar mechanism with diagonal_axes). Lmk if this helps!

@uditsaxena
Copy link
Author

uditsaxena commented Sep 10, 2020

Ah - that definitely helps.

With the empty parens, I get an output kernel with shape (n, n, k, k). The way I understand this is for an (i, j, k, k) the pair (i, j) refers to the i and j data points, and the k x k submatrix refers to their logits.

Does that sound alright to you?

@romanngg
Copy link
Contributor

Yep, correct!

@uditsaxena
Copy link
Author

Awesome. Thanks!

PS: I was trying to tag this as a question, but I wasn't able to.

@romanngg romanngg added the question Further information is requested label Sep 10, 2020
@uditsaxena
Copy link
Author

uditsaxena commented Sep 10, 2020

Also, I just went from a batch of 64 to a batch 256. Computing the output empirical NTK for n = 256 and k = 10 takes about 7 times longer (4.12 sec -> 27 sec)

How would you suggest I optimize this? If I have to run this for a lot of epochs, calculating the empirical NTK at each epoch, (and not only at the beginning of training) might take a bit long.
All options are on the table for now.

@romanngg
Copy link
Contributor

Likely only repo owners are allowed to do this, I don't see an option to let users set it...

Sadly the increased time is expected (in fact it would grow quadratically with batch size). See #30 for ongoing discussion about performance.

Depending on your application, you may want to compute the empirical NTK with a single output logit to gain yourself an extra factor of k x k. Note that in many use-cases (if you have a stax.Dense layer on top), all those k x k tensor slices will converge to a constant-diagonal matrix in the infinite-width/sample limit, so you may be justified in computing an NTK for a single logit only, if your goal is to approximate the infinite-width/sample behavior.

@uditsaxena
Copy link
Author

Okay - what you're saying is that in the infinite width/sample limit, we may be justified in computing the n X n matrix (which is computed for a single logit) instead of computing the whole n X n X k X k matrix since both converge to the same constant-diagonal matrix.
Got it - that should help with the speed up.

How do you think that would change for sparse layers/networks though? Wouldn't the empirical NTK be more accurate for sparse networks if computed across all logits as compared to only a single logit? Maybe there's no answer to that question yet.

Re: #30 , I did comment on that earlier yesterday. I'm not sure I follow the method of accumulating the empirical output NTK since that ignores cross-layer weight covariances. Unless we're doing that on purpose, which probably translates to accumulating the diagonal matrix (same as what we do here above using the diagonal_axes() option) for a single logit.

@romanngg
Copy link
Contributor

Re accuracy, it likely indeed depends on how you measure its accuracy precisely. For example, if your measure is how close the empirical kernel in terms of Frobenius norm to the infinite-width, infinite-sample NTK, or only infinite-sample but finite-width NTK, I think you would still get better accuracy / FLOPS if you use a single logit than multiple. If you want the best linear approximation to your finite-width network, then having multiple logits is more accurate.

Re #30, I have not looked into the per-layer implementation, but AFAIK "cross-layer weight covariances" are not needed in NTK, in fact no cross-weight covariances are present in the expression: df/dp(p, x1) df/dp(p, x2)^T = Sum over all individual scalar w in p [df/dw(p, x1)*df/dw(p, x2)] so you only need to compute covariances between [gradients of outputs wrt] same-weights, different x1 and x2. (in contrast to df/dp(p, x1)^T df/dp(p, x2), which would be a #p x #p matrix of cross-weight covariances). But perhaps this is a trivial point and not what you mean, I may need to look at their code more to better understand it...

Also, I've just pushed f15b652 which should make computing empirical NTK faster wrt to n, especially for CNNs, so hopefully this will also help (you'd need to pass something like vmap_axes=0 to your nt.empirical_ntk_fn - see https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html for more details).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants