Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
alewis committed Jun 3, 2020
2 parents c52b16b + 6ef56e3 commit 15f9447
Show file tree
Hide file tree
Showing 20 changed files with 616 additions and 72 deletions.
4 changes: 4 additions & 0 deletions tensornetwork/backends/symmetric/symmetric_backend.py
Expand Up @@ -90,6 +90,10 @@ def diag(self, tensor: Tensor) -> Tensor:
return self.bs.diag(tensor)

def convert_to_tensor(self, tensor: Tensor) -> Tensor:
if numpy.isscalar(tensor):
tensor = BlockSparseTensor(
data=tensor, charges=[], flows=[], order=[], check_consistency=False)

if not isinstance(tensor, BlockSparseTensor):
raise TypeError(
"cannot convert tensor of type `{}` to `BlockSparseTensor`".format(
Expand Down
40 changes: 26 additions & 14 deletions tensornetwork/block_sparse/blocksparsetensor.py
Expand Up @@ -391,6 +391,26 @@ def __rmul__(self, number: np.number) -> "ChargeArray":
def __truediv__(self, number: np.number) -> "ChargeArray":
raise NotImplementedError("__truediv__ not implemented for ChargeArray")

def __repr__(self):
if len(self._charges) > 0:
charge_types = self._charges[0].names
else:
charge_types = 'no charge types (scalar)'
output = 'BlockSparseTensor\n shape: ' + repr(
self.shape
) + '\n charge types: ' + charge_types + '\n dtype: ' + repr(
self.dtype.name) + '\n flat flows: ' + repr(
self.flat_flows) + '\n order: ' + repr(
self._order) + '\n data:' + repr(self.data)

return output

def item(self):
if self.ndim == 0:
if len(self.data) == 1:
return self.data[0]
raise ValueError("can only convert an array of size 1 to a Python scalar")


class BlockSparseTensor(ChargeArray):
"""
Expand Down Expand Up @@ -499,7 +519,7 @@ def todense(self) -> np.ndarray:
Map the sparse tensor to dense storage.
"""
if len(self.shape) == 0:
if self.ndim == 0:
return self.data
out = np.asarray(np.zeros(self.shape, dtype=self.dtype).flat)
out[np.nonzero(
Expand Down Expand Up @@ -930,23 +950,15 @@ def tensordot(

#checks finished

#special case inner product
#special case inner product (returns an ndim=0 tensor)
if (len(axes1) == tensor1.ndim) and (len(axes2) == tensor2.ndim):
t1 = tensor1.transpose(axes1).transpose_data()
t2 = tensor2.transpose(axes2).transpose_data()
data = np.dot(t1.data, t2.data)
charge = tensor1._charges[0]
final_charge = charge.__new__(type(charge))

final_charge.__init__(
np.empty((charge.num_symmetries, 0), dtype=np.int16),
charge_labels=np.empty(0, dtype=np.int16),
charge_types=charge.charge_types)
return BlockSparseTensor(
data=data,
charges=[final_charge],
flows=[False],
order=[[0]],
data=np.dot(t1.data, t2.data),
charges=[],
flows=[],
order=[],
check_consistency=False)

#in all other cases we perform a regular tensordot
Expand Down
20 changes: 20 additions & 0 deletions tensornetwork/block_sparse/blocksparsetensor_test.py
Expand Up @@ -737,3 +737,23 @@ def test_flat_charges():
a = a.transpose(order)
for n, o in enumerate(a.flat_order):
charge_equal(a.flat_charges[n], a._charges[o])


def test_item():
t1 = BlockSparseTensor(
data=np.array(1.0),
charges=[],
flows=[],
order=[],
check_consistency=False)
assert t1.item() == 1
Ds = [10, 11, 12, 13]
charges = [U1Charge.random(Ds[n], -5, 5) for n in range(4)]
flows = [True, False, True, False]
inds = [Index(c, flows[n]) for n, c in enumerate(charges)]
t2 = BlockSparseTensor.random(inds, dtype=np.float64)
with pytest.raises(
ValueError,
match="can only convert an array of"
" size 1 to a Python scalar"):
t2.item()
8 changes: 6 additions & 2 deletions tensornetwork/block_sparse/charge.py
Expand Up @@ -164,8 +164,8 @@ def degeneracies(self):
return np.sum(exp1 == exp2, axis=0)

def __repr__(self):
return str(
type(self)) + '\n' + 'charges: \n' + self.charges.__repr__() + '\n'
return 'BaseCharge object:' + '\n charge types: ' + self.names + '\n charges:' + str(
self.charges).replace('\n', '\n\t ') + '\n'

def __iter__(self):
return self.Iterator(self.unique_charges, self.charge_labels)
Expand Down Expand Up @@ -528,6 +528,10 @@ def isin(self, target_charges: Union[np.ndarray, "BaseCharge"]) -> np.ndarray:

return np.isin(self.charge_labels, inds)

@property
def names(self):
return repr([ct.__new__(ct).__class__.__name__ for ct in self.charge_types])


class U1Charge(BaseCharge):
"""Charge Class for the U1 symmetry group."""
Expand Down
5 changes: 4 additions & 1 deletion tensornetwork/block_sparse/linalg.py
Expand Up @@ -602,7 +602,8 @@ def eye(column_index: Index,
def trace(tensor: BlockSparseTensor,
axes: Optional[Tuple[int, ...]] = None) -> BlockSparseTensor:
"""
Compute the trace of a matrix or tensor.
Compute the trace of a matrix or tensor. If input has `ndim>2`, take
the trace over the last two dimensions.
Args:
tensor: A `BlockSparseTensor`.
axes: The axes over which the trace should be computed.
Expand Down Expand Up @@ -656,6 +657,8 @@ def trace(tensor: BlockSparseTensor,
a1ar[np.logical_xor(mask_min, mask_max)] -= 1
a0 = list(a0ar)
a1 = list(a1ar)
if out.ndim == 0:
return out.item()
return out # pytype: disable=bad-return-type
raise ValueError("trace can only be taken for tensors with ndim > 1")

Expand Down
189 changes: 189 additions & 0 deletions tensornetwork/linalg/linalg.py
@@ -0,0 +1,189 @@

# Copyright 2019 The TensorNetwork Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to initialize Node using a NumPy-like syntax."""

import warnings
from typing import Optional, Sequence, Tuple, Any, Union, Type, Callable, List
from typing import Text
import numpy as np
from tensornetwork.backends import base_backend
from tensornetwork import backend_contextmanager
from tensornetwork import backends
from tensornetwork import network_components


Tensor = Any
BaseBackend = base_backend.BaseBackend


# INITIALIZATION
def initialize_node(fname: Text,
*fargs: Any,
name: Optional[Text] = None,
axis_names: Optional[List[Text]] = None,
backend: Optional[Union[Text, BaseBackend]] = None,
**fkwargs: Any
) -> Tensor:
"""Return a Node wrapping data obtained by an initialization function
implemented in a backend. The Node will have the same shape as the
underlying array that function generates, with all Edges dangling.
This function is not intended to be called directly, but doing so should
be safe enough.
Args:
fname: Name of the method of backend to call (a string).
*fargs: Positional arguments to the initialization method.
name: Optional name of the Node.
axis_names: Optional names of the Node's dangling edges.
backend: The backend or its name.
**fkwargs: Keyword arguments to the initialization method.
Returns:
node: A Node wrapping data generated by
(the_backend).fname(*fargs, **fkwargs), with one dangling edge per
axis of data.
"""
if backend is None:
backend_obj = backend_contextmanager.get_default_backend()
else:
backend_obj = backends.backend_factory.get_backend(backend)
func = getattr(backend_obj, fname)
data = func(*fargs, **fkwargs)
node = network_components.Node(data, name=name, axis_names=axis_names,
backend=backend)
return node


def eye(N: int,
dtype: Optional[Type[np.number]] = None,
M: Optional[int] = None,
name: Optional[Text] = None,
axis_names: Optional[List[Text]] = None,
backend: Optional[Union[Text, BaseBackend]] = None) -> Tensor:
"""Return a Node representing a 2D array with ones on the diagonal and
zeros elsewhere. The Node has two dangling Edges.
Args:
N (int): The first dimension of the returned matrix.
dtype, optional: dtype of array (default np.float64).
M (int, optional): The second dimension of the returned matrix.
name (text, optional): Name of the Node.
axis_names (optional): List of names of the edges.
backend (optional): The backend or its name.
Returns:
I : Node of shape (N, M)
Represents an array of all zeros except for the k'th diagonal of all
ones.
"""
the_node = initialize_node("eye", N,
name=name, axis_names=axis_names, backend=backend,
dtype=dtype, M=M)
return the_node


def zeros(shape: Sequence[int],
dtype: Optional[Type[np.number]] = None,
name: Optional[Text] = None,
axis_names: Optional[List[Text]] = None,
backend: Optional[Union[Text, BaseBackend]] = None) -> Tensor:
"""Return a Node of shape `shape` of all zeros.
The Node has one dangling Edge per dimension.
Args:
shape : Shape of the array.
dtype, optional: dtype of array (default np.float64).
name (text, optional): Name of the Node.
axis_names (optional): List of names of the edges.
backend (optional): The backend or its name.
Returns:
the_node : Node of shape `shape`. Represents an array of all zeros.
"""
the_node = initialize_node("zeros", shape,
name=name, axis_names=axis_names, backend=backend,
dtype=dtype)
return the_node


def ones(shape: Sequence[int],
dtype: Optional[Type[np.number]] = None,
name: Optional[Text] = None,
axis_names: Optional[List[Text]] = None,
backend: Optional[Union[Text, BaseBackend]] = None) -> Tensor:
"""Return a Node of shape `shape` of all ones.
The Node has one dangling Edge per dimension.
Args:
shape : Shape of the array.
dtype, optional: dtype of array (default np.float64).
name (text, optional): Name of the Node.
axis_names (optional): List of names of the edges.
backend (optional): The backend or its name.
Returns:
the_node : Node of shape `shape`
Represents an array of all ones.
"""
the_node = initialize_node("ones", shape,
name=name, axis_names=axis_names, backend=backend,
dtype=dtype)
return the_node


def randn(shape: Sequence[int],
dtype: Optional[Type[np.number]] = None,
seed: Optional[int] = None,
name: Optional[Text] = None,
axis_names: Optional[List[Text]] = None,
backend: Optional[Union[Text, BaseBackend]] = None) -> Tensor:
"""Return a Node of shape `shape` of Gaussian random floats.
The Node has one dangling Edge per dimension.
Args:
shape : Shape of the array.
dtype, optional: dtype of array (default np.float64).
seed, optional: Seed for the RNG.
name (text, optional): Name of the Node.
axis_names (optional): List of names of the edges.
backend (optional): The backend or its name.
Returns:
the_node : Node of shape `shape` filled with Gaussian random data.
"""
the_node = initialize_node("randn", shape,
name=name, axis_names=axis_names, backend=backend,
seed=seed, dtype=dtype)
return the_node


def random_uniform(shape: Sequence[int],
dtype: Optional[Type[np.number]] = None,
seed: Optional[int] = None,
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
name: Optional[Text] = None,
axis_names: Optional[List[Text]] = None,
backend:
Optional[Union[Text, BaseBackend]] = None) -> Tensor:
"""Return a Node of shape `shape` of uniform random floats.
The Node has one dangling Edge per dimension.
Args:
shape : Shape of the array.
dtype, optional: dtype of array (default np.float64).
seed, optional: Seed for the RNG.
boundaries : Values lie in [boundaries[0], boundaries[1]).
name (text, optional): Name of the Node.
axis_names (optional): List of names of the edges.
backend (optional): The backend or its name.
Returns:
the_node : Node of shape `shape` filled with uniform random data.
"""
the_node = initialize_node("random_uniform", shape,
name=name, axis_names=axis_names, backend=backend,
seed=seed, boundaries=boundaries, dtype=dtype)
return the_node

0 comments on commit 15f9447

Please sign in to comment.