Skip to content

Commit

Permalink
allow backend=None (#196)
Browse files Browse the repository at this point in the history
* allow backend=None

* update backend typing

* optimize `infer_backend`

* tweak to help mypy
  • Loading branch information
jcmgray committed Jul 18, 2022
1 parent aa3dd9a commit 3824e4f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
18 changes: 12 additions & 6 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections import namedtuple
from decimal import Decimal
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

from . import backends, blas, helpers, parser, paths, sharing
Expand Down Expand Up @@ -542,15 +543,20 @@ def contract(*operands_: Any, **kwargs: Any) -> ArrayType:
return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)


@lru_cache(None)
def _infer_backend_class_cached(cls: type) -> str:
return cls.__module__.split(".")[0]


def infer_backend(x: Any) -> str:
return x.__class__.__module__.split(".")[0]
return _infer_backend_class_cached(x.__class__)


def parse_backend(arrays: Sequence[ArrayType], backend: str) -> str:
def parse_backend(arrays: Sequence[ArrayType], backend: Optional[str]) -> str:
"""Find out what backend we should use, dipatching based on the first
array if ``backend='auto'`` is specified.
"""
if backend != "auto":
if (backend != "auto") and (backend is not None):
return backend
backend = infer_backend(arrays[0])

Expand All @@ -565,7 +571,7 @@ def parse_backend(arrays: Sequence[ArrayType], backend: str) -> str:
def _core_contract(
operands_: Sequence[ArrayType],
contraction_list: ContractionListType,
backend: str = "auto",
backend: Optional[str] = "auto",
evaluate_constants: bool = False,
**einsum_kwargs: Any,
) -> ArrayType:
Expand Down Expand Up @@ -703,7 +709,7 @@ def __init__(
self._evaluated_constants: Dict[str, Any] = {}
self._backend_expressions: Dict[str, Any] = {}

def evaluate_constants(self, backend: str = "auto") -> None:
def evaluate_constants(self, backend: Optional[str] = "auto") -> None:
"""Convert any constant operands to the correct backend form, and
perform as many contractions as possible to create a new list of
operands, stored in ``self._evaluated_constants[backend]``. This also
Expand Down Expand Up @@ -746,7 +752,7 @@ def _contract(
self,
arrays: Sequence[ArrayType],
out: Optional[ArrayType] = None,
backend: str = "auto",
backend: Optional[str] = "auto",
evaluate_constants: bool = False,
) -> ArrayType:
"""The normal, core contraction."""
Expand Down
1 change: 1 addition & 0 deletions opt_einsum/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def test_auto_backend_custom_array_no_tensordot():
# Shaped is an array-like object defined by opt_einsum - which has no TDOT
assert infer_backend(x) == "opt_einsum"
assert parse_backend([x], "auto") == "numpy"
assert parse_backend([x], None) == "numpy"


@pytest.mark.parametrize("string", tests)
Expand Down

0 comments on commit 3824e4f

Please sign in to comment.