Skip to content

Commit

Permalink
Use testing.parametrize
Browse files Browse the repository at this point in the history
  • Loading branch information
rezoo committed Apr 17, 2017
1 parent a460de3 commit 45f1cfa
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions tests/cupy_tests/linalg_tests/test_decomposition.py
Expand Up @@ -61,6 +61,9 @@ def test_mode(self):
self.check_mode(numpy.random.randn(5, 4), mode=self.mode)


@testing.parameterize(*testing.product({
'full_matrices': [True, False],
}))
@unittest.skipUnless(
cuda.cusolver_enabled, 'Only cusolver in CUDA 8.0 is supported')
@testing.gpu
Expand All @@ -69,39 +72,30 @@ class TestSVD(unittest.TestCase):
_multiprocess_can_split_ = True

@testing.for_float_dtypes(no_float16=True)
def check_usv(self, array, full_matrices, dtype):
def check_usv(self, array, dtype):
a_cpu = numpy.asarray(array, dtype=dtype)
a_gpu = cupy.asarray(array, dtype=dtype)
result_cpu = numpy.linalg.svd(a_cpu, full_matrices=full_matrices)
result_gpu = cupy.linalg.svd(a_gpu, full_matrices=full_matrices)
result_cpu = numpy.linalg.svd(a_cpu, full_matrices=self.full_matrices)
result_gpu = cupy.linalg.svd(a_gpu, full_matrices=self.full_matrices)
for b_cpu, b_gpu in zip(result_cpu, result_gpu):
# Use abs to support an inverse vector
cupy.testing.assert_allclose(
numpy.abs(b_cpu), cupy.abs(b_gpu), atol=1e-4)

@testing.for_float_dtypes(no_float16=True)
@testing.numpy_cupy_allclose(atol=1e-4)
def check_singular(self, array, xp, dtype, full_matrices):
def check_singular(self, array, xp, dtype):
a = xp.asarray(array, dtype=dtype)
result = xp.linalg.svd(
a, full_matrices=full_matrices, compute_uv=False)
a, full_matrices=self.full_matrices, compute_uv=False)
return result

def test_svd_full_matrices(self):
self.check_usv(numpy.random.randn(2, 3), full_matrices=True)
self.check_usv(numpy.random.randn(2, 2), full_matrices=True)
self.check_usv(numpy.random.randn(3, 2), full_matrices=True)

def test_svd_no_full_matrices(self):
self.check_usv(numpy.random.randn(2, 3), full_matrices=False)
self.check_usv(numpy.random.randn(2, 2), full_matrices=False)
self.check_usv(numpy.random.randn(3, 2), full_matrices=False)
def test_svd(self):
self.check_usv(numpy.random.randn(2, 3))
self.check_usv(numpy.random.randn(2, 2))
self.check_usv(numpy.random.randn(3, 2))

def test_svd_no_uv(self):
self.check_singular(numpy.random.randn(2, 3), full_matrices=True)
self.check_singular(numpy.random.randn(2, 2), full_matrices=True)
self.check_singular(numpy.random.randn(3, 2), full_matrices=True)

self.check_singular(numpy.random.randn(2, 3), full_matrices=False)
self.check_singular(numpy.random.randn(2, 2), full_matrices=False)
self.check_singular(numpy.random.randn(3, 2), full_matrices=False)
self.check_singular(numpy.random.randn(2, 3))
self.check_singular(numpy.random.randn(2, 2))
self.check_singular(numpy.random.randn(3, 2))

0 comments on commit 45f1cfa

Please sign in to comment.