Skip to content

Commit 9ac3e41

Browse files
adriagarpAdrián García Pitarchlucascolley
authored
ENH: cov delegation (#451)
Co-authored-by: Adrián García Pitarch <adrian.garcia-pitarch@lht.dlh.de> Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent 0c9620e commit 9ac3e41

File tree

4 files changed

+109
-77
lines changed

4 files changed

+109
-77
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._delegation import (
44
atleast_nd,
5+
cov,
56
expand_dims,
67
isclose,
78
nan_to_num,
@@ -13,7 +14,6 @@
1314
from ._lib._funcs import (
1415
apply_where,
1516
broadcast_shapes,
16-
cov,
1717
create_diagonal,
1818
default_dtype,
1919
kron,

src/array_api_extra/_delegation.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,95 @@
1818
from ._lib._utils._helpers import asarrays
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["expand_dims", "isclose", "nan_to_num", "one_hot", "pad", "sinc"]
21+
__all__ = [
22+
"cov",
23+
"expand_dims",
24+
"isclose",
25+
"nan_to_num",
26+
"one_hot",
27+
"pad",
28+
"sinc",
29+
]
30+
31+
32+
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
33+
"""
34+
Estimate a covariance matrix.
35+
36+
Covariance indicates the level to which two variables vary together.
37+
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
38+
then the covariance matrix element :math:`C_{ij}` is the covariance of
39+
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
40+
of :math:`x_i`.
41+
42+
This provides a subset of the functionality of ``numpy.cov``.
43+
44+
Parameters
45+
----------
46+
m : array
47+
A 1-D or 2-D array containing multiple variables and observations.
48+
Each row of `m` represents a variable, and each column a single
49+
observation of all those variables.
50+
xp : array_namespace, optional
51+
The standard-compatible namespace for `m`. Default: infer.
52+
53+
Returns
54+
-------
55+
array
56+
The covariance matrix of the variables.
57+
58+
Examples
59+
--------
60+
>>> import array_api_strict as xp
61+
>>> import array_api_extra as xpx
62+
63+
Consider two variables, :math:`x_0` and :math:`x_1`, which
64+
correlate perfectly, but in opposite directions:
65+
66+
>>> x = xp.asarray([[0, 2], [1, 1], [2, 0]]).T
67+
>>> x
68+
Array([[0, 1, 2],
69+
[2, 1, 0]], dtype=array_api_strict.int64)
70+
71+
Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance
72+
matrix shows this clearly:
73+
74+
>>> xpx.cov(x, xp=xp)
75+
Array([[ 1., -1.],
76+
[-1., 1.]], dtype=array_api_strict.float64)
77+
78+
Note that element :math:`C_{0,1}`, which shows the correlation between
79+
:math:`x_0` and :math:`x_1`, is negative.
80+
81+
Further, note how `x` and `y` are combined:
82+
83+
>>> x = xp.asarray([-2.1, -1, 4.3])
84+
>>> y = xp.asarray([3, 1.1, 0.12])
85+
>>> X = xp.stack((x, y), axis=0)
86+
>>> xpx.cov(X, xp=xp)
87+
Array([[11.71 , -4.286 ],
88+
[-4.286 , 2.14413333]], dtype=array_api_strict.float64)
89+
90+
>>> xpx.cov(x, xp=xp)
91+
Array(11.71, dtype=array_api_strict.float64)
92+
93+
>>> xpx.cov(y, xp=xp)
94+
Array(2.14413333, dtype=array_api_strict.float64)
95+
"""
96+
97+
if xp is None:
98+
xp = array_namespace(m)
99+
100+
if (
101+
is_numpy_namespace(xp)
102+
or is_cupy_namespace(xp)
103+
or is_torch_namespace(xp)
104+
or is_dask_namespace(xp)
105+
or is_jax_namespace(xp)
106+
):
107+
return xp.cov(m)
108+
109+
return _funcs.cov(m, xp=xp)
22110

23111

24112
def expand_dims(

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -248,73 +248,8 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...
248248
return tuple(out)
249249

250250

251-
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
252-
"""
253-
Estimate a covariance matrix.
254-
255-
Covariance indicates the level to which two variables vary together.
256-
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
257-
then the covariance matrix element :math:`C_{ij}` is the covariance of
258-
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
259-
of :math:`x_i`.
260-
261-
This provides a subset of the functionality of ``numpy.cov``.
262-
263-
Parameters
264-
----------
265-
m : array
266-
A 1-D or 2-D array containing multiple variables and observations.
267-
Each row of `m` represents a variable, and each column a single
268-
observation of all those variables.
269-
xp : array_namespace, optional
270-
The standard-compatible namespace for `m`. Default: infer.
271-
272-
Returns
273-
-------
274-
array
275-
The covariance matrix of the variables.
276-
277-
Examples
278-
--------
279-
>>> import array_api_strict as xp
280-
>>> import array_api_extra as xpx
281-
282-
Consider two variables, :math:`x_0` and :math:`x_1`, which
283-
correlate perfectly, but in opposite directions:
284-
285-
>>> x = xp.asarray([[0, 2], [1, 1], [2, 0]]).T
286-
>>> x
287-
Array([[0, 1, 2],
288-
[2, 1, 0]], dtype=array_api_strict.int64)
289-
290-
Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance
291-
matrix shows this clearly:
292-
293-
>>> xpx.cov(x, xp=xp)
294-
Array([[ 1., -1.],
295-
[-1., 1.]], dtype=array_api_strict.float64)
296-
297-
Note that element :math:`C_{0,1}`, which shows the correlation between
298-
:math:`x_0` and :math:`x_1`, is negative.
299-
300-
Further, note how `x` and `y` are combined:
301-
302-
>>> x = xp.asarray([-2.1, -1, 4.3])
303-
>>> y = xp.asarray([3, 1.1, 0.12])
304-
>>> X = xp.stack((x, y), axis=0)
305-
>>> xpx.cov(X, xp=xp)
306-
Array([[11.71 , -4.286 ],
307-
[-4.286 , 2.14413333]], dtype=array_api_strict.float64)
308-
309-
>>> xpx.cov(x, xp=xp)
310-
Array(11.71, dtype=array_api_strict.float64)
311-
312-
>>> xpx.cov(y, xp=xp)
313-
Array(2.14413333, dtype=array_api_strict.float64)
314-
"""
315-
if xp is None:
316-
xp = array_namespace(m)
317-
251+
def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
252+
"""See docstring in array_api_extra._delegation."""
318253
m = xp.asarray(m, copy=True)
319254
dtype = (
320255
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)

tests/test_funcs.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -419,46 +419,55 @@ def test_none(self, args: tuple[tuple[float | None, ...], ...]):
419419
class TestCov:
420420
def test_basic(self, xp: ModuleType):
421421
xp_assert_close(
422-
cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T),
422+
cov(xp.asarray([[0, 2], [1, 1], [2, 0]], dtype=xp.float64).T),
423423
xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64),
424424
)
425425

426426
def test_complex(self, xp: ModuleType):
427-
actual = cov(xp.asarray([[1, 2, 3], [1j, 2j, 3j]]))
427+
actual = cov(xp.asarray([[1, 2, 3], [1j, 2j, 3j]], dtype=xp.complex128))
428428
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
429429
xp_assert_close(actual, expect)
430430

431+
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="jax#32296")
431432
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#877")
432433
def test_empty(self, xp: ModuleType):
433434
with warnings.catch_warnings(record=True):
434435
warnings.simplefilter("always", RuntimeWarning)
435-
xp_assert_equal(cov(xp.asarray([])), xp.asarray(xp.nan, dtype=xp.float64))
436+
warnings.simplefilter("always", UserWarning)
436437
xp_assert_equal(
437-
cov(xp.reshape(xp.asarray([]), (0, 2))),
438+
cov(xp.asarray([], dtype=xp.float64)),
439+
xp.asarray(xp.nan, dtype=xp.float64),
440+
)
441+
xp_assert_equal(
442+
cov(xp.reshape(xp.asarray([], dtype=xp.float64), (0, 2))),
438443
xp.reshape(xp.asarray([], dtype=xp.float64), (0, 0)),
439444
)
440445
xp_assert_equal(
441-
cov(xp.reshape(xp.asarray([]), (2, 0))),
446+
cov(xp.reshape(xp.asarray([], dtype=xp.float64), (2, 0))),
442447
xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]], dtype=xp.float64),
443448
)
444449

445450
def test_combination(self, xp: ModuleType):
446-
x = xp.asarray([-2.1, -1, 4.3])
447-
y = xp.asarray([3, 1.1, 0.12])
451+
x = xp.asarray([-2.1, -1, 4.3], dtype=xp.float64)
452+
y = xp.asarray([3, 1.1, 0.12], dtype=xp.float64)
448453
X = xp.stack((x, y), axis=0)
449454
desired = xp.asarray([[11.71, -4.286], [-4.286, 2.144133]], dtype=xp.float64)
450455
xp_assert_close(cov(X), desired, rtol=1e-6)
451456
xp_assert_close(cov(x), xp.asarray(11.71, dtype=xp.float64))
452457
xp_assert_close(cov(y), xp.asarray(2.144133, dtype=xp.float64), rtol=1e-6)
453458

459+
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="array-api-extra#455")
454460
def test_device(self, xp: ModuleType, device: Device):
455461
x = xp.asarray([1, 2, 3], device=device)
456462
assert get_device(cov(x)) == device
457463

458464
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
459465
def test_xp(self, xp: ModuleType):
460466
xp_assert_close(
461-
cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp),
467+
cov(
468+
xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]], dtype=xp.float64).T,
469+
xp=xp,
470+
),
462471
xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64),
463472
)
464473

0 commit comments

Comments
 (0)