Skip to content

Commit

Permalink
fix 1065: Verify input object type and layout + Supporting tests (#1066)
Browse files Browse the repository at this point in the history
* fix 1065: Verify input object type and layout + Supporting tests

* Updated documentation

* Included example with a non-sparse input argument

---------

Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com>
  • Loading branch information
Mystic-Slice and ClaudiaComito committed Jan 30, 2023
1 parent 6e5ab15 commit 8597417
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
27 changes: 23 additions & 4 deletions heat/sparse/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from scipy.sparse import csr_matrix as scipy_csr_matrix

from typing import Optional, Type, Union
from typing import Optional, Type, Iterable
import warnings

from ..core import devices
Expand All @@ -21,7 +21,7 @@


def sparse_csr_matrix(
obj: Union[torch.Tensor, scipy_csr_matrix],
obj: Iterable,
dtype: Optional[Type[datatype]] = None,
split: Optional[int] = None,
is_split: Optional[int] = None,
Expand All @@ -33,8 +33,9 @@ def sparse_csr_matrix(
Parameters
----------
obj : :class:`torch.Tensor` (layout ==> torch.sparse_csr) or :class:`scipy.sparse.csr_matrix`
Sparse tensor that needs to be distributed
obj : array_like
A tensor or array, any object exposing the array interface, an object whose ``__array__`` method returns an
array, or any (nested) sequence. Sparse tensor that needs to be distributed.
dtype : datatype, optional
The desired data-type for the sparse matrix. If not given, then the type will be determined as the minimum type required
to hold the objects in the sequence. This argument can only be used to ‘upcast’ the array. For downcasting, use
Expand Down Expand Up @@ -87,6 +88,10 @@ def sparse_csr_matrix(
>>> heat_sparse_csr = ht.sparse.sparse_csr_matrix(local_torch_sparse_csr, is_split=0)
>>> heat_sparse_csr
(indptr: tensor([0, 2, 3, 6]), indices: tensor([0, 2, 2, 0, 1, 2]), data: tensor([1., 2., 3., 4., 5., 6.]), dtype=ht.float32, device=cpu:0, split=0)
Create a :class:`~heat.sparse.DCSR_matrix` from List
>>> ht.sparse.sparse_csr_matrix([[0, 0, 1], [1, 0, 2], [0, 0, 3]])
(indptr: tensor([0, 1, 3, 4]), indices: tensor([2, 0, 2, 2]), data: tensor([1, 1, 2, 3]), dtype=ht.int64, device=cpu:0, split=None)
"""
# version check
if int(torch.__version__.split(".")[1]) < 10:
Expand All @@ -110,6 +115,20 @@ def sparse_csr_matrix(
size=obj.shape,
)

if not isinstance(obj, torch.Tensor):
try:
obj = torch.tensor(
obj,
device=device.torch_device
if device is not None
else devices.get_device().torch_device,
)
except RuntimeError:
raise TypeError(f"Invalid data of type {type(obj)}")

if obj.layout != torch.sparse_csr:
obj = obj.to_sparse_csr()

# infer dtype from obj if not explicitly given
if dtype is None:
dtype = types.canonical_heat_type(obj.dtype)
Expand Down
52 changes: 52 additions & 0 deletions heat/sparse/tests/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def setUpClass(self):
[4, 0, 0, 5, 0]
[0, 0, 0, 0, 6]]
"""
self.matrix_list = [
[0, 0, 1, 0, 2],
[0, 0, 0, 0, 0],
[0, 3, 0, 0, 0],
[4, 0, 0, 5, 0],
[0, 0, 0, 0, 6],
]
self.ref_indptr = torch.tensor(
[0, 2, 2, 3, 5, 6], dtype=torch.int, device=self.device.torch_device
)
Expand Down Expand Up @@ -459,6 +466,51 @@ def test_sparse_csr_matrix(self):
).all()
)

"""
Input: torch.Tensor
"""
torch_tensor = torch.tensor(
self.matrix_list, dtype=torch.float, device=self.device.torch_device
)
heat_sparse_csr = ht.sparse.sparse_csr_matrix(torch_tensor)

self.assertIsInstance(heat_sparse_csr, ht.sparse.DCSR_matrix)
self.assertEqual(heat_sparse_csr.dtype, ht.float32)
self.assertEqual(heat_sparse_csr.indptr.dtype, torch.int64)
self.assertEqual(heat_sparse_csr.indices.dtype, torch.int64)
self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape)
self.assertEqual(heat_sparse_csr.lshape, self.ref_torch_sparse_csr.shape)
self.assertEqual(heat_sparse_csr.split, None)
self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all())
self.assertTrue((heat_sparse_csr.lindptr == self.ref_indptr).all())
self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all())
self.assertTrue((heat_sparse_csr.lindices == self.ref_indices).all())
self.assertTrue((heat_sparse_csr.data == self.ref_data).all())
self.assertTrue((heat_sparse_csr.ldata == self.ref_data).all())

"""
Input: List[int]
"""
heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.matrix_list)

self.assertIsInstance(heat_sparse_csr, ht.sparse.DCSR_matrix)
self.assertEqual(heat_sparse_csr.dtype, ht.int64)
self.assertEqual(heat_sparse_csr.indptr.dtype, torch.int64)
self.assertEqual(heat_sparse_csr.indices.dtype, torch.int64)
self.assertEqual(heat_sparse_csr.shape, self.ref_torch_sparse_csr.shape)
self.assertEqual(heat_sparse_csr.lshape, self.ref_torch_sparse_csr.shape)
self.assertEqual(heat_sparse_csr.split, None)
self.assertTrue((heat_sparse_csr.indptr == self.ref_indptr).all())
self.assertTrue((heat_sparse_csr.lindptr == self.ref_indptr).all())
self.assertTrue((heat_sparse_csr.indices == self.ref_indices).all())
self.assertTrue((heat_sparse_csr.lindices == self.ref_indices).all())
self.assertTrue((heat_sparse_csr.data == self.ref_data).all())
self.assertTrue((heat_sparse_csr.ldata == self.ref_data).all())

with self.assertRaises(TypeError):
# Passing an object which cant be converted into a torch.Tensor
heat_sparse_csr = ht.sparse.sparse_csr_matrix(self)

# Errors (torch.Tensor)
with self.assertRaises(ValueError):
heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr, split=1)
Expand Down

0 comments on commit 8597417

Please sign in to comment.