/
typing.py
38 lines (25 loc) · 1.06 KB
/
typing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# -*- coding: utf-8 -*-
# Copyright (C) 2021 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.
"""Type definitions."""
from typing import Any, Tuple, Union
import numpy as np
import jax
import jax.numpy as jnp
__author__ = """Luke Pfister <luke.pfister@gmail.com>"""
JaxArray = Union[jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
"""A jax array."""
Array = Union[np.ndarray, JaxArray]
"""Either a numpy or jax array."""
PRNGKey = jnp.ndarray
"""A key for jax random number generators (see :mod:`jax.random`)."""
DType = Any # TODO: can we do better than this? Maybe with the new numpy typing?
"""A numpy or jax dtype."""
Shape = Tuple[int, ...] # shape of an array
"""A shape of a numpy or jax array."""
BlockShape = Tuple[Tuple[int, ...], ...] # shape of a BlockArray
"""A shape of a :class:`.BlockArray`."""
Axes = Union[int, Tuple[int, ...]] # one or more axes