Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate initializers and activation functions to jax.nn #1262

Merged
merged 6 commits into from
Sep 4, 2019

Conversation

jekbradbury
Copy link
Contributor

@jekbradbury jekbradbury commented Aug 29, 2019

This PR does a few things:

  • Adds a simple truncated normal sampler, which I've manually verified does the right thing for inputs in the regime relevant to neural net initialization (I'm still a little afraid that the special cases I've seen in other truncated normal implementations are important, but I'm fairly confident that they're not needed for initializers).
  • Adds a jax.nn namespace to centralize/official-ize the location of shared neural net-related functions.
  • Puts implementations of standard initializers, activation functions, and a couple other things there. The goal is to have one place with reliable semantics and numerics; I'm not yet sure what kinds of tests to add to make sure that stays the case.
  • Updates the log_softmax implementation to fix a numerical issue seen internally.
  • Updates stax to use jax.nn, hopefully without breaking users.
  • Adds docs for jax.nn.

Closes #1194 and #1195, re-closes #985.

@jekbradbury jekbradbury marked this pull request as ready for review August 30, 2019 01:07
Copy link
Member

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! My main request is not to put code in __init__.py files and not to do from blah import *. A few other minor suggestions too.

jax/experimental/stax.py Outdated Show resolved Hide resolved
jax/nn/__init__.py Outdated Show resolved Hide resolved
jax/nn/initializers.py Outdated Show resolved Hide resolved
jax/nn/__init__.py Outdated Show resolved Hide resolved
jax/nn/initializers.py Outdated Show resolved Hide resolved
jax/nn/initializers.py Outdated Show resolved Hide resolved
jax/nn/__init__.py Outdated Show resolved Hide resolved
jax/nn/__init__.py Outdated Show resolved Hide resolved
jax/random.py Outdated Show resolved Hide resolved
jax/random.py Outdated Show resolved Hide resolved
Copy link
Member

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Great idea to make this change. It's a big improvement and emblematic of a good way to think about jax-core's tools-for-nn-library-builders going forward.

jax/nn/__init__.py Show resolved Hide resolved
@jekbradbury jekbradbury merged commit 146b5d1 into master Sep 4, 2019
@jekbradbury jekbradbury deleted the jb/initializers branch September 4, 2019 21:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: More Initializers in stax logsoftmax has numerical issue
3 participants