Skip to content

Commit

Permalink
Try to avoid transposes in jnp.einsum by considering both argument or…
Browse files Browse the repository at this point in the history
…ders to dot_general.
  • Loading branch information
hawkinsp committed Mar 4, 2021
1 parent 6c102d9 commit 9832df8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
18 changes: 14 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3866,11 +3866,21 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
batch_names_str = ''.join(batch_names)
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
for n in contracted_names)
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
operand = lax.dot_general(lhs, rhs, dimension_numbers, precision)
deleted_names = batch_names_str + ''.join(contracted_names)
names = (batch_names_str + _removechars(lhs_names, deleted_names)
+ _removechars(rhs_names, deleted_names))
remaining_lhs_names = _removechars(lhs_names, deleted_names)
remaining_rhs_names = _removechars(rhs_names, deleted_names)
# Try both orders of lhs and rhs, in the hope that one of them means we
# don't need an explicit transpose. opt_einsum likes to contract from
# right to left, so we expect (rhs,lhs) to have the best chance of not
# needing a transpose.
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
if names == result_names:
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
operand = lax.dot_general(rhs, lhs, dimension_numbers, precision)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
operand = lax.dot_general(lhs, rhs, dimension_numbers, precision)
else:
raise NotImplementedError # if this is actually reachable, open an issue!

Expand Down
9 changes: 9 additions & 0 deletions tests/lax_numpy_einsum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@


from collections import defaultdict
from functools import partial
import itertools

import numpy as np
from absl.testing import absltest
from absl.testing import parameterized

import jax
from jax import lax
import jax.numpy as jnp
import jax.test_util as jtu
Expand Down Expand Up @@ -339,6 +341,13 @@ def test_broadcasting_issue_2189(self):
s = '...ij,...j'
self._check(s, x, y)

def test_no_unnecessary_transpose(self):
r = self.rng()
x = r.randn(2, 2, 2)
y = r.randn(2, 2)
jaxpr = jax.make_jaxpr(partial(jnp.einsum, "ijk,kl->ijl"))(x, y)
self.assertNotIn('transpose', str(jaxpr))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 9832df8

Please sign in to comment.