Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculating Marginal log-likelihood of NNGP #81

Closed
TZeng20 opened this issue Oct 26, 2020 · 7 comments
Closed

Calculating Marginal log-likelihood of NNGP #81

TZeng20 opened this issue Oct 26, 2020 · 7 comments
Labels
question Further information is requested

Comments

@TZeng20
Copy link

TZeng20 commented Oct 26, 2020

Hi,

I've been using the neural-tangents library a lot over the past few months, it's been extremely helpful.

I just a had a question about calculating the marginal log-likelihood for NNGPs, which I came across when reading the neural tangents library paper (e.g. figure 3, figure 7).

I have tried to calculate the NLL on the CIFAR-10 dataset as well and have linked my jupyter notebook . The problem I'm getting is that as I increase the training test size, the training NLL increases as well, which is the opposite of the results in the paper. Could you please point out the errors in my code/calculation or perhaps share the code for the calculations?

Thanks

@jaehlee
Copy link
Collaborator

jaehlee commented Oct 30, 2020

Hi there! Thanks for trying out our library!

As far as I can tell, notebook looks good to me. My sense is that you should compare 'mean' training NLL when you compare different dataset size, which we are doing in the paper. Recall that if you have N identical independent random variable, negative log-likelihood of joint distribution will scale with N. So if you want to compare across different size you should normalize by the size. Let us know if we can reproduce reducing average training NLL.

@jaehlee
Copy link
Collaborator

jaehlee commented Oct 30, 2020

Well sorry, I noticed that you do normalize in your notebook. So I'll take a deeper look if there's anything else.

@TZeng20
Copy link
Author

TZeng20 commented Oct 30, 2020

Hi @jaehlee thanks a lot. Here is the colab notebook in case its easier to reproduce.

@jaehlee
Copy link
Collaborator

jaehlee commented Oct 31, 2020

Thanks for sharing. I think I have identified few differences. As a side there are few efficiency tricks to make the code run faster by using jax numpy and scipy and running on accelerators as well as not doing cholesky on n 10 by n10 matrix but utilizing the kronecker structure for example.

Here's an extension of your notebook to include some of our computation.

Few things to notice: at the bottom, NLL computation is quite sensitive to reg strength. So should you consider unregularized strength for computing NLL or the optimal reg strength? I think in GP model selection, you ought to model output noise strength with diagonal regularization to the kernel and you do param selection using train set NLL.

The plots we've made in the paper is on optimal reg strength. In the colab we compare the case where we tune the reg strength vs setting to zero. The latter case shows the behavior as you have (modulo possible # C factor difference) but for the former we see that NLL decreasing with the dataset size.

Let me know if the notebook makes sense to you. If you spot anything not correct, please let me know.

@TZeng20
Copy link
Author

TZeng20 commented Nov 2, 2020

Thanks a lot for the notebook! When I used a diag reg of around eps = 1 using my original function it was also decreasing with train size. I had some additional questions about the notebook:

  • When comparing the marginal NLL across different architectures, do the covariance matrices need to be scaled down (e.g by dividing by the maximum?)
  • And also, how are you able to rewrite the first term of the log-likelihood as np.einsum('ik,ik->k', t, alpha)? Is this related to a property of the kronecker product?

I'm also curious about the relationship between marginal NLL and accuracy. I know that in some of the papers you guys have written that you tune the diagonal reg to maximise accuracy on a validation set. Does this value also minimise the marginal NLL? And is there a reason why you use accuracy rather than minimising the marginal NLL?

@jaehlee
Copy link
Collaborator

jaehlee commented Nov 6, 2020

  • One should not scale down the covariance matrix since the scale itself is capturing the uncertainty of the output. However, one way to think about changing the scale in GP is as temperature scaling (https://arxiv.org/abs/2010.07355, https://arxiv.org/abs/2008.00029).

  • The assumption is that different output classes are independent and one could just sum over them for log-likelihood. The independence is a result of infinite-width limit, which allows kronecker product structure in the full output covariance.

While we have not extensively studied relationship between marginal NLL and accuracy, we believe validation accuracy was better metric to use if our goal is to find models that does well on test accuracy. In small comparisons, optimal diag-reg or other hparams are similar in ball-park either by NLL or validation accuracy but different enough to change test accuracy. In a more-principled, model selection point-of-view, selection on marginal NLL do make more sense but often for finding good performing model on certain metric, validation is the way to go.

@TZeng20
Copy link
Author

TZeng20 commented Nov 7, 2020

Great, appreciate your help!

@TZeng20 TZeng20 closed this as completed Nov 7, 2020
@romanngg romanngg added the question Further information is requested label Nov 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants