Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/dgasmith/opt_einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Jun 27, 2018
2 parents b08d789 + a0268e2 commit 521fde9
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Function Reference
opt_einsum.contract.ContractExpression
opt_einsum.paths.optimal
opt_einsum.paths.greedy
opt_einsum.parser.get_symbol
7 changes: 4 additions & 3 deletions docs/source/ex_large_expr_with_greedy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ that looks like::

The meaning of this is not that important other than its a large, useful
contraction. For ``n=100`` it involves 200 different tensors and about 300
unique indices.
unique indices. With this many indices it can be useful to generate them with
the function :func:`~opt_einsum.parser.get_symbol`.

Let's set up the required einsum string:

Expand All @@ -42,7 +43,7 @@ Let's set up the required einsum string:
... # |
... # --O--
... j = 3 * i
... ul, ur, m, ll, lr = (oe.parser.einsum_symbols[i]
... ul, ur, m, ll, lr = (oe.get_symbol(i)
... for i in (j - 1, j + 2, j, j - 2, j + 1))
>>> einsum_str += "{}{}{},{}{}{},".format(m, ul, ur, m, ll, lr)
Expand All @@ -52,7 +53,7 @@ Let's set up the required einsum string:
... # --O
>>> i = n - 1
>>> j = 3 * i
>>> ul, m, ll, = (oe.parser.einsum_symbols[i] for i in (j - 1, j, j - 2))
>>> ul, m, ll, = (oe.get_symbol(i) for i in (j - 1, j, j - 2))
>>> einsum_str += "{}{},{}{}".format(m, ul, m, ll)
Generate the shapes:
Expand Down
6 changes: 3 additions & 3 deletions docs/source/greedy_path.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ The ``greedy`` path iterates through the possible pair contractions and chooses
The "best" contraction pair is determined by the smallest of the tuple ``(-removed_size, cost)`` where ``removed_size`` is the size of the contracted tensors minus the size of the tensor created and ``cost`` is the cost of the contraction.
Effectively, the algorithm chooses the best inner or dot product, Hadamard product, and then outer product at each iteration with a sieve to prevent large outer products.
This algorithm has proven to be quite successful for general production and only misses a few complex cases that make it slightly worse than the ``optimal`` algorithm.
Fortunately, these often only lead to increases in prefactor than missing the optimal scaling.
Fortunately, these often only lead to increases in prefactor than missing the optimal scaling.

The ``greedy`` scale like N^2 rather than factorially making ``greedy`` much more suitable for large numbers of contractions and has a lower prefactor that helps decrease latency.
As :mod:`opt_einsum` can handle more than a thousand unique indices the low scaling is especially important for very large contraction networks.
The ``greedy`` approach scales like N^2 rather than factorially, making ``greedy`` much more suitable for large numbers of contractions where the lower prefactor helps decrease latency.
As :mod:`opt_einsum` can handle an arbitrary number of indices the low scaling is especially important for very large contraction networks.
The ``greedy`` functionality is provided by :func:`~opt_einsum.paths.greedy`.
1 change: 1 addition & 0 deletions opt_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from .contract import contract, contract_path, contract_expression
from .parser import get_symbol
from . import paths
from . import blas
from . import helpers
Expand Down
3 changes: 1 addition & 2 deletions opt_einsum/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Contains helper functions for opt_einsum testing scripts
"""

import itertools
import numpy as np

chars = 'abcdefghijklmopq'
Expand Down Expand Up @@ -123,7 +122,7 @@ def find_contraction(positions, input_sets, output_set):
for i in sorted(positions, reverse=True):
idx_contract |= remaining.pop(i)

idx_remain = set(itertools.chain(output_set, *remaining))
idx_remain = output_set.union(*remaining)

new_result = idx_remain & idx_contract
idx_removed = (idx_contract - new_result)
Expand Down
64 changes: 40 additions & 24 deletions opt_einsum/parser.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
#!/usr/bin/env python
# coding: utf-8
"""
A functionally equivalent parser of the numpy.einsum input parser
"""

import sys

import numpy as np

einsum_symbols_base = 'abcdefghijklmnopqrstuvwxyz'
einsum_symbols = einsum_symbols_base + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'

# boost the number of symbols using unicode if python3
if sys.version_info[0] >= 3:
einsum_symbols += ''.join(map(chr, range(193, 688)))
einsum_symbols += ''.join(map(chr, range(913, 1367)))

einsum_symbols_set = set(einsum_symbols)
einsum_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'


def is_valid_einsum_char(x):
Expand All @@ -29,6 +22,39 @@ def has_valid_einsum_chars_only(einsum_str):
return all(map(is_valid_einsum_char, einsum_str))


def get_symbol(i):
"""Get the symbol corresponding to int ``i`` - runs through the usual 52
letters before resorting to unicode characters, starting at ``chr(192)``.
Examples
--------
>>> get_symbol(2)
'c'
>>> oe.get_symbol(200)
'Ŕ'
>>> oe.get_symbol(20000)
'京'
"""
if i < 52:
return einsum_symbols_base[i]
return chr(i + 140)


def gen_unused_symbols(used, n):
"""Generate ``n`` symbols that are not already in ``used``.
"""
i = cnt = 0
while cnt < n:
s = get_symbol(i)
i += 1
if s in used:
continue
yield s
cnt += 1


def convert_to_valid_einsum_chars(einsum_str):
"""Convert the str ``einsum_str`` to contain only the alphabetic characters
valid for numpy einsum.
Expand All @@ -39,7 +65,7 @@ def convert_to_valid_einsum_chars(einsum_str):
(valid if is_valid_einsum_char(x) else invalid).add(x)

# get replacements for invalid chars that are not already used
available = (x for x in einsum_symbols if x not in valid)
available = gen_unused_symbols(valid, len(invalid))

# map invalid to available and replace in the inputs
replacer = dict(zip(invalid, available))
Expand All @@ -52,8 +78,6 @@ def find_output_str(subscripts):
tmp_subscripts = subscripts.replace(",", "")
output_subscript = ""
for s in sorted(set(tmp_subscripts)):
if s not in einsum_symbols_set:
raise ValueError("Character %s is not a valid symbol." % s)
if tmp_subscripts.count(s) == 1:
output_subscript += s
return output_subscript
Expand Down Expand Up @@ -101,13 +125,6 @@ def parse_einsum_input(operands):
subscripts = operands[0].replace(" ", "")
operands = [possibly_convert_to_numpy(x) for x in operands[1:]]

# Ensure all characters are valid
for s in subscripts:
if s in '.,->':
continue
if s not in einsum_symbols_set:
raise ValueError("Character %s is not a valid symbol." % s)

else:
tmp_operands = list(operands)
operand_list = []
Expand All @@ -125,7 +142,7 @@ def parse_einsum_input(operands):
if s is Ellipsis:
subscripts += "..."
elif isinstance(s, int):
subscripts += einsum_symbols[s]
subscripts += get_symbol(s)
else:
raise TypeError("For this input type lists must contain " "either int or Ellipsis")
if num != last:
Expand All @@ -137,7 +154,7 @@ def parse_einsum_input(operands):
if s is Ellipsis:
subscripts += "..."
elif isinstance(s, int):
subscripts += einsum_symbols[s]
subscripts += get_symbol(s)
else:
raise TypeError("For this input type lists must contain " "either int or Ellipsis")
# Check for proper "->"
Expand All @@ -149,8 +166,7 @@ def parse_einsum_input(operands):
# Parse ellipses
if "." in subscripts:
used = subscripts.replace(".", "").replace(",", "").replace("->", "")
unused = list(einsum_symbols_set - set(used))
ellipse_inds = "".join(unused)
ellipse_inds = "".join(gen_unused_symbols(used, max(len(x.shape) for x in operands)))
longest = 0

# Do we have an output to account for?
Expand Down
2 changes: 1 addition & 1 deletion paper/paper.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
title: opt_einsum - A Python package for optimizing contraction order for einsum-like expressions
title: opt\_einsum - A Python package for optimizing contraction order for einsum-like expressions
tags:
- array
- tensors
Expand Down

0 comments on commit 521fde9

Please sign in to comment.