New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Analytic kernel evaluated on sparse inputs #14
Comments
Thanks for the report! The issue is that if the variance of an input (here taken to mean a single pixel for a single datapoint) is zero then it can cause NaNs. There is a simple fix to just add a small stability term to our normalization. I expect us to push a fix today. |
Thanks so much for your reply! I am not sure whether the input variance plays a major role here. Indeed, in the example above, the sparse inputs are drawn from a Gaussian distribution and then truncated based on magnitude, so their magnitudes should symmetrically spread out around 0. But I also found the similar
The output of the above script should be like:
So, without standardization, the sparse MNIST images seem to cause the Thanks for your time! |
Thanks Tianlin.
The problem is caused by the fact that the per pixel variance is zero for
many pixels in a single input of mnist (or sparse inputs). When computing
the ntk of cnn, we need to keep track of the variance of each pixel in each
input. You could be able to see this by computing the ntk right after the
conv layer. There are many zero terms in the ntk. Subtracting the mean
eliminates the zeros in the pixels and fixes this issue.
…On Thu, Dec 19, 2019 at 2:48 PM Tianlin Liu ***@***.***> wrote:
Thanks so much for your reply!
I am not sure whether the input variance plays a major role here. Indeed,
in the example above, the sparse inputs are drawn from a Gaussian
distribution and then truncated based on magnitude, so their magnitudes
should symmetrically spread out around 0. But I also found the similar nan
problem with non-negative sparse inputs. The below script shows this
phenomenon with MNIST images:
import tensorflow as tf
import numpy as np
from jax import random
from neural_tangents import stax
mnist = tf.keras.datasets.mnist
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0 # normalize the input values to values in (0, 1)
x_train_subset_sparse = x_train[:3].reshape([-1, 28, 28, 1]) # sparse input samples.
# standardize the data
mean = np.mean(x_train)
std = np.std(x_train)
x_train_dense = (x_train - mean) / std
x_train_subset_dense = x_train_dense[:3].reshape([-1, 28, 28, 1]) # dense input samples
# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Conv(128, (3, 3)),
stax.Relu(),
stax.Flatten(),
stax.Dense(10) )
print('NTK evaluated w/ sparse MNIST images: \n', kernel_fn(x_train_subset_sparse, x_train_subset_sparse, 'ntk')) # the outputs contains nan
print('NTK evaluated w/ dense, standardized MNIST images: \n', kernel_fn(x_train_subset_dense, x_train_subset_dense, 'ntk')) # the outputs looks fine
The output of the above script should be like:
NTK evaluated w/ sparse MNIST images:
[[nan nan nan]
[nan nan nan]
[nan nan nan]]
NTK evaluated w/ dense, standardized MNIST images:
[[1.1637697 0.6116184 0.21783468]
[0.6116184 1.3009455 0.213599 ]
[0.21783468 0.213599 0.79291606]]
So, without standardization, the sparse MNIST images seem to cause the nan
problem. The standardized MNIST images with zero mean actually seem to
solve the problem.
Thanks for your time!
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#14?email_source=notifications&email_token=AGC3MA6VMCP6475YM7DL4RDQZPF2PA5CNFSM4J5CDCY2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEHKXE3Y#issuecomment-567636591>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AGC3MA7EJFVTGDOQXHY5YHTQZPF2PANCNFSM4J5CDCYQ>
.
|
@SiuMath @sschoenholz Many thanks for your explanations! Previously I misunderstood the variance we are talking about here as the one defined across multiple samples for a single pixel :) |
…ctly zero, which implied zero variance. Solved by using a safe inverse sqrt for normalization in the activation functions. PiperOrigin-RevId: 287355386
FYI, I believe Sam has fixed it 0e92b0f! |
@romanngg many thanks!! |
Hi!
A bug seems to occur when I was trying to evaluate analytic NTKs using sparse input data -- the evaluated kernel contains
nan
entries. This can be reproduced with the following lines of codes:The output of the above script should be:
Thanks for your time in advance!
The text was updated successfully, but these errors were encountered: