Skip to content

Commit

Permalink
Merge pull request #41 from fritzo/torch-rename
Browse files Browse the repository at this point in the history
Rename symbols to support torch.einsum
  • Loading branch information
dgasmith committed Aug 18, 2018
2 parents 7ef12a5 + 2656141 commit 49e2a91
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 18 deletions.
11 changes: 6 additions & 5 deletions opt_einsum/backends/torch.py
Expand Up @@ -5,8 +5,7 @@
from __future__ import absolute_import
import numpy as np

from ..parser import einsum_symbols_base

from ..parser import convert_to_valid_einsum_chars, einsum_symbols_base

_TORCH_DEVICE = None

Expand All @@ -31,6 +30,10 @@ def transpose(a, axes):
def einsum(equation, *operands):
"""Variadic version of torch.einsum to match numpy api.
"""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
equation = convert_to_valid_einsum_chars(equation)

torch, _ = _get_torch_and_device()
return torch.einsum(equation, operands)

Expand All @@ -39,8 +42,6 @@ def tensordot(x, y, axes=2):
"""Simple translation of tensordot syntax to einsum.
"""
# XXX: tensordot should be directly implemented in torch soon
torch, _ = _get_torch_and_device()

xnd = x.ndimension()
ynd = y.ndimension()

Expand Down Expand Up @@ -80,7 +81,7 @@ def tensordot(x, y, axes=2):

# form full string and contract!
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
return torch.einsum(einsum_str, (x, y))
return einsum(einsum_str, x, y)


def to_torch(array):
Expand Down
4 changes: 2 additions & 2 deletions opt_einsum/helpers.py
Expand Up @@ -4,8 +4,8 @@

import numpy as np

chars = 'abcdefghijklmopq'
sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3])
chars = 'abcdefghijklmopqABC'
sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4])
default_dim_dict = {c: s for c, s in zip(chars, sizes)}


Expand Down
15 changes: 4 additions & 11 deletions opt_einsum/parser.py
Expand Up @@ -57,18 +57,11 @@ def gen_unused_symbols(used, n):

def convert_to_valid_einsum_chars(einsum_str):
"""Convert the str ``einsum_str`` to contain only the alphabetic characters
valid for numpy einsum.
valid for numpy einsum. If there are too many symbols, let the backend
throw an error.
"""
# partition into valid and invalid sets
valid, invalid = set(), set()
for x in einsum_str:
(valid if is_valid_einsum_char(x) else invalid).add(x)

# get replacements for invalid chars that are not already used
available = gen_unused_symbols(valid, len(invalid))

# map invalid to available and replace in the inputs
replacer = dict(zip(invalid, available))
symbols = sorted(set(einsum_str) - set(',->'))
replacer = {x: get_symbol(i) for i, x in enumerate(symbols)}
return "".join(replacer.get(x, x) for x in einsum_str)


Expand Down
1 change: 1 addition & 0 deletions opt_einsum/tests/test_backends.py
Expand Up @@ -35,6 +35,7 @@
'ijk,ikj',
'i,j->ij',
'ijk,k->ij',
'AB,BC->CA',
]


Expand Down

0 comments on commit 49e2a91

Please sign in to comment.