Skip to content

Commit

Permalink
fix: pytorch sparse ndarray constructor (#2233)
Browse files Browse the repository at this point in the history
* fix: add proto to pytorch constructor

* test: delete unused test sparse

* test: add ndarray genetic test

* test: add ndarray genetic test

* test: add ndarray genetic test

* test: add ndarray genetic test

* fix: add docstring

* test: add ndarray genetic test
  • Loading branch information
bwanglzu committed Mar 25, 2021
1 parent 1c9949c commit 2be26c6
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 3 deletions.
16 changes: 13 additions & 3 deletions jina/types/ndarray/sparse/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.sparse import FloatTensor

from . import BaseSparseNdArray
from ....proto import jina_pb2

if False:
import numpy as np
Expand All @@ -18,16 +19,25 @@ class SparseNdArray(BaseSparseNdArray):
.. seealso::
https://pytorch.org/docs/stable/sparse.html
:param proto: protobuf instance, default is None.
:param transpose_indices: in torch, the input to LongTensor is NOT a list of index tuples.
:param args: positional arguments.
:param kwargs: positional key value arguments.
If you want to write your indices this way, you should transpose before passing them to the sparse constructor
.. note::
To comply with Tensorflow, `transpose_indices` is set to True by default
"""

def __init__(self, transpose_indices: bool = True, *args, **kwargs):
"""Set constructor method."""
super().__init__(*args, **kwargs)
def __init__(
self,
proto: 'jina_pb2.SparseNdArrayProto' = None,
transpose_indices: bool = True,
*args,
**kwargs
):
super().__init__(proto, *args, **kwargs)
self.transpose_indices = transpose_indices

def sparse_constructor(
Expand Down
95 changes: 95 additions & 0 deletions tests/unit/types/ndarray/test_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
import numpy as np

from jina.proto import jina_pb2
from jina.types.ndarray.generic import NdArray
from jina.types.ndarray.dense.numpy import DenseNdArray
from jina.types.ndarray.sparse.tensorflow import SparseNdArray as TFSparseNdArray
from jina.types.ndarray.sparse.pytorch import SparseNdArray as PTSparseNdArray
from jina.types.ndarray.sparse.scipy import SparseNdArray as SPSparseNdArray


@pytest.fixture
def tf_sparse_tensor():
from tensorflow import SparseTensor

return SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])


@pytest.fixture
def pt_sparse_tensor():
import torch

i = [[0, 2], [1, 0], [1, 2]]
v = [3, 4, 5]
return torch.sparse_coo_tensor(list(zip(*i)), v, (2, 3))


@pytest.fixture
def sp_sparse_tensor():
from scipy.sparse import csr_matrix

row = np.array([0, 0, 1, 2, 2, 2])
col = np.array([0, 2, 2, 0, 1, 2])
data = np.array([1, 2, 3, 4, 5, 6])
return csr_matrix((data, (row, col)), shape=(3, 3))


@pytest.fixture
def np_dense_tensor(sp_sparse_tensor):
return sp_sparse_tensor.todense()


def tf_ndarray():
return NdArray(None, True, None, TFSparseNdArray)


def sp_ndarray():
return NdArray(None, True, None, SPSparseNdArray)


def pt_ndarray():
return NdArray(None, True, None, PTSparseNdArray)


def np_ndarray():
return NdArray(None, False, DenseNdArray, None)


@pytest.fixture(
params=[
tf_ndarray,
sp_ndarray,
pt_ndarray,
np_ndarray,
]
)
def NdArrayCls(request):
yield request.param


def test_null_proto(NdArrayCls):
assert NdArrayCls().null_proto() == jina_pb2.NdArrayProto()


def test_value_get_set(
NdArrayCls, tf_sparse_tensor, pt_sparse_tensor, sp_sparse_tensor, np_dense_tensor
):
ndarray = NdArrayCls()
assert ndarray.value is None
if isinstance(ndarray, TFSparseNdArray):
ndarray.value = tf_sparse_tensor
assert ndarray.value == tf_sparse_tensor
assert ndarray.is_sparse is True
elif isinstance(ndarray, SPSparseNdArray):
ndarray.value = sp_sparse_tensor
assert ndarray.value == sp_sparse_tensor
assert ndarray.is_sparse is True
elif isinstance(ndarray, PTSparseNdArray):
ndarray.value = pt_sparse_tensor
assert ndarray.value == pt_sparse_tensor
assert ndarray.is_sparse is True
elif isinstance(ndarray, DenseNdArray):
ndarray.value = np_dense_tensor
assert ndarray.value == np_dense_tensor
assert ndarray.is_sparse is False

0 comments on commit 2be26c6

Please sign in to comment.