Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure float32 inputs imply float32 outputs when jax_enable_x64=1 #36

Closed
trevorcai opened this issue Apr 21, 2020 · 0 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@trevorcai
Copy link
Contributor

trevorcai commented Apr 21, 2020

import os
os.environ["JAX_ENABLE_X64"] = "1"

import jax
import haiku as hk
import numpy as np

@hk.transform
def f(x):
  return hk.Linear(4)(x)

f32_data = np.zeros((4, 8), dtype=np.float32)

p = f.init(jax.random.PRNGKey(428), f32_data)
print(jax.tree_map(lambda t: t.dtype, p))
f32_params = jax.tree_map(lambda t: t.astype(np.float32), p)
print(f.apply(f32_params, f32_data).dtype)

Prints:

frozendict({
  'linear': frozendict({'b': dtype('float32'), 'w': dtype('float64')}),
})
dtype('float32')

Hopefully the bfloat16 compatibility work means we get most of this for free and that we only need to port the initializers.

@trevorcai trevorcai self-assigned this Apr 21, 2020
@trevorcai trevorcai added the bug Something isn't working label Apr 21, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant