Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex number support to linalg.trace #541

Merged
merged 1 commit into from Dec 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 4 additions & 5 deletions spec/API_specification/array_api/linalg.py
Expand Up @@ -447,14 +447,12 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr

- If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum).

For floating-point operands,

- If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate).
For both real-valued and complex floating-point operands, special cases must be handled as if the operation is implemented by successive application of :func:`~array_api.add`.

Parameters
----------
x: array
input array having shape ``(..., M, N)`` and whose innermost two dimensions form ``MxN`` matrices. Should have a real-valued data type.
input array having shape ``(..., M, N)`` and whose innermost two dimensions form ``MxN`` matrices. Should have a numeric data type.
offset: int
offset specifying the off-diagonal relative to the main diagonal.

Expand All @@ -466,8 +464,9 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr
dtype: Optional[dtype]
data type of the returned array. If ``None``,

- if the default data type corresponding to the data type "kind" (integer or floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).

Expand Down