diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index c346c494300b..0019c4bcd63f 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -3661,8 +3661,9 @@ cdef class Generator: mean = np.array(mean) cov = np.array(cov) - if np.issubdtype(cov.dtype, np.complexfloating): - raise NotImplementedError("Complex gaussians are not supported.") + if (np.issubdtype(mean.dtype, np.complexfloating) or + np.issubdtype(cov.dtype, np.complexfloating)): + raise TypeError("mean and cov must not be complex") if size is None: shape = [] diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 925ac9e2ba71..fa55ac0ee96a 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -1453,7 +1453,11 @@ def test_multivariate_normal(self, method): assert_raises(ValueError, random.multivariate_normal, mu, np.eye(3)) - assert_raises(NotImplementedError, np.random.multivariate_normal, [0], [[1+1j]]) + @pytest.mark.parametrize('mean, cov', [([0], [[1+1j]]), ([0j], [[1]])]) + def test_multivariate_normal_disallow_complex(self, mean, cov): + random = Generator(MT19937(self.seed)) + with pytest.raises(TypeError, match="must not be complex"): + random.multivariate_normal(mean, cov) @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"]) def test_multivariate_normal_basic_stats(self, method):