Skip to content

Commit

Permalink
tensorflow-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Mar 24, 2021
1 parent c6ff814 commit 8be6349
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 11 deletions.
2 changes: 1 addition & 1 deletion opt_einsum/backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from . import object_arrays
from . import cupy as _cupy
from . import jax as _jax
from . import tensorflow as _tensorflow
from . import tensorflow2 as _tensorflow
from . import theano as _theano
from . import torch as _torch

Expand Down
52 changes: 52 additions & 0 deletions opt_einsum/backends/tensorflow2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Required functions for optimized contractions of numpy arrays using tensorflow.
"""
import numpy as np

from ..sharing import to_backend_cache_wrap

__all__ = ["build_expression", "evaluate_constants"]


_tensorflow = None


def _get_tensorflow_and_to_tensorflow():
global _tensorflow
if _tensorflow is None:
import tensorflow

@to_backend_cache_wrap
def to_tensorflow(array):
if isinstance(array, np.ndarray):
return tensorflow.convert_to_tensor(array)
return array

_tensorflow = tensorflow, to_tensorflow

return _tensorflow


def build_expression(_, expr): # pragma: no cover
"""Build a tensorflow function based on ``arrays`` and ``expr``.
"""
tensorflow, to_tensorflow = _get_tensorflow_and_to_tensorflow()
tensorflow_expr = tensorflow.function(
expr._contract,
autograph=False,
experimental_compile=True,
)

def tensorflow_contract(*arrays):
tf_arrays = tuple(map(to_tensorflow, arrays))
return tensorflow_expr(tf_arrays).numpy()

return tensorflow_contract


def evaluate_constants(const_arrays, expr): # pragma: no cover
"""Convert constant arguments to tensorflow arrays, and perform any possible
constant contractions.
"""
_, to_tensorflow = _get_tensorflow_and_to_tensorflow()
return expr(*[to_tensorflow(x) for x in const_arrays], backend='tensorflow', evaluate_constants=True)
34 changes: 24 additions & 10 deletions opt_einsum/tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
import numpy as np
import pytest

Expand All @@ -13,9 +14,15 @@
try:
import tensorflow as tf
# needed so tensorflow doesn't allocate all gpu mem
_TF_CONFIG = tf.ConfigProto()
_TF_CONFIG.gpu_options.allow_growth = True
found_tensorflow = True
try:
_TF_CONFIG = tf.ConfigProto()
_TF_CONFIG.gpu_options.allow_growth = True
found_tensorflow = "v1"
except AttributeError:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
tf.config.experimental.set_memory_growth(device, True)
found_tensorflow = "v2"
except ImportError:
found_tensorflow = False

Expand Down Expand Up @@ -58,6 +65,17 @@
]


@contextmanager
def maybe_tensorflow_session():
if found_tensorflow != 'v1':
yield
else:
sess = tf.Session(config=_TF_CONFIG)
with sess.as_default() as x:
yield x
sess.close()


@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
@pytest.mark.parametrize("string", tests)
def test_tensorflow(string):
Expand All @@ -68,10 +86,8 @@ def test_tensorflow(string):
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)

sess = tf.Session(config=_TF_CONFIG)
with sess.as_default():
with maybe_tensorflow_session():
expr(*views, backend='tensorflow', out=opt)
sess.close()

assert np.allclose(ein, opt)

Expand All @@ -93,7 +109,7 @@ def test_tensorflow_with_constants(constants):
expr = contract_expression(eq, *ops, constants=constants)

# check tensorflow
with tf.Session(config=_TF_CONFIG).as_default():
with maybe_tensorflow_session():
res_got = expr(var, backend='tensorflow')
assert all(array is None or infer_backend(array) == 'tensorflow'
for array in expr._evaluated_constants['tensorflow'])
Expand All @@ -117,9 +133,7 @@ def test_tensorflow_with_sharing(string):
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)

sess = tf.Session(config=_TF_CONFIG)

with sess.as_default(), sharing.shared_intermediates() as cache:
with maybe_tensorflow_session(), sharing.shared_intermediates() as cache:
tfl1 = expr(*views, backend='tensorflow')
assert sharing.get_sharing_cache() is cache
cache_sz = len(cache)
Expand Down

0 comments on commit 8be6349

Please sign in to comment.