diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index e09f1f4a..289d21e4 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -81,21 +81,24 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: """ - Estimate a covariance matrix. + Estimate a covariance matrix (or a stack of covariance matrices). Covariance indicates the level to which two variables vary together. - If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`, - then the covariance matrix element :math:`C_{ij}` is the covariance of + If we examine *N*-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`, + each with *M* observations, then element :math:`C_{ij}` of the + :math:`N \times N` covariance matrix is the covariance of :math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance of :math:`x_i`. - This provides a subset of the functionality of ``numpy.cov``. + With the exception of supporting batch input, this provides a subset of + the functionality of ``numpy.cov``. Parameters ---------- m : array - A 1-D or 2-D array containing multiple variables and observations. - Each row of `m` represents a variable, and each column a single + An array of shape ``(..., N, M)`` whose innermost two dimensions + contain *M* observations of *N* variables. That is, + each row of `m` represents a variable, and each column a single observation of all those variables. xp : array_namespace, optional The standard-compatible namespace for `m`. Default: infer. @@ -103,7 +106,8 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: Returns ------- array - The covariance matrix of the variables. + An array having shape (..., N, N) whose innermost two dimensions represent + the covariance matrix of the variables. Examples -------- @@ -142,6 +146,17 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: >>> xpx.cov(y, xp=xp) Array(2.14413333, dtype=array_api_strict.float64) + + Input with more than two dimensions is treated as a stack of + two-dimensional input. + + >>> stack = xp.stack((X, 2*X)) + >>> xpx.cov(stack) + Array([[[ 11.71 , -4.286 ], + [ -4.286 , 2.14413333]], + + [[ 46.84 , -17.144 ], + [-17.144 , 8.57653333]]], dtype=array_api_strict.float64) """ if xp is None: @@ -153,7 +168,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: or is_torch_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp) - ): + ) and m.ndim <= 2: return xp.cov(m) return _funcs.cov(m, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index fb124a6e..49840c0f 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -258,20 +258,20 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - avg = _helpers.mean(m, axis=1, xp=xp) + avg = _helpers.mean(m, axis=-1, keepdims=True, xp=xp) m_shape = eager_shape(m) - fact = m_shape[1] - 1 + fact = m_shape[-1] - 1 if fact <= 0: warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) fact = 0 - m -= avg[:, None] - m_transpose = m.T + m -= avg + m_transpose = xp.matrix_transpose(m) if xp.isdtype(m_transpose.dtype, "complex floating"): m_transpose = xp.conj(m_transpose) - c = m @ m_transpose + c = xp.matmul(m, m_transpose) c /= fact axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) return xp.squeeze(c, axis=axes) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index a120e559..ff050468 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -476,6 +476,16 @@ def test_xp(self, xp: ModuleType): xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64), ) + def test_batch(self, xp: ModuleType): + rng = np.random.default_rng(8847643423) + batch_shape = (3, 4) + n_var, n_obs = 3, 20 + m = rng.random((*batch_shape, n_var, n_obs)) + res = cov(xp.asarray(m)) + ref_list = [np.cov(m_) for m_ in np.reshape(m, (-1, n_var, n_obs))] + ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) + xp_assert_close(res, xp.asarray(ref)) + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) class TestOneHot: