Skip to content

Commit

Permalink
Add unary xeinsum and allow named axis reductions for unary and binar…
Browse files Browse the repository at this point in the history
…y xeinsums
  • Loading branch information
bloops authored and apaszke committed Apr 26, 2022
1 parent 098f212 commit a147046
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 45 deletions.
110 changes: 67 additions & 43 deletions jax/_src/lax/parallel.py
Expand Up @@ -36,7 +36,7 @@
from jax._src.lax import slicing
from jax._src.numpy import lax_numpy
import jax._src.util as util
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo

Expand Down Expand Up @@ -419,50 +419,74 @@ def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ()),
precision=lax.canonicalize_precision(precision))


def xeinsum(spec: str, x, y):
def xeinsum(spec: str, *operands):
in_spec, out_spec = spec.split('->')
(lhs_subs, lhs_named), (rhs_subs, rhs_named) = XeinsumSpecParser(in_spec).parse_args()
all_in_subs, all_in_named = unzip2(XeinsumSpecParser(in_spec).parse_args())
(out_subs, out_named), = XeinsumSpecParser(out_spec).parse_args()
all_named = {*lhs_named, *rhs_named, *out_named}
all_subs = {*lhs_subs, *rhs_subs, *out_subs}
lhs_uniques = set(lhs_subs) - set(rhs_subs)
rhs_uniques = set(rhs_subs) - set(lhs_subs)
if all_subs & all_named:
raise NotImplementedError
if not set(out_named).issubset({*lhs_named, *rhs_named}):
raise ValueError

# if a named axis appears in both inputs and not the output, contract!
named_contract = list(all_named - set(out_named))

# if a subscript appears in both inputs and not the outputs, contract!
subs_contract = all_subs - set(out_subs)

lhs_reduce_axes = [lhs_subs.index(n) for n in lhs_uniques & subs_contract]
if lhs_reduce_axes:
x = lax._reduce_sum(x, lhs_reduce_axes)
for i in sorted(lhs_reduce_axes, reverse=True):
del lhs_subs[i]

rhs_reduce_axes = [rhs_subs.index(n) for n in rhs_uniques & subs_contract]
if rhs_reduce_axes:
y = lax._reduce_sum(y, rhs_reduce_axes)
for i in sorted(rhs_reduce_axes, reverse=True):
del rhs_subs[i]

pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n))
for n in subs_contract - (lhs_uniques | rhs_uniques))

# if a subscript apperas in both inputs _and_ the outputs, batch!
subs_batch = all_subs - subs_contract
if subs_batch & (lhs_uniques | rhs_uniques):
raise NotImplementedError

pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n))
for n in subs_batch)

return pdot(x, y, axis_name=named_contract,
pos_contract=pos_contract, pos_batch=pos_batch)

if len(operands) != len(all_in_named):
raise ValueError("Expecting the same number of argument specs in the "
"subscript ({in_spec}) as the number of operands. But got "
"{len(all_in_named)} argument specs for "
"{len(operands)} operands")

if len(operands) > 2:
raise NotImplementedError("Only one or two operands are supported. "
f"But got {len(operands)} operands")

# output subs and named axes must appear in at least one of the inputs.
if not set(out_named).issubset(set().union(*all_in_named)):
raise ValueError("Found named axes "
f"{set(out_named) - set().union(*all_in_named)} "
"appearing in the output spec but not in the input")
if not set(out_subs).issubset(set().union(*all_in_subs)):
raise ValueError("Found subscript(s) "
f"{set(out_subs) - set().union(*all_in_subs)} "
"appearing in the output spec but not in the input")

xs = list(operands)
for idx, (in_subs, in_named) in enumerate(safe_zip(all_in_subs, all_in_named)):
# if a subscript axis appears only in one input and not the output, reduce!
other_named = set().union( # type: ignore
*[named for i, named in enumerate(all_in_named) if i != idx])
other_subs = set().union( # type: ignore
*[subs for i, subs in enumerate(all_in_subs) if i != idx])

subs_reduce = list(set(in_subs) - {*out_subs, *other_subs})
subs_reduce_axes = [in_subs.index(n) for n in subs_reduce]
named_reduce_axes = list(set(in_named) - {*out_named, *other_named})

if subs_reduce_axes or named_reduce_axes:
xs[idx] = psum(xs[idx], axis_name=subs_reduce_axes + named_reduce_axes)
for i in sorted(subs_reduce_axes, reverse=True):
del all_in_subs[idx][i]
for named_axis in named_reduce_axes:
all_in_named[idx].remove(named_axis)

if len(operands) == 1:
return xs[0]

if len(operands) == 2:
x, y = xs
lhs_subs, rhs_subs = all_in_subs
lhs_named, rhs_named = all_in_named

# if a named axis appears in both inputs and not the output, contract!
named_contract = list((set(lhs_named) & set(rhs_named)) - set(out_named))

# if a subscript appears in both inputs and not the outputs, contract!
subs_contract = (set(lhs_subs) & set(rhs_subs)) - set(out_subs)

pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n))
for n in subs_contract)

# if a subscript appears in both inputs _and_ the outputs, batch!
subs_batch = (set(lhs_subs) & set(rhs_subs)) - subs_contract
pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n)) for n in subs_batch)

return pdot(x, y, axis_name=named_contract,
pos_contract=pos_contract, pos_batch=pos_batch)


class XeinsumSpecParser:
spec: str
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -2826,8 +2826,7 @@ def einsum(*operands, out=None, optimize='optimal', precision=None,
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")

if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0] and
len(operands[1:]) == 2):
if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0]):
return lax.xeinsum(*operands)

optimize = 'optimal' if optimize is True else optimize
Expand Down
110 changes: 110 additions & 0 deletions tests/xmap_test.py
Expand Up @@ -1332,6 +1332,15 @@ def test_xeinsum_no_named_axes_batch_vector_dot(self):
expected = np.einsum('ij,ij->i', x, y)
self.assertAllClose(out, expected, check_dtypes=True)

def test_xeinsum_no_named_axes_batch_matmul(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 5, 4)
y = rng.randn(3, 4, 2)
out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True)
expected = np.einsum('bij,bjk->bik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)

def test_xeinsum_no_named_axes_reduce_sum(self):
rng = self.rng()
x = rng.randn(3)
Expand All @@ -1341,6 +1350,107 @@ def test_xeinsum_no_named_axes_reduce_sum(self):
self.assertAllClose(out, expected, check_dtypes=True)


def test_xeinsum_no_named_axes_reduce_and_contract(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 5, 4)
y = rng.randn(2, 4, 2)
out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True)
expected = np.einsum('bij,cjk->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)

def test_xeinsum_named_axes_reduce(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 4)
y = rng.randn(5,)

def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=(['i', 'j'], ['k']),
out_axes=['i', 'k'])(x, y)
expected = np.einsum('ij,k->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
check('{i,j},{k}->{i,k}')

@jtu.with_mesh([('x', 2), ('y', 2)])
def test_xeinsum_named_axes_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(6, 4)
y = rng.randn(8,)

def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=(['i', 'j'], ['k']),
out_axes=['i', 'k'],
axis_resources={'i': 'x', 'k': 'y'})(x, y)
expected = np.einsum('ij,k->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)

check('{i,j},{k}->{i,k}')
check('{i,j},{k}->{k,i}') # order of named axes in the spec doesn't matter!
check('{j,i},{k}->{i,k}')
check('{j,i},{k}->{k,i}')

@jtu.with_mesh([('x', 2), ('y', 2)])
def test_xeinsum_named_axes_batch_matmul_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 3, 4)
y = rng.randn(8, 4, 5)

def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=(['b', 'i', 'j'], ['b', 'j', 'k']),
out_axes=['b', 'i', 'k'],
axis_resources={'b': 'x', 'j': 'y'})(x, y)
expected = np.einsum('bij,bjk->bik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)

check('{b,i,j},{b,j,k}->{b,i,k}')
check('{j,i,b},{j,b,k}->{i,b,k}') # order of named axes in the spec doesn't matter!

@jtu.with_mesh([('x', 2), ('y', 2)])
def test_xeinsum_named_axes_unary_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 6, 4)

def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=['b', 'i', 'j'],
out_axes=['b'],
axis_resources={'b': 'x', 'i': 'y'})(x)
expected = np.einsum('bij->b', x)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)

check('{b,i,j}->{b}')
check('{b,j,i}->{b}') # order of named axes in the spec doesn't matter!
check('{i,j,b}->{b}')

@jtu.with_mesh([('x', 2), ('y', 2)])
def test_xeinsum_mixed_axes_unary_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 6, 4, 5)

def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=['b', 'i', ...],
out_axes=['b', ...],
axis_resources={'b': 'x', 'i': 'y'})(x)
expected = np.einsum('bijk->bk', x)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)

check('jk{i,b}->k{b}')


class XMapErrorTest(jtu.JaxTestCase):

@jtu.with_mesh([('x', 2)])
Expand Down

0 comments on commit a147046

Please sign in to comment.