Skip to content

Commit

Permalink
MAINT: random: Update to disallowing complex inputs to multivariate_n…
Browse files Browse the repository at this point in the history
…ormal.

* Disallow both mean and cov from being complex.
* Raise a TypeError instead of a NotImplementedError if mean or cov is
  complex.
* Expand and fix the unit test.
  • Loading branch information
WarrenWeckesser committed Jun 16, 2022
1 parent c8b5124 commit 979d288
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
5 changes: 3 additions & 2 deletions numpy/random/_generator.pyx
Expand Up @@ -3661,8 +3661,9 @@ cdef class Generator:
mean = np.array(mean)
cov = np.array(cov)

if np.issubdtype(cov.dtype, np.complexfloating):
raise NotImplementedError("Complex gaussians are not supported.")
if (np.issubdtype(mean.dtype, np.complexfloating) or
np.issubdtype(cov.dtype, np.complexfloating)):
raise TypeError("mean and cov must not be complex")

if size is None:
shape = []
Expand Down
6 changes: 5 additions & 1 deletion numpy/random/tests/test_generator_mt19937.py
Expand Up @@ -1453,7 +1453,11 @@ def test_multivariate_normal(self, method):
assert_raises(ValueError, random.multivariate_normal,
mu, np.eye(3))

assert_raises(NotImplementedError, np.random.multivariate_normal, [0], [[1+1j]])
@pytest.mark.parametrize('mean, cov', [([0], [[1+1j]]), ([0j], [[1]])])
def test_multivariate_normal_disallow_complex(self, mean, cov):
random = Generator(MT19937(self.seed))
with pytest.raises(TypeError, match="must not be complex"):
random.multivariate_normal(mean, cov)

@pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"])
def test_multivariate_normal_basic_stats(self, method):
Expand Down

0 comments on commit 979d288

Please sign in to comment.