Skip to content

Commit

Permalink
adding sum and matmul to backends (google#681)
Browse files Browse the repository at this point in the history
* adding sum and matmul to backends

* formatting

* change to jitted_functions.py
  • Loading branch information
mganahl committed Jun 25, 2020
1 parent be7713a commit 6f62db9
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 15 deletions.
37 changes: 35 additions & 2 deletions tensornetwork/backends/abstract_backend.py
Expand Up @@ -342,7 +342,7 @@ def eigs(self,
numeig: int = 1,
tol: float = 1E-8,
which: Text = 'LR',
maxiter: Optional[int] = None) -> List[Tensor]:
maxiter: Optional[int] = None) -> Tuple[Tensor, List]:
"""Arnoldi method for finding the lowest eigenvector-eigenvalue pairs
of a linear operator `A`. `A` is a callable implementing the
matrix-vector product. If no `initial_state` is provided then
Expand Down Expand Up @@ -390,7 +390,7 @@ def eigsh_lanczos(self,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 20,
reorthogonalize: bool = False) -> Tuple[List, List]:
reorthogonalize: bool = False) -> Tuple[Tensor, List]:
"""
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
of `A`.
Expand Down Expand Up @@ -607,3 +607,36 @@ def jit(self, fun: Callable, *args: Any, **kwargs: Any) -> Callable:
"""
raise NotImplementedError("Backend '{}' has not implemented `jit`.".format(
self.name))

def sum(self,
tensor: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False) -> Tensor:
"""
Sum elements of `tensor` along the specified `axis`. Results in a
new Tensor with the summed axis removed.
Args:
tensor: An input tensor.
Returns:
tensor: The result of performing the summation. The order of the tensor
will be reduced by 1.
"""
raise NotImplementedError("Backend '{}' has not implemented `sum`.".format(
self.name))

def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
"""
Perform a possibly batched matrix-matrix multiplication
between `tensor1` and `tensor2`. The following behaviour
is similar to `numpy.matmul`:
- If both arguments are 2-D they are multiplied like conventional
matrices.
- If either argument is N-D, N > 2, it is treated as a stack of
matrices residing in the last two indexes and broadcast accordingly.
Both arguments to `matmul` have to be tensors of order >= 2.
Args:
tensor1: An input tensor.
tensor2: An input tensor.
Returns:
tensor: The result of performing the matmul.
"""
17 changes: 14 additions & 3 deletions tensornetwork/backends/jax/jax_backend.py
Expand Up @@ -241,7 +241,7 @@ def eigs(self,
numeig: int = 6,
tol: float = 1E-8,
which: Text = 'LR',
maxiter: int = 20) -> Tuple[List, List]:
maxiter: int = 20) -> Tuple[Tensor, List]:
"""
Implicitly restarted Arnoldi method for finding the lowest
eigenvector-eigenvalue pairs of a linear operator `A`.
Expand Down Expand Up @@ -344,7 +344,7 @@ def eigsh_lanczos(
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 10,
reorthogonalize: Optional[bool] = False) -> Tuple[List, List]:
reorthogonalize: Optional[bool] = False) -> Tuple[Tensor, List]:
"""
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
of a hermitian linear operator `A`. `A` is a function implementing
Expand Down Expand Up @@ -410,7 +410,7 @@ def A(H,x):
explicit orthogonalization (more costly than `reorthogonalize=False`)
Returns:
(eigvals, eigvecs)
eigvals: A list of `numeig` lowest eigenvalues
eigvals: An jax-array containing `numeig` lowest eigenvalues
eigvecs: A list of `numeig` lowest eigenvectors
"""
if args is None:
Expand Down Expand Up @@ -510,3 +510,14 @@ def expm(self, matrix: Tensor) -> Tensor:

def jit(self, fun: Callable, *args: List, **kwargs: dict) -> Callable:
return libjax.jit(fun, *args, **kwargs)

def sum(self,
tensor: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False) -> Tensor:
return np.sum(tensor, axis=axis, keepdims=keepdims)

def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
if (tensor1.ndim <= 1) or (tensor2.ndim <= 1):
raise ValueError("inputs to `matmul` have to be a tensors of order > 1,")
return jnp.matmul(tensor1, tensor2)
29 changes: 28 additions & 1 deletion tensornetwork/backends/jax/jax_backend_test.py
Expand Up @@ -661,5 +661,32 @@ def test_eigs_raises():
backend.eigs(lambda x: x, initial_state=[1, 2, 3])
for which in ('SI', 'LI', 'SM', 'SR'):
with pytest.raises(
ValueError, match=f'which = {which} is currently not supported.'):
ValueError, match=f"which = {which}"
f" is currently not supported."):
backend.eigs(lambda x: x, which=which)


def test_sum():
np.random.seed(10)
backend = jax_backend.JaxBackend()
tensor = np.random.rand(2, 3, 4)
a = backend.convert_to_tensor(tensor)
actual = backend.sum(a, axis=(1, 2))
expected = np.sum(tensor, axis=(1, 2))
np.testing.assert_allclose(expected, actual)

actual = backend.sum(a, axis=(1, 2), keepdims=True)
expected = np.sum(a, axis=(1, 2), keepdims=True)
np.testing.assert_allclose(expected, actual)


def test_matmul():
np.random.seed(10)
backend = jax_backend.JaxBackend()
t1 = np.random.rand(10, 2, 3)
t2 = np.random.rand(10, 3, 4)
a = backend.convert_to_tensor(t1)
b = backend.convert_to_tensor(t2)
actual = backend.matmul(a, b)
expected = np.matmul(t1, t2)
np.testing.assert_allclose(expected, actual)
4 changes: 2 additions & 2 deletions tensornetwork/backends/jax/jitted_functions.py
Expand Up @@ -53,8 +53,8 @@ def jax_lanczos(matvec, arguments, init, ncv, neig, landelta, reortho):
reortho: If `True`, reorthogonalize all krylov vectors at each step.
This should be used if `neig>1`.
Returns:
list: Eigen values
list: Eigen values
jax.numpy.ndarray: Eigenvalues
list: Eigenvectors
"""

def body_modified_gram_schmidt(i, vals):
Expand Down
18 changes: 14 additions & 4 deletions tensornetwork/backends/numpy/numpy_backend.py
Expand Up @@ -34,7 +34,6 @@ def tensordot(self, a: Tensor, b: Tensor,
if (len(axes[0]) == a.ndim) and (len(axes[1]) == b.ndim):
if not len(axes[0]) == len(axes[1]):
raise ValueError("shape-mismatch for sum")

u, pos1, _ = np.intersect1d(
axes[0], axes[1], return_indices=True, assume_unique=True)
labels = int_to_string[0:len(u)]
Expand Down Expand Up @@ -206,7 +205,7 @@ def eigs(self,
numeig: int = 6,
tol: float = 1E-8,
which: Text = 'LR',
maxiter: Optional[int] = None) -> Tuple[List, List]:
maxiter: Optional[int] = None) -> Tuple[Tensor, List]:
"""
Arnoldi method for finding the lowest eigenvector-eigenvalue pairs
of a linear operator `A`. `A` can be either a
Expand Down Expand Up @@ -284,7 +283,7 @@ def matvec(vector):
if dtype:
eta = eta.astype(dtype)
U = U.astype(dtype)
return list(eta), [np.reshape(U[:, n], shape) for n in range(numeig)]
return eta, [np.reshape(U[:, n], shape) for n in range(numeig)]

def eigsh_lanczos(self,
A: Callable,
Expand All @@ -297,7 +296,7 @@ def eigsh_lanczos(self,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 20,
reorthogonalize: bool = False) -> Tuple[List, List]:
reorthogonalize: bool = False) -> Tuple[Tensor, List]:
"""
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
of a linear operator `A`.
Expand Down Expand Up @@ -474,3 +473,14 @@ def expm(self, matrix: Tensor) -> Tensor:

def jit(self, fun: Callable, *args: List, **kwargs: dict) -> Callable:
return fun

def sum(self,
tensor: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False) -> Tensor:
return np.sum(tensor, axis=tuple(axis), keepdims=keepdims)

def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
if (tensor1.ndim <= 1) or (tensor2.ndim <= 1):
raise ValueError("inputs to `matmul` have to be a tensors of order > 1,")
return np.matmul(tensor1, tensor2)
26 changes: 26 additions & 0 deletions tensornetwork/backends/numpy/numpy_backend_test.py
Expand Up @@ -743,3 +743,29 @@ def test_tensordot_inner():
actual = backend.tensordot(a, b, ((0, 1, 2), (1, 2, 0)))
expected = np.tensordot(a, b, ((0, 1, 2), (1, 2, 0)))
np.testing.assert_allclose(expected, actual)


def test_sum():
np.random.seed(10)
backend = numpy_backend.NumPyBackend()
tensor = np.random.rand(2, 3, 4)
a = backend.convert_to_tensor(tensor)
actual = backend.sum(a, axis=(1, 2))
expected = np.sum(tensor, axis=(1, 2))
np.testing.assert_allclose(expected, actual)

actual = backend.sum(a, axis=(1, 2), keepdims=True)
expected = np.sum(a, axis=(1, 2), keepdims=True)
np.testing.assert_allclose(expected, actual)


def test_matmul():
np.random.seed(10)
backend = numpy_backend.NumPyBackend()
t1 = np.random.rand(10, 2, 3)
t2 = np.random.rand(10, 3, 4)
a = backend.convert_to_tensor(t1)
b = backend.convert_to_tensor(t2)
actual = backend.matmul(a, b)
expected = np.matmul(t1, t2)
np.testing.assert_allclose(expected, actual)
14 changes: 13 additions & 1 deletion tensornetwork/backends/pytorch/pytorch_backend.py
Expand Up @@ -185,7 +185,7 @@ def eigsh_lanczos(self,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 20,
reorthogonalize: bool = False) -> Tuple[List, List]:
reorthogonalize: bool = False) -> Tuple[Tensor, List]:
"""
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
of a `LinearOperator` `A`.
Expand Down Expand Up @@ -346,3 +346,15 @@ def broadcast_left_multiplication(self, tensor1: Tensor,

def jit(self, fun: Callable, *args: List, **kwargs: dict) -> Callable:
return fun

def sum(self,
tensor: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False) -> Tensor:
return torchlib.sum(tensor, axis=axis, keepdim=keepdims)

def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
if (tensor1.ndim <= 1) or (tensor2.ndim <= 1):
raise ValueError("inputs to `matmul` have to be a tensors of order > 1,")

return torchlib.einsum('mab,mbc->mac', tensor1, tensor2)
22 changes: 22 additions & 0 deletions tensornetwork/backends/pytorch/pytorch_backend_test.py
Expand Up @@ -537,3 +537,25 @@ def test_sparse_shape():
backend = pytorch_backend.PyTorchBackend()
tensor = backend.randn((2, 3, 4), dtype=dtype, seed=10)
np.testing.assert_allclose(backend.sparse_shape(tensor), tensor.shape)


def test_sum():
np.random.seed(10)
backend = pytorch_backend.PyTorchBackend()
tensor = np.random.rand(2, 3, 4)
a = backend.convert_to_tensor(tensor)
actual = backend.sum(a, axis=(1, 2))
expected = np.sum(tensor, axis=(1, 2))
np.testing.assert_allclose(expected, actual)


def test_matmul():
np.random.seed(10)
backend = pytorch_backend.PyTorchBackend()
t1 = np.random.rand(10, 2, 3)
t2 = np.random.rand(10, 3, 4)
a = backend.convert_to_tensor(t1)
b = backend.convert_to_tensor(t2)
actual = backend.matmul(a, b)
expected = np.matmul(t1, t2)
np.testing.assert_allclose(expected, actual)
24 changes: 22 additions & 2 deletions tensornetwork/backends/shell/shell_backend.py
Expand Up @@ -157,6 +157,7 @@ def trace(self, tensor: Tensor) -> Tensor:

def outer_product(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return ShellTensor(tensor1.shape + tensor2.shape)

#pylint: disable=unused-argument
def einsum(self,
expression: str,
Expand Down Expand Up @@ -239,7 +240,7 @@ def eigs(self,
numeig: Optional[int] = 1,
tol: Optional[float] = 1E-8,
which: Optional[Text] = 'LR',
maxiter: Optional[int] = None) -> Tuple[List, List]:
maxiter: Optional[int] = None) -> Tuple[Tensor, List]:
if args is None:
args = []

Expand Down Expand Up @@ -270,7 +271,7 @@ def eigsh_lanczos(self,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 20,
reorthogonalize: bool = False) -> Tuple[List, List]:
reorthogonalize: bool = False) -> Tuple[Tensor, List]:
if args is None:
args = []
if num_krylov_vecs < numeig:
Expand Down Expand Up @@ -347,3 +348,22 @@ def broadcast_left_multiplication(self, tensor1: Tensor,

def jit(self, fun: Callable, *args: List, **kwargs: dict) -> Callable:
return fun

def sum(self,
tensor: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False) -> Tensor:
if not keepdims:
newshape = np.delete(tensor.shape, axis)
else:
newshape = np.array(tensor.shape)
newshape[np.array(axis)] = 1
return ShellTensor(newshape)

def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
shape1 = np.array(tensor1.shape)[:-2]
shape2 = np.array(tensor2.shape)[:-2]
if not np.array_equal(shape1, shape2):
raise ValueError("shape mismatch for matmul")
new_shape = np.append(shape1, [tensor1.shape[-2], tensor2.shape[-1]])
return ShellTensor(new_shape)
22 changes: 22 additions & 0 deletions tensornetwork/backends/shell/shell_backend_test.py
Expand Up @@ -434,3 +434,25 @@ def test_index_update():
actual = backend.index_update(matrix, matrix, matrix)
assert isinstance(actual, shell_backend.ShellTensor)
assert actual.shape == (4, 4, 4)


def test_sum():
np.random.seed(10)
backend = shell_backend.ShellBackend()
a = backend.randn((2, 3, 4), seed=10)
actual = backend.sum(a, axis=(1, 2))
np.testing.assert_allclose(actual.shape, [
2,
])

actual = backend.sum(a, axis=(1, 2), keepdims=True)
np.testing.assert_allclose(actual.shape, [2, 1, 1])


def test_matmul():
np.random.seed(10)
backend = shell_backend.ShellBackend()
a = backend.randn((10, 2, 3), seed=10)
b = backend.randn((10, 3, 4), seed=10)
actual = backend.matmul(a, b)
np.testing.assert_allclose(actual.shape, [10, 2, 4])
12 changes: 12 additions & 0 deletions tensornetwork/backends/tensorflow/tensorflow_backend.py
Expand Up @@ -255,3 +255,15 @@ def expm(self, matrix: Tensor) -> Tensor:
def jit(self, fun: Callable, *args: List, **kwargs: dict) -> Callable:
# tf.function is slow and bad.
return fun

def sum(self,
tensor: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False) -> Tensor:
return tf.math.reduce_sum(tensor, axis=axis, keepdims=keepdims)

def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
if (tensor1.ndim <= 1) or (tensor2.ndim <= 1):
raise ValueError("inputs to `matmul` have to be a tensors of order > 1,")

return tf.matmul(tensor1, tensor2)
26 changes: 26 additions & 0 deletions tensornetwork/backends/tensorflow/tensorflow_backend_test.py
Expand Up @@ -481,3 +481,29 @@ def fun(x, A, y):
res3 = fun_jit(x, y=y, A=A)
np.testing.assert_allclose(res1, res2)
np.testing.assert_allclose(res1, res3)


def test_sum():
np.random.seed(10)
backend = tensorflow_backend.TensorFlowBackend()
tensor = np.random.rand(2, 3, 4)
a = backend.convert_to_tensor(tensor)
actual = backend.sum(a, axis=(1, 2))
expected = np.sum(tensor, axis=(1, 2))
np.testing.assert_allclose(expected, actual)

actual = backend.sum(a, axis=(1, 2), keepdims=True)
expected = np.sum(a, axis=(1, 2), keepdims=True)
np.testing.assert_allclose(expected, actual)


def test_matmul():
np.random.seed(10)
backend = tensorflow_backend.TensorFlowBackend()
t1 = np.random.rand(10, 2, 3)
t2 = np.random.rand(10, 3, 4)
a = backend.convert_to_tensor(t1)
b = backend.convert_to_tensor(t2)
actual = backend.matmul(a, b)
expected = np.matmul(t1, t2)
np.testing.assert_allclose(expected, actual)

0 comments on commit 6f62db9

Please sign in to comment.