|
4 | 4 | import warnings |
5 | 5 | from collections.abc import Callable, Sequence |
6 | 6 | from types import ModuleType, NoneType |
7 | | -from typing import cast, overload |
| 7 | +from typing import Literal, cast, overload |
8 | 8 |
|
9 | 9 | from ._at import at |
10 | 10 | from ._utils import _compat, _helpers |
|
16 | 16 | meta_namespace, |
17 | 17 | ndindex, |
18 | 18 | ) |
19 | | -from ._utils._typing import Array |
| 19 | +from ._utils._typing import Array, Device, DType |
20 | 20 |
|
21 | 21 | __all__ = [ |
22 | 22 | "apply_where", |
@@ -438,6 +438,44 @@ def create_diagonal( |
438 | 438 | return xp.reshape(diag, (*batch_dims, n, n)) |
439 | 439 |
|
440 | 440 |
|
| 441 | +def default_dtype( |
| 442 | + xp: ModuleType, |
| 443 | + kind: Literal[ |
| 444 | + "real floating", "complex floating", "integral", "indexing" |
| 445 | + ] = "real floating", |
| 446 | + *, |
| 447 | + device: Device | None = None, |
| 448 | +) -> DType: |
| 449 | + """ |
| 450 | + Return the default dtype for the given namespace and device. |
| 451 | +
|
| 452 | + This is a convenience shorthand for |
| 453 | + ``xp.__array_namespace_info__().default_dtypes(device=device)[kind]``. |
| 454 | +
|
| 455 | + Parameters |
| 456 | + ---------- |
| 457 | + xp : array_namespace |
| 458 | + The standard-compatible namespace for which to get the default dtype. |
| 459 | + kind : {'real floating', 'complex floating', 'integral', 'indexing'}, optional |
| 460 | + The kind of dtype to return. Default is 'real floating'. |
| 461 | + device : Device, optional |
| 462 | + The device for which to get the default dtype. Default: current device. |
| 463 | +
|
| 464 | + Returns |
| 465 | + ------- |
| 466 | + dtype |
| 467 | + The default dtype for the given namespace, kind, and device. |
| 468 | + """ |
| 469 | + dtypes = xp.__array_namespace_info__().default_dtypes(device=device) |
| 470 | + try: |
| 471 | + return dtypes[kind] |
| 472 | + except KeyError as e: |
| 473 | + domain = ("real floating", "complex floating", "integral", "indexing") |
| 474 | + assert set(dtypes) == set(domain), f"Non-compliant namespace: {dtypes}" |
| 475 | + msg = f"Unknown kind '{kind}'. Expected one of {domain}." |
| 476 | + raise ValueError(msg) from e |
| 477 | + |
| 478 | + |
441 | 479 | def expand_dims( |
442 | 480 | a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None |
443 | 481 | ) -> Array: |
@@ -728,9 +766,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
728 | 766 | x = xp.reshape(x, (-1,)) |
729 | 767 | x = xp.sort(x) |
730 | 768 | mask = x != xp.roll(x, -1) |
731 | | - default_int = xp.__array_namespace_info__().default_dtypes( |
732 | | - device=_compat.device(x) |
733 | | - )["integral"] |
| 769 | + default_int = default_dtype(xp, "integral", device=_compat.device(x)) |
734 | 770 | return xp.maximum( |
735 | 771 | # Special cases: |
736 | 772 | # - array is size 0 |
|
0 commit comments