Skip to content

Commit

Permalink
[TRAX] Backends are enums instead of strings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 323567051
  • Loading branch information
afrozenator authored and Copybara-Service committed Jul 28, 2020
1 parent 2149bbf commit aafc0d8
Show file tree
Hide file tree
Showing 24 changed files with 136 additions and 99 deletions.
47 changes: 35 additions & 12 deletions trax/fastmath/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,21 @@
"""

import contextlib
import enum

import gin
from trax.fastmath.jax import JAX_BACKEND
from trax.fastmath.numpy import NUMPY_BACKEND
from trax.fastmath.tf import TF_BACKEND


@enum.unique
class Backend(enum.Enum):
JAX = 'jax'
TFNP = 'tensorflow-numpy'
NUMPY = 'numpy'


# For numpy and random modules, we need to call "backend()" lazily, only when
# the function is called -- so that it can be set by gin configs.
# (Otherwise, backend() is called on import before gin-config is parsed.)
Expand Down Expand Up @@ -226,34 +235,48 @@ def device_count(*args, **kwargs):

# Backend selection functions.


override_backend_name = None
override_backend = None
_backend_dict = {
Backend.JAX: JAX_BACKEND,
Backend.NUMPY: NUMPY_BACKEND,
Backend.TFNP: TF_BACKEND,
}


@gin.configurable()
def backend(name='jax'):
"""Return the backend used to provide fastmath ops ('tf' or 'jax')."""
name = name if not override_backend_name else override_backend_name
if name == 'numpy':
return NUMPY_BACKEND
elif name == 'tf':
return TF_BACKEND
"""Returns the backend used to provide fastmath ops ('tf' or 'jax')."""
if override_backend:
return _backend_dict[override_backend]

name = override_backend or name
if isinstance(name, Backend):
return _backend_dict[name]

# name is a string.
for backend_ in Backend:
if backend_.value == name:
return _backend_dict[backend_]
return JAX_BACKEND


@contextlib.contextmanager
def use_backend(name):
"""Call fastmath functions with a specified backend."""
global override_backend_name
prev_name = override_backend_name
override_backend_name = name
global override_backend
prev_name_or_backend = override_backend
override_backend = name
# Run the decorated function in try-finally in case it throws, e.g. for tests.
try:
yield
finally:
override_backend_name = prev_name
override_backend = prev_name_or_backend


def backend_name():
"""Returns the name of the backend curently in use ('tf' or 'jax')."""
return backend()['name']


def is_backend(backend_):
return backend()['name'] == backend_.value
9 changes: 9 additions & 0 deletions trax/fastmath/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def test_nested_stack(self):
out = {'a': ([[0, 1], [1, 2]], [2, 3]), 'b': _TestNamedtuple([3, 4])}
onp.testing.assert_equal(fastmath.nested_stack(inp), out)

def test_names_match(self):
# Names match up.
for backend_enum, backend_obj in fastmath.ops._backend_dict.items():
self.assertEqual(backend_enum.value, backend_obj['name'])

# Every backend appears in the dictionary.
for backend_enum in fastmath.ops.Backend:
self.assertIn(backend_enum, fastmath.ops._backend_dict)


if __name__ == '__main__':
test.main()
2 changes: 1 addition & 1 deletion trax/fastmath/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _tf_pmap(*args, **kwargs):


TF_BACKEND = {
'name': 'tf',
'name': 'tensorflow-numpy',
'np': tf_np,
'jit': _tf_jit,
'stop_gradient': tf_np_extensions.stop_gradient,
Expand Down
2 changes: 1 addition & 1 deletion trax/layers/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def f(x):
def for_n_devices(x, n_devices):
"""Replicates/broadcasts `x` for `n_devices`."""
def f(x):
if n_devices > 1 and fastmath.backend_name() == 'jax':
if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
return _multi_device_put(x)
elif n_devices > 1:
return jnp.broadcast_to(x, (n_devices,) + x.shape)
Expand Down
10 changes: 5 additions & 5 deletions trax/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if fastmath.backend_name() == 'jax':
if fastmath.is_backend(fastmath.Backend.JAX):
mask = jax.lax.tie_in(dots, mask)
# JAX's `full_like` already ties in -1e9 to dots.
dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
Expand Down Expand Up @@ -272,7 +272,7 @@ def forward(self, inputs):
# Not all backends define jnp.tril. However, using np.tril is inefficient
# in that it creates a large global constant. TODO(kitaev): try to find an
# alternative that works across all backends.
if fastmath.backend_name() == 'jax':
if fastmath.is_backend(fastmath.Backend.JAX):
mask = jnp.tril(
jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0)
else:
Expand Down Expand Up @@ -360,7 +360,7 @@ def forward(self, inputs):
for dim in self._dropout_broadcast_dims:
noise_shape[dim] = 1
keep_prob = 1.0 - self._dropout
if fastmath.backend_name() == 'jax':
if fastmath.is_backend(fastmath.Backend.JAX):
keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype))
keep = fastmath.random.bernoulli(self.rng, keep_prob,
tuple(noise_shape))
Expand Down Expand Up @@ -437,9 +437,9 @@ def _fast_inference_update_state(inputs, state):
Returns:
Updated state.
"""
if fastmath.backend_name() != 'jax':
if not fastmath.is_backend(fastmath.Backend.JAX):
raise ValueError(f'JAX backend is required in predict mode, but found '
f'backend ({fastmath.backend_nameO()}).')
f"backend ({fastmath.backend()['name']}).")

