From eaf1e5260d41058ee261484e5b3f00f816964374 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 8 Mar 2023 15:15:12 -0700 Subject: [PATCH] Rename get_namespace to array_namespace get_namespace is maintained as a backwards compatible alias. Fixes #19. --- CHANGELOG.md | 14 +++++++--- README.md | 10 +++---- array_api_compat/common/_aliases.py | 4 +-- array_api_compat/common/_helpers.py | 10 ++++--- tests/test_array_namespace.py | 41 +++++++++++++++++++++++++++++ tests/test_get_namespace.py | 37 -------------------------- 6 files changed, 65 insertions(+), 51 deletions(-) create mode 100644 tests/test_array_namespace.py delete mode 100644 tests/test_get_namespace.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d971dfc7..b82a9845 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # 1.1.1 (2023-03-08) +## Major Changes + +- Rename `get_namespace()` to `array_namespace()` (`get_namespace()` is + maintained as a backwards compatible alias). + ## Minor Changes - The minimum supported NumPy version is now 1.21. Fixed a few issues with @@ -8,11 +13,14 @@ - Add `api_version` to `get_namespace()`. -- `get_namespace()` now works correctly with `torch` tensors. +- `array_namespace()` (*née* `get_namespace()`) now works correctly with + `torch` tensors. -- `get_namespace()` now works correctly with `numpy.array_api` arrays. +- `array_namespace()` (*née* `get_namespace()`) now works correctly with + `numpy.array_api` arrays. -- `get_namespace()` now raises `TypeError` instead of `ValueError`. +- `array_namespace()` (*née* `get_namespace()`) now raises `TypeError` instead + of `ValueError`. - Fix the `torch.std` wrapper. diff --git a/README.md b/README.md index 6f839866..ad16b348 100644 --- a/README.md +++ b/README.md @@ -21,11 +21,11 @@ later this year. ## Usage The typical usage of this library will be to get the corresponding array API -compliant namespace from the input arrays using `get_namespace()`, like +compliant namespace from the input arrays using `array_namespace()`, like ```py def your_function(x, y): - xp = array_api_compat.get_namespace(x, y) + xp = array_api_compat.array_namespace(x, y) # Now use xp as the array library namespace return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) ``` @@ -88,7 +88,7 @@ part of the specification but which are useful for using the array API: - `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array object. -- `get_namespace(*xs)`: Get the corresponding array API namespace for the +- `array_namespace(*xs)`: Get the corresponding array API namespace for the arrays `xs`. For example, if the arrays are NumPy arrays, the returned namespace will be `array_api_compat.numpy`. Note that this function will also work for namespaces that aren't supported by this compat library but @@ -133,7 +133,7 @@ specification: don't want to monkeypatch or wrap it. The helper functions `device()` and `to_device()` are provided to work around these missing methods (see above). `x.mT` can be replaced with `xp.linalg.matrix_transpose(x)`. - `get_namespace(x)` should be used instead of `x.__array_namespace__`. + `array_namespace(x)` should be used instead of `x.__array_namespace__`. - Value-based casting for scalars will be in effect unless explicitly disabled with the environment variable `NPY_PROMOTION_STATE=weak` or @@ -168,7 +168,7 @@ version. - Like NumPy/CuPy, we do not wrap the `torch.Tensor` object. It is missing the `__array_namespace__` and `to_device` methods, so the corresponding helper - functions `get_namespace()` and `to_device()` in this library should be + functions `array_namespace()` and `to_device()` in this library should be used instead (see above). - The `x.size` attribute on `torch.Tensor` is a function that behaves diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index bf8d25f7..8875f2c2 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -13,7 +13,7 @@ from types import ModuleType import inspect -from ._helpers import _check_device, _is_numpy_array, get_namespace +from ._helpers import _check_device, _is_numpy_array, array_namespace # These functions are modified from the NumPy versions. @@ -293,7 +293,7 @@ def _asarray( """ if namespace is None: try: - xp = get_namespace(obj, _use_compat=False) + xp = array_namespace(obj, _use_compat=False) except ValueError: # TODO: What about lists of arrays? raise ValueError("A namespace must be specified for asarray() with non-array input") diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index ad162ccf..e6adc948 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -53,7 +53,7 @@ def _check_api_version(api_version): if api_version is not None and api_version != '2021.12': raise ValueError("Only the 2021.12 version of the array API specification is currently supported") -def get_namespace(*xs, api_version=None, _use_compat=True): +def array_namespace(*xs, api_version=None, _use_compat=True): """ Get the array API compatible namespace for the arrays `xs`. @@ -62,7 +62,7 @@ def get_namespace(*xs, api_version=None, _use_compat=True): Typical usage is def your_function(x, y): - xp = array_api_compat.get_namespace(x, y) + xp = array_api_compat.array_namespace(x, y) # Now use xp as the array library namespace return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) @@ -72,7 +72,7 @@ def your_function(x, y): namespaces = set() for x in xs: if isinstance(x, (tuple, list)): - namespaces.add(get_namespace(*x, _use_compat=_use_compat)) + namespaces.add(array_namespace(*x, _use_compat=_use_compat)) elif hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__(api_version=api_version)) elif _is_numpy_array(x): @@ -113,6 +113,8 @@ def your_function(x, y): return xp +# backwards compatibility alias +get_namespace = array_namespace def _check_device(xp, device): if xp == sys.modules.get('numpy'): @@ -224,4 +226,4 @@ def size(x): return None return math.prod(x.shape) -__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device', 'size'] +__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size'] diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py new file mode 100644 index 00000000..4b5bb07c --- /dev/null +++ b/tests/test_array_namespace.py @@ -0,0 +1,41 @@ +import array_api_compat +from array_api_compat import array_namespace +import pytest + + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("api_version", [None, '2021.12']) +def test_array_namespace(library, api_version): + lib = pytest.importorskip(library) + + array = lib.asarray([1.0, 2.0, 3.0]) + namespace = array_api_compat.array_namespace(array, api_version=api_version) + + if 'array_api' in library: + assert namespace == lib + else: + assert namespace == getattr(array_api_compat, library) + +def test_array_namespace_multiple(): + import numpy as np + + x = np.asarray([1, 2]) + assert array_namespace(x, x) == array_namespace((x, x)) == \ + array_namespace((x, x), x) == array_api_compat.numpy + +def test_array_namespace_errors(): + pytest.raises(TypeError, lambda: array_namespace([1])) + pytest.raises(TypeError, lambda: array_namespace()) + + import numpy as np + import torch + x = np.asarray([1, 2]) + y = torch.asarray([1, 2]) + + pytest.raises(TypeError, lambda: array_namespace(x, y)) + + pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12')) + +def test_get_namespace(): + # Backwards compatible wrapper + assert array_api_compat.get_namespace is array_api_compat.array_namespace diff --git a/tests/test_get_namespace.py b/tests/test_get_namespace.py deleted file mode 100644 index e4e4fb60..00000000 --- a/tests/test_get_namespace.py +++ /dev/null @@ -1,37 +0,0 @@ -import array_api_compat -from array_api_compat import get_namespace -import pytest - - -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) -@pytest.mark.parametrize("api_version", [None, '2021.12']) -def test_get_namespace(library, api_version): - lib = pytest.importorskip(library) - - array = lib.asarray([1.0, 2.0, 3.0]) - namespace = array_api_compat.get_namespace(array, api_version=api_version) - - if 'array_api' in library: - assert namespace == lib - else: - assert namespace == getattr(array_api_compat, library) - -def test_get_namespace_multiple(): - import numpy as np - - x = np.asarray([1, 2]) - assert get_namespace(x, x) == get_namespace((x, x)) == \ - get_namespace((x, x), x) == array_api_compat.numpy - -def test_get_namespace_errors(): - pytest.raises(TypeError, lambda: get_namespace([1])) - pytest.raises(TypeError, lambda: get_namespace()) - - import numpy as np - import torch - x = np.asarray([1, 2]) - y = torch.asarray([1, 2]) - - pytest.raises(TypeError, lambda: get_namespace(x, y)) - - pytest.raises(ValueError, lambda: get_namespace(x, api_version='2022.12'))