| 
1 | 1 | from __future__ import annotations  | 
2 | 2 | 
 
  | 
 | 3 | +import warnings  | 
3 | 4 | from typing import TYPE_CHECKING  | 
4 | 5 | 
 
  | 
5 | 6 | if TYPE_CHECKING:  | 
6 | 7 |     from ._typing import Array, ModuleType  | 
7 | 8 | 
 
  | 
8 |  | -__all__ = ["atleast_nd", "expand_dims", "kron"]  | 
 | 9 | +__all__ = ["atleast_nd", "cov", "expand_dims", "kron"]  | 
9 | 10 | 
 
  | 
10 | 11 | 
 
  | 
11 | 12 | def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:  | 
@@ -48,6 +49,117 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:  | 
48 | 49 |     return x  | 
49 | 50 | 
 
  | 
50 | 51 | 
 
  | 
 | 52 | +def cov(m: Array, /, *, xp: ModuleType) -> Array:  | 
 | 53 | +    """  | 
 | 54 | +    Estimate a covariance matrix.  | 
 | 55 | +
  | 
 | 56 | +    Covariance indicates the level to which two variables vary together.  | 
 | 57 | +    If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,  | 
 | 58 | +    then the covariance matrix element :math:`C_{ij}` is the covariance of  | 
 | 59 | +    :math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance  | 
 | 60 | +    of :math:`x_i`.  | 
 | 61 | +
  | 
 | 62 | +    This provides a subset of the functionality of ``numpy.cov``.  | 
 | 63 | +
  | 
 | 64 | +    Parameters  | 
 | 65 | +    ----------  | 
 | 66 | +    m : array  | 
 | 67 | +        A 1-D or 2-D array containing multiple variables and observations.  | 
 | 68 | +        Each row of `m` represents a variable, and each column a single  | 
 | 69 | +        observation of all those variables.  | 
 | 70 | +    xp : array_namespace  | 
 | 71 | +        The standard-compatible namespace for `m`.  | 
 | 72 | +
  | 
 | 73 | +    Returns  | 
 | 74 | +    -------  | 
 | 75 | +    res : array  | 
 | 76 | +        The covariance matrix of the variables.  | 
 | 77 | +
  | 
 | 78 | +    Examples  | 
 | 79 | +    --------  | 
 | 80 | +    >>> import array_api_strict as xp  | 
 | 81 | +    >>> import array_api_extra as xpx  | 
 | 82 | +
  | 
 | 83 | +    Consider two variables, :math:`x_0` and :math:`x_1`, which  | 
 | 84 | +    correlate perfectly, but in opposite directions:  | 
 | 85 | +
  | 
 | 86 | +    >>> x = xp.asarray([[0, 2], [1, 1], [2, 0]]).T  | 
 | 87 | +    >>> x  | 
 | 88 | +    Array([[0, 1, 2],  | 
 | 89 | +           [2, 1, 0]], dtype=array_api_strict.int64)  | 
 | 90 | +
  | 
 | 91 | +    Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance  | 
 | 92 | +    matrix shows this clearly:  | 
 | 93 | +
  | 
 | 94 | +    >>> xpx.cov(x, xp=xp)  | 
 | 95 | +    Array([[ 1., -1.],  | 
 | 96 | +           [-1.,  1.]], dtype=array_api_strict.float64)  | 
 | 97 | +
  | 
 | 98 | +
  | 
 | 99 | +    Note that element :math:`C_{0,1}`, which shows the correlation between  | 
 | 100 | +    :math:`x_0` and :math:`x_1`, is negative.  | 
 | 101 | +
  | 
 | 102 | +    Further, note how `x` and `y` are combined:  | 
 | 103 | +
  | 
 | 104 | +    >>> x = xp.asarray([-2.1, -1,  4.3])  | 
 | 105 | +    >>> y = xp.asarray([3,  1.1,  0.12])  | 
 | 106 | +    >>> X = xp.stack((x, y), axis=0)  | 
 | 107 | +    >>> xpx.cov(X, xp=xp)  | 
 | 108 | +    Array([[11.71      , -4.286     ],  | 
 | 109 | +           [-4.286     ,  2.14413333]], dtype=array_api_strict.float64)  | 
 | 110 | +
  | 
 | 111 | +    >>> xpx.cov(x, xp=xp)  | 
 | 112 | +    Array(11.71, dtype=array_api_strict.float64)  | 
 | 113 | +
  | 
 | 114 | +    >>> xpx.cov(y, xp=xp)  | 
 | 115 | +    Array(2.14413333, dtype=array_api_strict.float64)  | 
 | 116 | +
  | 
 | 117 | +    """  | 
 | 118 | +    m = xp.asarray(m, copy=True)  | 
 | 119 | +    dtype = (  | 
 | 120 | +        xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)  | 
 | 121 | +    )  | 
 | 122 | + | 
 | 123 | +    m = atleast_nd(m, ndim=2, xp=xp)  | 
 | 124 | +    m = xp.astype(m, dtype)  | 
 | 125 | + | 
 | 126 | +    avg = _mean(m, axis=1, xp=xp)  | 
 | 127 | +    fact = m.shape[1] - 1  | 
 | 128 | + | 
 | 129 | +    if fact <= 0:  | 
 | 130 | +        warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)  | 
 | 131 | +        fact = 0.0  | 
 | 132 | + | 
 | 133 | +    m -= avg[:, None]  | 
 | 134 | +    m_transpose = m.T  | 
 | 135 | +    if xp.isdtype(m_transpose.dtype, "complex floating"):  | 
 | 136 | +        m_transpose = xp.conj(m_transpose)  | 
 | 137 | +    c = m @ m_transpose  | 
 | 138 | +    c /= fact  | 
 | 139 | +    axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)  | 
 | 140 | +    return xp.squeeze(c, axis=axes)  | 
 | 141 | + | 
 | 142 | + | 
 | 143 | +def _mean(  | 
 | 144 | +    x: Array,  | 
 | 145 | +    /,  | 
 | 146 | +    *,  | 
 | 147 | +    axis: int | tuple[int, ...] | None = None,  | 
 | 148 | +    keepdims: bool = False,  | 
 | 149 | +    xp: ModuleType,  | 
 | 150 | +) -> Array:  | 
 | 151 | +    """  | 
 | 152 | +    Complex mean, https://github.com/data-apis/array-api/issues/846.  | 
 | 153 | +    """  | 
 | 154 | +    if xp.isdtype(x.dtype, "complex floating"):  | 
 | 155 | +        x_real = xp.real(x)  | 
 | 156 | +        x_imag = xp.imag(x)  | 
 | 157 | +        mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)  | 
 | 158 | +        mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)  | 
 | 159 | +        return mean_real + (mean_imag * xp.asarray(1j))  | 
 | 160 | +    return xp.mean(x, axis=axis, keepdims=keepdims)  | 
 | 161 | + | 
 | 162 | + | 
51 | 163 | def expand_dims(  | 
52 | 164 |     a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType  | 
53 | 165 | ) -> Array:  | 
 | 
0 commit comments