Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: simple example poor performance, what am I doing wrong? #163

Open
mfouesneau opened this issue Aug 24, 2022 · 5 comments
Open
Labels
question Further information is requested

Comments

@mfouesneau
Copy link

Dear team, great package, I'm very excited to use it.

However, I tried a simple case, and I failed miserably to get a decent performance.

I generate a multi-dimensional dataset with a relatively simple feature

import numpy as np
    
#Create some fake data
np.random.seed(0)
m = 1000
n = 10
noise_std = 1.
X = 80*numpy.random.uniform(size=(m,n)) - 40
y = np.abs(X[:,6] - 4.0) + noise_std * np.random.normal(size=m)

And I followed your examples as

import neural_tangents as nt
from neural_tangents import stax
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(
    X, y.reshape(-1, 1), test_size=0.4, random_state=42)

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(256), stax.Relu(),
    stax.Dense(1)
)
predict_fn = nt.predict.gradient_descent_mse_ensemble(
    kernel_fn, 
    x_train,
    y_train)

# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'), compute_cov=True)

Visual inspection shows terrible predictions, and loss values are large:

loss = lambda ypred, y_hat: 0.5 * jnp.mean((ypred - y_hat) ** 2)
print("loss_nngp = {}".format(loss(y_test_nngp.mean, y_test)))
print("loss_ntk = {}".format(loss(y_test_ntk.mean, y_test)))
loss_nngp = 6.877374649047852
loss_ntk = 6.610106468200684

I varied the network in many ways and fiddled with learning_rate and diag_reg, but I hardly changed anything.

I'm sure I am doing something wrong, but I cannot see what it is. Any obvious mistake?

Thanks for your help.

@romanngg
Copy link
Contributor

At a glance, library usage seems good to me! Perhaps one way to figure this out is to establish a baseline using some other method (kernel, neural network, etc), to figure out what loss values are expected? For example it seems that y will have a mean of 20 (expectation of the absolute value of a uniform from -44 to 36 is 1/2 * (40 + 4) / 2 + 1/2 * (40 - 4) / 2), so scale of outputs is pretty large, so it's not obvious to me if the loss values are that large. Another angle is to try increasing the training set size - it's hard to say if 600 training points is large enough for the model to learn well.

@romanngg romanngg added the question Further information is requested label Aug 24, 2022
@mfouesneau
Copy link
Author

mfouesneau commented Aug 24, 2022

args = np.argsort(x_test[:, 6])
y_mean = np.reshape(y_test_ntk.mean, (-1,))[args]
y_std = np.sqrt(np.diag(y_test_ntk.covariance))[args]

plt.plot(X[:,6],y,'k.', alpha=0.1, rasterized=True)
plt.fill_between(
    np.reshape(x_test[args, 6], (-1)),
    y_mean - 3 * y_std,
    y_mean +  3 * y_std,
    color='red', alpha=0.2)
plt.xlabel('x_6')
plt.ylabel('y')

image

The thing is that changing the layer from 50 nodes to 5000 hardly changes the output. I would expect at least some changes.

I tried 10_000 points, and I only gained a factor of 2 on the loss
image

Is there any guidance on what a correct training set should be?

@mfouesneau
Copy link
Author

I get memory errors if I try 100 000 points in my dataset. Even with the batch trick

kernel_fn = nt.batch(kernel_fn,
                     device_count=0,
                     batch_size=1_000)

@romanngg
Copy link
Contributor

Note that in your example you are doing inference with an infinitely-wide neural network (kernel_fn), so the width doesn't matter in this case. Also, the plot does look like the learned function mimicks |x_6 - 4| (at least it's not doing something obviously wrong, it has the right shape and kink location), so I'm inclined to think that it's working as intended?...

Re training set, I think it's constructed correctly, I'm just not sure how to reason about the generalization that we should expect from it (per your plot, it seems to be at least OKish?...).

And yes, 100K is too much for most GPUs.

@mfouesneau
Copy link
Author

You're right; it seems to be doing ok, but with serious overfit.

Is there a paper to read to get a feeling for appropriate network architecture? My understanding is that multiplying layers will not change anything unless a "layer" is a complex thing already. right?

image
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants