From c30b59eb6ff4461f7194e91eda6e68c2535a1ab7 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 1 Mar 2023 12:40:32 -0500 Subject: [PATCH] Adds pytorch to get_namespace --- array_api_compat/common/_helpers.py | 7 +++++++ tests/test_get_namespace.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 tests/test_get_namespace.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 6a4a43fd..6d41572f 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -75,6 +75,13 @@ def get_namespace(*xs, _use_compat=True): else: import cupy as cp namespaces.add(cp) + elif _is_torch_array(x): + if _use_compat: + from .. import torch as torch_namespace + namespaces.add(torch_namespace) + else: + import torch + namespaces.add(torch) else: # TODO: Support Python scalars? raise ValueError("The input is not a supported array type") diff --git a/tests/test_get_namespace.py b/tests/test_get_namespace.py new file mode 100644 index 00000000..c150fde5 --- /dev/null +++ b/tests/test_get_namespace.py @@ -0,0 +1,14 @@ +import array_api_compat +import pytest + + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +def test_get_namespace(library): + lib = pytest.importorskip(library) + + array = lib.asarray([1.0, 2.0, 3.0]) + namespace = array_api_compat.get_namespace(array) + + expected_namespace = getattr(array_api_compat, library) + assert namespace is expected_namespace +