Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,33 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array

def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Estimate a covariance matrix.
Estimate a covariance matrix (or a stack of covariance matrices).

Covariance indicates the level to which two variables vary together.
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
then the covariance matrix element :math:`C_{ij}` is the covariance of
If we examine *N*-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
each with *M* observations, then element :math:`C_{ij}` of the
:math:`N \times N` covariance matrix is the covariance of
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
of :math:`x_i`.

This provides a subset of the functionality of ``numpy.cov``.
With the exception of supporting batch input, this provides a subset of
the functionality of ``numpy.cov``.

Parameters
----------
m : array
A 1-D or 2-D array containing multiple variables and observations.
Each row of `m` represents a variable, and each column a single
An array of shape ``(..., N, M)`` whose innermost two dimensions
contain *M* observations of *N* variables. That is,
each row of `m` represents a variable, and each column a single
observation of all those variables.
xp : array_namespace, optional
The standard-compatible namespace for `m`. Default: infer.

Returns
-------
array
The covariance matrix of the variables.
An array having shape (..., N, N) whose innermost two dimensions represent
the covariance matrix of the variables.

Examples
--------
Expand Down Expand Up @@ -142,6 +146,17 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:

>>> xpx.cov(y, xp=xp)
Array(2.14413333, dtype=array_api_strict.float64)

Input with more than two dimensions is treated as a stack of
two-dimensional input.

>>> stack = xp.stack((X, 2*X))
>>> xpx.cov(stack)
Array([[[ 11.71 , -4.286 ],
[ -4.286 , 2.14413333]],

[[ 46.84 , -17.144 ],
[-17.144 , 8.57653333]]], dtype=array_api_strict.float64)
"""

if xp is None:
Expand All @@ -153,7 +168,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
or is_torch_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
):
) and m.ndim <= 2:
return xp.cov(m)

return _funcs.cov(m, xp=xp)
Expand Down
10 changes: 5 additions & 5 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,20 +258,20 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = _helpers.mean(m, axis=1, xp=xp)
avg = _helpers.mean(m, axis=-1, keepdims=True, xp=xp)

m_shape = eager_shape(m)
fact = m_shape[1] - 1
fact = m_shape[-1] - 1

if fact <= 0:
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
fact = 0

m -= avg[:, None]
m_transpose = m.T
m -= avg
m_transpose = xp.matrix_transpose(m)
if xp.isdtype(m_transpose.dtype, "complex floating"):
m_transpose = xp.conj(m_transpose)
c = m @ m_transpose
c = xp.matmul(m, m_transpose)
c /= fact
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
return xp.squeeze(c, axis=axes)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,16 @@ def test_xp(self, xp: ModuleType):
xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64),
)

def test_batch(self, xp: ModuleType):
rng = np.random.default_rng(8847643423)
batch_shape = (3, 4)
n_var, n_obs = 3, 20
m = rng.random((*batch_shape, n_var, n_obs))
res = cov(xp.asarray(m))
ref_list = [np.cov(m_) for m_ in np.reshape(m, (-1, n_var, n_obs))]
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
xp_assert_close(res, xp.asarray(ref))


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
class TestOneHot:
Expand Down