Skip to content

Commit

Permalink
feat: Do not infer dtype when data is None in tensor.Box (#261)
Browse files Browse the repository at this point in the history
* feat: Do not infer dtype when `data` is `None` in `tensor.Box`

Tests have been added accordingly.

Fixes weird behavior where tensor boxes that should be equal are different.

* Implement changes suggested by @toumix
  • Loading branch information
boldar99 committed Jan 17, 2024
1 parent ca51877 commit 909e203
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
14 changes: 7 additions & 7 deletions discopy/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ class Box(frobenius.Box, Diagram):
dom : The domain of the box, i.e. its input dimension.
cod : The codomain of the box, i.e. its output dimension.
data : The array inside the tensor box.
dtype : The datatype for the entries of the array.
Example
-------
Expand All @@ -541,16 +540,17 @@ def __setstate__(self, state):
state['data'] = state['_array']
del state["_array"]
super().__setstate__(state)
if self.dtype is None:
if self.dtype is None and self.data is not None:
self.data, self.dtype = self._get_data_dtype(self.data)
self.__class__ = self.__class__[self.dtype]

def __new__(cls, *args, **kwargs):
if not args and not kwargs or cls.dtype is not None:
def __new__(
cls, name=None, dom=None, cod=None, data=None, *args, **kwargs):
if cls.dtype is not None or data is None:
return object.__new__(cls)
data, dtype = cls._get_data_dtype(kwargs.get("data", []))
kwargs["data"] = data
return cls.__new__(cls[dtype], *args, **kwargs)
data, dtype = cls._get_data_dtype(data)
return cls.__new__(
cls[dtype], name, dom, cod, data, *args, **kwargs)

@staticmethod
def _get_data_dtype(data):
Expand Down
9 changes: 8 additions & 1 deletion test/semantics/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_Box():


def test_Spider():
assert repr(Spider(1, 2, Dim(3))) == "tensor.Spider[float64](1, 2, Dim(3))"
assert repr(Spider(1, 2, Dim(3))) == "tensor.Spider(1, 2, Dim(3))"
assert Spider(1, 2, Dim(2)).dagger() == Spider(2, 1, Dim(2))
with raises(ValueError):
Spider(1, 2, Dim(2, 3))
Expand Down Expand Up @@ -260,6 +260,13 @@ def test_Tensor_adjoint_eval():
assert tensor1 == tensor2


def test_Tensor_dtype_inference():
assert Box("F(A)", Dim(1), Dim(1), data=None).dtype is None
assert Box("X", Dim(1), Dim(1), data=[0]) == Box[np.int64]("X", Dim(1), Dim(1), data=[0])
assert Box("Y", Dim(1), Dim(1), data=[1.]) == Box[np.float64]("Y", Dim(1), Dim(1), data=[1.])
assert Box("Y", Dim(1), Dim(1), data=[1]) != Box("Y", Dim(1), Dim(1), data=[1.])


def test_non_numpy_eval():
with backend('pytorch'):
with raises(Exception):
Expand Down

0 comments on commit 909e203

Please sign in to comment.