Skip to content

Commit

Permalink
Merge pull request #397 from google/fix-dot-batch-rule
Browse files Browse the repository at this point in the history
fix broken dot batch rule case
  • Loading branch information
mattjj committed Feb 17, 2019
2 parents 0bcf3a3 + 793e055 commit 8756721
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
8 changes: 3 additions & 5 deletions jax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,11 +1326,9 @@ def _dot_batch_rule(batched_args, batch_dims):

assert lbd is not None and rbd is not None
assert lhs.ndim == rhs.ndim == 2 # dot only supports rank 1 and above
if lbd != 0:
batching.move_dim_to_front(lhs, lbd)
if rbd != 0:
batching.move_dim_to_front(rhs, rbd)
return dot_general(lhs, rhs, [((1,), (1,)), ((0,), (0,))])
lhs = batching.move_dim_to_front(lhs, lbd)
rhs = batching.move_dim_to_front(rhs, rbd)
return dot_general(lhs, rhs, [((1,), (1,)), ((0,), (0,))]), 0

if lbd is None:
assert rbd is not None
Expand Down
11 changes: 9 additions & 2 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,15 @@ def vecvec(a, b):

assert vecvec(np.zeros((3,)), np.zeros((3,))).shape == ()
assert vecvec(np.zeros((2, 3)), np.zeros((3,))).shape == (2,)
# TODO(mattjj): this fails due to an xla error in dot_general
# assert vecvec(np.zeros((4, 2, 3)), np.zeros((3,))).shape == (4, 2)
assert vecvec(np.zeros((4, 2, 3)), np.zeros((3,))).shape == (4, 2)

def testDot2(self):
R = onp.random.RandomState(0).randn
xs = R(10, 3)
ys = R(10, 3)
ans = vmap(np.dot)(xs, ys)
expected = onp.einsum('ni,ni->n', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)

def testPad(self):
R = onp.random.RandomState(0).randn
Expand Down

0 comments on commit 8756721

Please sign in to comment.