Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analytic kernel evaluated on sparse inputs #14

Closed
liutianlin0121 opened this issue Dec 19, 2019 · 6 comments
Closed

Analytic kernel evaluated on sparse inputs #14

liutianlin0121 opened this issue Dec 19, 2019 · 6 comments

Comments

@liutianlin0121
Copy link

liutianlin0121 commented Dec 19, 2019

Hi!

A bug seems to occur when I was trying to evaluate analytic NTKs using sparse input data -- the evaluated kernel contains nan entries. This can be reproduced with the following lines of codes:

from jax import random
from neural_tangents import stax

key = random.PRNGKey(1)

# a batch of dense inputs 
x_dense = random.normal(key, (3, 32, 32, 3))

# a batch of sparse inputs 
x_sparse = x_dense * (abs(x_dense) > 1.2)


# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
     stax.Conv(128, (3, 3)),
     stax.Relu(),
     stax.Flatten(),
     stax.Dense(10) )

# Evaluate the analytic NTK upon dense inputs

print('NTK evaluated w/ dense inputs: \n', kernel_fn(x_dense, x_dense, 'ntk')) # the outputs look fine.

print('\n')

# Evaluate the analytic NTK upon sparse inputs

print('NTK evaluated w/ sparse inputs: \n', kernel_fn(x_sparse, x_sparse, 'ntk')) # the outputs contains nan

The output of the above script should be:

NTK evaluated w/ dense inputs: 
 [[0.97102666 0.16131128 0.16714054]
 [0.16131128 0.9743941  0.17580226]
 [0.16714054 0.17580226 1.0097454 ]]


NTK evaluated w/ sparse inputs: 
 [[       nan        nan        nan]
 [       nan 0.66292834        nan]
 [       nan        nan        nan]]

Thanks for your time in advance!

@sschoenholz
Copy link
Contributor

Thanks for the report! The issue is that if the variance of an input (here taken to mean a single pixel for a single datapoint) is zero then it can cause NaNs. There is a simple fix to just add a small stability term to our normalization. I expect us to push a fix today.

@liutianlin0121
Copy link
Author

Thanks so much for your reply!

I am not sure whether the input variance plays a major role here. Indeed, in the example above, the sparse inputs are drawn from a Gaussian distribution and then truncated based on magnitude, so their magnitudes should symmetrically spread out around 0. But I also found the similar nan problem with non-negative sparse inputs. The below script shows this phenomenon with MNIST images:


import tensorflow as tf
import numpy as np
from jax import random
from neural_tangents import stax

mnist = tf.keras.datasets.mnist

(x_train, _), (_, _) = mnist.load_data()


x_train = x_train / 255.0 # normalize the input values to values in (0, 1)

x_train_subset_sparse = x_train[:3].reshape([-1, 28, 28, 1]) # sparse input samples. 


# standardize the data 
mean = np.mean(x_train)
std = np.std(x_train)
x_train_dense = (x_train - mean) / std

x_train_subset_dense = x_train_dense[:3].reshape([-1, 28, 28, 1])  # dense input samples

# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
     stax.Conv(128, (3, 3)),
     stax.Relu(),
     stax.Flatten(),
     stax.Dense(10) )

print('NTK evaluated w/ sparse MNIST images: \n', kernel_fn(x_train_subset_sparse, x_train_subset_sparse, 'ntk')) # the outputs contains nan

print('NTK evaluated w/ dense, standardized MNIST images: \n', kernel_fn(x_train_subset_dense, x_train_subset_dense, 'ntk')) # the outputs looks fine

The output of the above script should be like:

NTK evaluated w/ sparse MNIST images: 
 [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
NTK evaluated w/ dense, standardized MNIST images: 
 [[1.1637697  0.6116184  0.21783468]
 [0.6116184  1.3009455  0.213599  ]
 [0.21783468 0.213599   0.79291606]]

So, without standardization, the sparse MNIST images seem to cause the nan problem. The standardized MNIST images with zero mean actually seem to solve the problem.

Thanks for your time!

@SiuMath
Copy link
Contributor

SiuMath commented Dec 20, 2019 via email

@liutianlin0121
Copy link
Author

@SiuMath @sschoenholz Many thanks for your explanations! Previously I misunderstood the variance we are talking about here as the one defined across multiple samples for a single pixel :)

romanngg pushed a commit that referenced this issue Jan 16, 2020
…ctly zero, which implied zero variance. Solved by using a safe inverse sqrt for normalization in the activation functions.

PiperOrigin-RevId: 287355386
@romanngg
Copy link
Contributor

FYI, I believe Sam has fixed it 0e92b0f!

@liutianlin0121
Copy link
Author

@romanngg many thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants