diff --git a/cupy/sparse/coo.py b/cupy/sparse/coo.py index 92ca5f0acf2..5b25fe46dfb 100644 --- a/cupy/sparse/coo.py +++ b/cupy/sparse/coo.py @@ -5,10 +5,12 @@ except ImportError: _scipy_available = False +import cupy from cupy import cusparse from cupy.sparse import base from cupy.sparse import csr from cupy.sparse import data as sparse_data +from cupy.sparse import util class coo_matrix(sparse_data._data_matrix): @@ -17,6 +19,10 @@ class coo_matrix(sparse_data._data_matrix): Now it has only one initializer format below: + ``coo_matrix((M, N), [dtype])`` + It constructs an empty matrix whose shape is ``(M, N)``. Default dtype + is float64. + ``coo_matrix((data, (row, col))`` All ``data``, ``row`` and ``col`` are one-dimenaional :class:`cupy.ndarray`. @@ -39,7 +45,17 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False): raise ValueError( 'Only two-dimensional sparse arrays are supported.') - if isinstance(arg1, tuple) and len(arg1) == 2: + if util.isshape(arg1): + m, n = arg1 + m, n = int(m), int(n) + data = cupy.zeros(0, dtype if dtype else 'd') + row = cupy.zeros(0, dtype='i') + col = cupy.zeros(0, dtype='i') + # shape and copy argument is ignored + shape = (m, n) + copy = False + + elif isinstance(arg1, tuple) and len(arg1) == 2: try: data, (row, col) = arg1 except (TypeError, ValueError): diff --git a/cupy/sparse/util.py b/cupy/sparse/util.py index 4269cb1ac8d..4a4bbe3c373 100644 --- a/cupy/sparse/util.py +++ b/cupy/sparse/util.py @@ -16,4 +16,4 @@ def isshape(x): if not isinstance(x, tuple) or len(x) != 2: return False m, n = x - return int(m) == m and int(n) == n + return isintlike(m) and isintlike(n) diff --git a/tests/cupy_tests/sparse_tests/test_coo.py b/tests/cupy_tests/sparse_tests/test_coo.py index 7e53488b5f5..f6a27d8afaa 100644 --- a/tests/cupy_tests/sparse_tests/test_coo.py +++ b/tests/cupy_tests/sparse_tests/test_coo.py @@ -39,6 +39,10 @@ def _make_empty(xp, sp, dtype): return sp.coo_matrix((data, (row, col)), shape=(3, 4)) +def _make_shape(xp, sp, dtype): + return sp.coo_matrix((3, 4)) + + @testing.parameterize(*testing.product({ 'dtype': [numpy.float32, numpy.float64], })) @@ -247,7 +251,7 @@ def test_unsupported_dtype(self): @testing.parameterize(*testing.product({ - 'make_method': ['_make', '_make_unordered', '_make_empty'], + 'make_method': ['_make', '_make_unordered', '_make_empty', '_make_shape'], 'dtype': [numpy.float32, numpy.float64], })) @unittest.skipUnless(scipy_available, 'requires scipy') @@ -257,6 +261,11 @@ class TestCooMatrixScipyComparison(unittest.TestCase): def make(self): return globals()[self.make_method] + @testing.numpy_cupy_equal(sp_name='sp') + def test_dtype(self, xp, sp): + m = self.make(xp, sp, self.dtype) + return m.dtype + @testing.numpy_cupy_equal(sp_name='sp') def test_nnz(self, xp, sp): m = self.make(xp, sp, self.dtype)