Skip to content

Commit

Permalink
Merge 4f73dc6 into 695859f
Browse files Browse the repository at this point in the history
  • Loading branch information
trax-robot committed May 14, 2020
2 parents 695859f + 4f73dc6 commit 593b980
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
6 changes: 3 additions & 3 deletions trax/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def __init__(self,
self._use_bias = use_bias

def forward(self, x, weights):
if len(weights) != 2:
raise ValueError(f'Weights has length {len(weights)}; should instead '
f'have two elements: w, b.')
if self._use_bias:
if not isinstance(weights, (tuple, list)):
raise ValueError(f'Weights should be a (w, b) tuple or list; '
f'instead got: {weights}')
w, b = weights
return jnp.matmul(x, w) + b # Affine map.
else:
Expand Down
5 changes: 3 additions & 2 deletions trax/layers/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,14 @@ def test_shared_instance_means_shared_weights(self):

def test_call_no_bias(self):
layer = tl.Dense(4, use_bias=False)
x = np.array([2, 3])
x = np.array([2, 5, 3])
_, _ = layer.init(shapes.signature(x))

w = np.array([[100, 200, 300, 400],
[10, 10, 10, 10],
[1, 2, 1, 2]])
y = layer(x, weights=w)
self.assertEqual(y.tolist(), [203, 406, 603, 806])
self.assertEqual(y.tolist(), [253, 456, 653, 856])

def test_new_weights_use_bias(self):
layer = tl.Dense(4)
Expand Down

0 comments on commit 593b980

Please sign in to comment.