Skip to content

Commit

Permalink
Merge d9f36d0 into 56b23b4
Browse files Browse the repository at this point in the history
  • Loading branch information
trax-robot committed Jul 1, 2020
2 parents 56b23b4 + d9f36d0 commit cd1a3d3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 47 deletions.
70 changes: 33 additions & 37 deletions trax/layers/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from trax import fastmath
from trax import shapes
from trax.fastmath import numpy as jnp
from trax.layers import base

import trax.layers as tl

BACKENDS = ['jax', 'tf']
CUSTOM_GRAD_BACKENDS = ['jax'] # TODO(afrozm): Delete after TF 2.3
Expand All @@ -34,35 +33,33 @@
class BaseLayerTest(parameterized.TestCase):

def test_call_raises_error(self):
layer = base.Layer()
x = np.array([[1, 2, 3, 4, 5],
[10, 20, 30, 40, 50]])
with self.assertRaisesRegex(base.LayerError, 'NotImplementedError'):
layer = tl.Layer()
x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]])
with self.assertRaisesRegex(tl.LayerError, 'NotImplementedError'):
_ = layer(x)

def test_forward_raises_error(self):
layer = base.Layer()
x = np.array([[1, 2, 3, 4, 5],
[10, 20, 30, 40, 50]])
layer = tl.Layer()
x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]])
with self.assertRaises(NotImplementedError):
_ = layer.forward(x)

def test_init_returns_empty_weights_and_state(self):
layer = base.Layer()
layer = tl.Layer()
input_signature = shapes.ShapeDtype((2, 5))
weights, state = layer.init(input_signature)
self.assertEmpty(weights)
self.assertEmpty(state)

def test_output_signature(self):
input_signature = (shapes.ShapeDtype((2, 3, 5)),
shapes.ShapeDtype((2, 3, 5)))
layer = base.Fn('2in1out', lambda x, y: x + y)
input_signature = (shapes.ShapeDtype((2, 3, 5)), shapes.ShapeDtype(
(2, 3, 5)))
layer = tl.Fn('2in1out', lambda x, y: x + y)
output_signature = layer.output_signature(input_signature)
self.assertEqual(output_signature, shapes.ShapeDtype((2, 3, 5)))

input_signature = shapes.ShapeDtype((5, 7))
layer = base.Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3)
layer = tl.Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3)
output_signature = layer.output_signature(input_signature)
self.assertEqual(output_signature, (shapes.ShapeDtype((5, 7)),) * 3)
self.assertNotEqual(output_signature, (shapes.ShapeDtype((4, 7)),) * 3)
Expand All @@ -71,7 +68,7 @@ def test_output_signature(self):
@parameterized.named_parameters([('_' + b, b) for b in CUSTOM_GRAD_BACKENDS])
def test_custom_zero_grad(self, backend_name):

class IdWithZeroGrad(base.Layer):
class IdWithZeroGrad(tl.Layer):

def forward(self, x):
return x
Expand All @@ -87,8 +84,8 @@ def backward(self, inputs, output, grad, weights, state, new_state, rng):
layer = IdWithZeroGrad()
rng = fastmath.random.get_prng(0)
input_signature = shapes.ShapeDtype((9, 17))
random_input = fastmath.random.uniform(rng, input_signature.shape,
minval=-1.0, maxval=1.0)
random_input = fastmath.random.uniform(
rng, input_signature.shape, minval=-1.0, maxval=1.0)
layer.init(input_signature)
f = lambda x: jnp.mean(layer(x))
grad = fastmath.grad(f)(random_input)
Expand All @@ -98,7 +95,7 @@ def backward(self, inputs, output, grad, weights, state, new_state, rng):
@parameterized.named_parameters([('_' + b, b) for b in CUSTOM_GRAD_BACKENDS])
def test_custom_id_grad(self, backend_name):

class IdWithIdGrad(base.Layer):
class IdWithIdGrad(tl.Layer):

def forward(self, x):
return x
Expand All @@ -114,8 +111,8 @@ def backward(self, inputs, output, grad, weights, state, new_state, rng):
layer = IdWithIdGrad()
rng = fastmath.random.get_prng(0)
input_signature = shapes.ShapeDtype((9, 17))
random_input = fastmath.random.uniform(rng, input_signature.shape,
minval=-1.0, maxval=1.0)
random_input = fastmath.random.uniform(
rng, input_signature.shape, minval=-1.0, maxval=1.0)
layer.init(input_signature)
f = lambda x: jnp.mean(layer(x))
grad = fastmath.grad(f)(random_input)
Expand All @@ -124,7 +121,7 @@ def backward(self, inputs, output, grad, weights, state, new_state, rng):

def test_weights_and_state_signature(self):

class MyLayer(base.Layer):
class MyLayer(tl.Layer):

def init_weights_and_state(self, input_signature):
self.weights = jnp.zeros((2, 3))
Expand All @@ -139,18 +136,18 @@ def forward(self, inputs):
self.assertEqual(s.shape, (3, 4))

def test_custom_name(self):
layer = base.Layer()
layer = tl.Layer()
self.assertIn('Layer', str(layer))
self.assertNotIn('CustomLayer', str(layer))

layer = base.Layer(name='CustomLayer')
layer = tl.Layer(name='CustomLayer')
self.assertIn('CustomLayer', str(layer))


class PureLayerTest(absltest.TestCase):

def test_forward(self):
layer = base.PureLayer(lambda x: 2 * x)
layer = tl.PureLayer(lambda x: 2 * x)

# Use Layer.__call__.
in_0 = np.array([1, 2])
Expand All @@ -164,29 +161,27 @@ def test_forward(self):

# Use Layer.pure_fn
in_2 = np.array([5, 6])
out_2, _ = layer.pure_fn(
in_2, base.EMPTY_WEIGHTS, base.EMPTY_WEIGHTS, None)
out_2, _ = layer.pure_fn(in_2, tl.EMPTY_WEIGHTS, tl.EMPTY_WEIGHTS, None)
self.assertEqual(out_2.tolist(), [10, 12])


class FnTest(absltest.TestCase):

def test_bad_f_has_default_arg(self):
with self.assertRaisesRegex(ValueError, 'default arg'):
_ = base.Fn('', lambda x, sth=None: x)
_ = tl.Fn('', lambda x, sth=None: x)

def test_bad_f_has_keyword_arg(self):
with self.assertRaisesRegex(ValueError, 'keyword arg'):
_ = base.Fn('', lambda x, **kwargs: x)
_ = tl.Fn('', lambda x, **kwargs: x)

def test_bad_f_has_variable_arg(self):
with self.assertRaisesRegex(ValueError, 'variable arg'):
_ = base.Fn('', lambda *args: args[0])
_ = tl.Fn('', lambda *args: args[0])

def test_forward(self):
layer = base.Fn('SumAndMax',
lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),
n_out=2)
layer = tl.Fn(
'SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)

x0 = np.array([1, 2, 3, 4, 5])
x1 = np.array([10, 20, 30, 40, 50])
Expand All @@ -199,16 +194,17 @@ def test_forward(self):
self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55])
self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50])

(y4, y5), state = layer.pure_fn(
(x0, x1), base.EMPTY_WEIGHTS, base.EMPTY_STATE, None)
(y4, y5), state = layer.pure_fn((x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE,
None)
self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55])
self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50])
self.assertEqual(state, base.EMPTY_STATE)
self.assertEqual(state, tl.EMPTY_STATE)

def test_weights_state(self):
layer = base.Fn(
layer = tl.Fn(
'2in2out',
lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), n_out=2)
lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)),
n_out=2)
layer.init_weights_and_state(None)
self.assertEmpty(layer.weights)
self.assertEmpty(layer.state)
Expand Down
18 changes: 8 additions & 10 deletions trax/layers/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np

from trax import shapes
import trax.layers as tl # Flattened view for API users (hides subpackaging).
import trax.layers as tl


class DenseTest(absltest.TestCase):
Expand All @@ -43,8 +43,7 @@ def test_call_uses_and_caches_supplied_weights(self):
w_init, b_init = layer.weights

# Call the layer with externally specified weights.
w = np.array([[10000, 20000, 30000, 40000],
[100, 200, 100, 200]])
w = np.array([[10000, 20000, 30000, 40000], [100, 200, 100, 200]])
b = np.array([9, 8, 7, 6])
y = layer(x, weights=(w, b))

Expand Down Expand Up @@ -85,9 +84,7 @@ def test_call_no_bias(self):
x = np.array([2, 5, 3])
_, _ = layer.init(shapes.signature(x))

w = np.array([[100, 200, 300, 400],
[10, 10, 10, 10],
[1, 2, 1, 2]])
w = np.array([[100, 200, 300, 400], [10, 10, 10, 10], [1, 2, 1, 2]])
y = layer(x, weights=w)
self.assertEqual(y.tolist(), [253, 456, 653, 856])

Expand Down Expand Up @@ -164,6 +161,7 @@ def test_new_weights(self):
self.assertLess(np.abs(np.mean(w)), .4) # .4 is 4 sigma deviation

def test_explicit_kernel_initializer(self):

def f(shape, rng):
del rng
n_elements = np.prod(shape)
Expand All @@ -188,8 +186,8 @@ def test_call_in_train_mode(self):
n_remaining = np.count_nonzero(y)
mu_of_remaining = 9000 # N * q: 10000 * .9
sigma_of_remaining = 30 # sqrt(N * p * q): sqrt(10000 * .1 * .9)
self.assertLess(np.abs(n_remaining - mu_of_remaining),
4 * sigma_of_remaining)
self.assertLess(
np.abs(n_remaining - mu_of_remaining), 4 * sigma_of_remaining)

def test_call_in_eval_mode_does_no_dropout(self):
layer = tl.Dropout(rate=0.1, mode='eval')
Expand Down Expand Up @@ -238,7 +236,7 @@ def test_log_gaussian_pdf(self):
x = np.zeros((2, 5), dtype=np.float32)
mu = x
dsigma = np.eye(5)[None, :, :]
sigma = np.concatenate([dsigma, 2*dsigma], axis=0)
sigma = np.concatenate([dsigma, 2 * dsigma], axis=0)
prob = tl.log_gaussian_pdf(x, mu, sigma)
self.assertEqual(prob.shape, (2,))
self.assertEqual(int(prob[0]), -4)
Expand All @@ -248,7 +246,7 @@ def test_log_gaussian_diag_pdf(self):
x = np.zeros((2, 5), dtype=np.float32)
mu = x
sigma = np.ones((5,))[None, :]
sigma = np.concatenate([sigma, 2*sigma], axis=0)
sigma = np.concatenate([sigma, 2 * sigma], axis=0)
prob = tl.log_gaussian_diag_pdf(x, mu, sigma)
self.assertEqual(prob.shape, (2,))
self.assertEqual(int(prob[0]), -4)
Expand Down

0 comments on commit cd1a3d3

Please sign in to comment.