# Fast inference: run step-by-step, storing the sequence
# of keys and values calculated so far in state.
Expand Down
18 changes: 10 additions & 8 deletions trax/layers/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from trax.fastmath import numpy as jnp
import trax.layers as tl

BACKENDS = ['jax', 'tf']
CUSTOM_GRAD_BACKENDS = ['jax'] # TODO(afrozm): Delete after TF 2.3
BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP]
CUSTOM_GRAD_BACKENDS = [fastmath.Backend.JAX] # TODO(afrozm): del after TF 2.3


class BaseLayerTest(parameterized.TestCase):
Expand Down Expand Up @@ -65,8 +65,9 @@ def test_output_signature(self):
self.assertNotEqual(output_signature, (shapes.ShapeDtype((4, 7)),) * 3)
self.assertNotEqual(output_signature, (shapes.ShapeDtype((5, 7)),) * 2)

@parameterized.named_parameters([('_' + b, b) for b in CUSTOM_GRAD_BACKENDS])
def test_custom_zero_grad(self, backend_name):
@parameterized.named_parameters(
[('_' + b.value, b) for b in CUSTOM_GRAD_BACKENDS])
def test_custom_zero_grad(self, backend):

class IdWithZeroGrad(tl.Layer):

Expand All @@ -80,7 +81,7 @@ def has_backward(self):
def backward(self, inputs, output, grad, weights, state, new_state, rng):
return (jnp.zeros_like(grad), ())

with fastmath.use_backend(backend_name):
with fastmath.use_backend(backend):
layer = IdWithZeroGrad()
rng = fastmath.random.get_prng(0)
input_signature = shapes.ShapeDtype((9, 17))
Expand All @@ -92,8 +93,9 @@ def backward(self, inputs, output, grad, weights, state, new_state, rng):
self.assertEqual(grad.shape, (9, 17)) # Gradient for each input.
self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0.

@parameterized.named_parameters([('_' + b, b) for b in CUSTOM_GRAD_BACKENDS])
def test_custom_id_grad(self, backend_name):
@parameterized.named_parameters(
[('_' + b.value, b) for b in CUSTOM_GRAD_BACKENDS])
def test_custom_id_grad(self, backend):

class IdWithIdGrad(tl.Layer):

Expand All @@ -107,7 +109,7 @@ def has_backward(self):
def backward(self, inputs, output, grad, weights, state, new_state, rng):
return (inputs, ())

with fastmath.use_backend(backend_name):
with fastmath.use_backend(backend):
layer = IdWithIdGrad()
rng = fastmath.random.get_prng(0)
input_signature = shapes.ShapeDtype((9, 17))
Expand Down
19 changes: 10 additions & 9 deletions trax/layers/combinators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ def some_layer():
self.assertEqual(output_shapes, [(3,), (5,), (2,)])


@parameterized.named_parameters(('_' + b, b) for b in ('jax', 'tf'))
@parameterized.named_parameters(
('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP))
class ScanTest(parameterized.TestCase):

def _AddWithCarry(self): # pylint: disable=invalid-name
Expand All @@ -479,8 +480,8 @@ def f(x, carry):
return res, res # output and carry are the same
return tl.Fn('AddWithCarry', f, n_out=2)

def test_default_axis(self, backend_name):
with fastmath.use_backend(backend_name):
def test_default_axis(self, backend):
with fastmath.use_backend(backend):
layer = tl.Scan(self._AddWithCarry())
xs = [
np.array([[0, 1, 2, 3],
Expand All @@ -497,8 +498,8 @@ def test_default_axis(self, backend_name):
[9000, 8111, 7222, 6333]
])

def test_axis_1(self, backend_name):
with fastmath.use_backend(backend_name):
def test_axis_1(self, backend):
with fastmath.use_backend(backend):
layer = tl.Scan(self._AddWithCarry(), axis=1)
xs = [
np.array([[0, 1, 2, 3],
Expand All @@ -519,13 +520,13 @@ def test_axis_1(self, backend_name):
7600]
])

def test_multi_input(self, backend_name):
def test_multi_input(self, backend):
def _MultiInputFn(): # pylint: disable=invalid-name
def f(a, b, carry):
return a + b, b, carry + 1
return tl.Fn('MultiInputFn', f, n_out=2)

with fastmath.use_backend(backend_name):
with fastmath.use_backend(backend):
layer = tl.Scan(_MultiInputFn(), axis=1)
xs = [
np.array([[0, 1, 2],
Expand All @@ -545,11 +546,11 @@ def f(a, b, carry):
8003]
])

def test_no_carry(self, backend_name):
def test_no_carry(self, backend):
def _AddOne(): # pylint: disable=invalid-name
return tl.Fn('AddOne', lambda x: x + 1)

with fastmath.use_backend(backend_name):
with fastmath.use_backend(backend):
layer = tl.Scan(_AddOne(), n_carry=0)
x = np.array([[1, 3, 7],
[10, 30, 70]])
Expand Down
4 changes: 2 additions & 2 deletions trax/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,12 @@ def forward(self, x):
mask_shape = list(x.shape)
for axis in self._shared_axes:
mask_shape[axis] = 1
if fastmath.backend_name() == 'jax':
if fastmath.is_backend(fastmath.Backend.JAX):
keep_prob = jax.lax.tie_in(self.rng, 1.0 - rate)
else:
keep_prob = 1.0 - rate
keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape))
if fastmath.backend_name() == 'jax':
if fastmath.is_backend(fastmath.Backend.JAX):
keep_prob = jax.lax.tie_in(keep, keep_prob)
mask = keep.astype(x.dtype) / keep_prob
return x * mask
Expand Down
2 changes: 1 addition & 1 deletion trax/layers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def f(values, weights): # pylint: disable=invalid-name
def one_hot(x, n_categories, dtype=jnp.float32): # pylint: disable=invalid-name
"""Makes a one-hot array (n+1 dims) from an int-categorical array (n dims)."""
indices_less_than_n = jnp.arange(n_categories)
if fastmath.backend_name() == 'jax':
if fastmath.is_backend(fastmath.Backend.JAX):
# Work around a jax broadcasting issue.
indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n)
return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
12 changes: 6 additions & 6 deletions trax/layers/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def test_forward_shape(self):
self.assertEqual(y.shape, x.shape)

@parameterized.named_parameters(
('jax32', 'jax', np.float32),
('tf32', 'tf', np.float32),
('tf64', 'tf', np.float64),
('jax32', fastmath.Backend.JAX, np.float32),
('tf32', fastmath.Backend.TFNP, np.float32),
('tf64', fastmath.Backend.TFNP, np.float64),
)
def test_forward_dtype(self, backend, dtype):
with fastmath.use_backend(backend):
Expand Down Expand Up @@ -98,9 +98,9 @@ def test_forward_shape(self):
self.assertEqual(y.shape, x.shape)

@parameterized.named_parameters(
('jax32', 'jax', np.float32),
('tf32', 'tf', np.float32),
('tf64', 'tf', np.float64),
('jax32', fastmath.Backend.JAX, np.float32),
('tf32', fastmath.Backend.TFNP, np.float32),
('tf64', fastmath.Backend.TFNP, np.float64),
)
def test_forward_dtype(self, backend, dtype):
with fastmath.use_backend(backend):
Expand Down
4 changes: 2 additions & 2 deletions trax/layers/research/efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


def tie_in(x, y):
if fastmath.backend_name() == 'jax':
if fastmath.is_backend(fastmath.Backend.JAX):
return jax.lax.tie_in(x, y)
return y

Expand Down Expand Up @@ -1158,7 +1158,7 @@ def hash_vectors(self, vecs, rng, mask=None):
rng = fastmath.stop_gradient(tie_in(vecs, rng))
random_rotations = fastmath.random.normal(rng, rotations_shape).astype(
np.float32)
if fastmath.backend_name() == 'jax':
if fastmath.is_backend(fastmath.Backend.JAX):
rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations)
else:
random_rotations = np.reshape(random_rotations,
Expand Down

0 comments on commit aafc0d8

Please sign in to comment.