Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eagerpy not working together with Neural Tangents #19

Open
ybj14 opened this issue Sep 20, 2020 · 3 comments
Open

eagerpy not working together with Neural Tangents #19

ybj14 opened this issue Sep 20, 2020 · 3 comments

Comments

@ybj14
Copy link

ybj14 commented Sep 20, 2020

I'm trying to differentiate through predict_fn provided by https://github.com/google/neural-tangents. This is doable with jax.grad, but not with eagerpy.value_and_grad.

@ybj14
Copy link
Author

ybj14 commented Sep 20, 2020

code snippet

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An example doing inference with an infinitely wide fully-connected network.

By default, this example does inference on a small CIFAR10 subset.
"""

import time
from absl import app
from absl import flags
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from examples import util


flags.DEFINE_integer('train_size', 10000,
                     'Dataset size to use for training.')
flags.DEFINE_integer('test_size', 1000,
                     'Dataset size to use for testing.')
flags.DEFINE_integer('batch_size', 0,
                     'Batch size for kernel computation. 0 for no batching.')


FLAGS = flags.FLAGS


def main(unused_argv):
  # Build data pipelines.
  print('Loading data.')
  x_train, y_train, x_test, y_test = \
    datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

  # Build the infinite network.
  _, _, kernel_fn = stax.serial(
      stax.Dense(1, 2., 0.05),
      stax.Relu(),
      stax.Dense(1, 2., 0.05)
  )

  # Optionally, compute the kernel in batches, in parallel.
  kernel_fn = nt.batch(kernel_fn,
                       device_count=0,
                       batch_size=FLAGS.batch_size)

  start = time.time()
  # Bayesian and infinite-time gradient descent inference with infinite network.
  predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
                                                        y_train, diag_reg=1e-3)
  fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)
  fx_test_nngp.block_until_ready()
  fx_test_ntk.block_until_ready()

  duration = time.time() - start
  print('Kernel construction and inference done in %s seconds.' % duration)

  # Print out accuracy and loss for infinite network predictions.
  loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
  util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
  util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)


  from jax import grad
  def MSELoss(x_test):
    loss = 0.5 * ((predict_fn(x_test=x_test, get='ntk') - y_test) ** 2).mean()
    return loss
  def Norm(x_test):
    return (x_test ** 2).mean()
  print(grad(MSELoss)(x_test).shape)
  print(x_test.shape)

  import eagerpy as ep 
  print(type(x_test))
  x_test = np.array(x_test)
  print(type(x_test))
  x_test = ep.astensor(x_test)
  print(type(x_test))
  loss, g = ep.value_and_grad(MSELoss, x_test) # Error!
  loss, g = ep.value_and_grad(Norm, x_test)
  print(g.shape)



if __name__ == '__main__':
  app.run(main)

@ybj14
Copy link
Author

ybj14 commented Sep 20, 2020

As can be seen from the above code, eagerpy works well with JAX for pure function, but will break as soon as predict_fn is involved.

@jonasrauber
Copy link
Owner

Thanks for reporting this. Could you add syntax highlighting to your code and share the exact error message? Can you try value_and_grad_fn instead of value_and_grad. I think we should be able to fix this once we know what the exact error message is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants