Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing support for shapes #16544

Open
mitar opened this issue Dec 6, 2017 · 38 comments
Open

Typing support for shapes #16544

mitar opened this issue Dec 6, 2017 · 38 comments

Comments

@mitar
Copy link

mitar commented Dec 6, 2017

See how contracts package are trying to provide support for something similar to shapes. They are extending annotations in a different way than just standard typing and maybe something like that could be also done. So instead of providing specific extension (PEP) for typing to allow thing like shapes, it might be maybe more useful to determine a syntax for general constraints on types and use that, in addition to standard types through typing.

@ethanhs
Copy link

ethanhs commented Dec 7, 2017

This will take quite some work here and in mypy. At a discussion held several weeks ago, the agreement was on the plan in the README. So while this will happen, adding support for checking shapes will not happen soon. There is much to decide, such as how to define syntax related to declaring array shape, and how an actual type checking implementation should work.

@rmcgibbo
Copy link
Contributor

rmcgibbo commented Dec 12, 2017

@shoyer: I'm not sure where this should be noted, but the framing of numpy/numpy-stubs#6 and python/typing#516 seems to indicate that you're favoring putting the full shape in the type specification: that is, for a two-dimensional array of size (M,N), the proposal seems to be that the type annotation would convey both the value of M and N.

I think we should consider a more modest version in which only the number of dimensions is tracked via the type, but not the shape itself. That is, instead of ndarray[Shape[100, 200], Dtype] for an array of shape=(100,200), we go for ndarray[2, Dtype], to indicate only the fact that the array has ndim=2.

@rmcgibbo
Copy link
Contributor

Tracking only the rank (as opposed to the full shape) of the array through the type system is, for what it's worth, how all of the C++ tensor libraries work to the best of my knowledge. It's more consistent with the rest of the PEP484 container types as well, in the sense that the type List[int] doesn't track the number of elements.

@shoyer
Copy link
Member

shoyer commented Dec 12, 2017

@rmcgibbo Yes, this needs justification.

My reasoning for favoring specifying shapes in types comes down to:

  • Rank and shape are not mutually exclusive: if we do shape, we get rank as well. We could even have an explicit Rank type, which we tell type checkers should be treated as equivalent to anonymous dimensions, e.g., ndarray[Rank[2]] could be the same as ndarray[Shape[:, :]].
  • For code directly using NumPy, rank checks would not be anywhere near as valuable as shape checks. The problem is that nearly NumPy operation is valid on inputs with an arbitrary number of dimensions (even between arguments, with broadcasting). Giving up on shape checks would means we couldn't type ufuncs in a useful way, or catch the common errors with broadcasting. This would basically mean reserving shape/rank checks for the (many) applications and libraries that only use on particular ranks (e.g., vectors and matrices).

With regards to C++ libraries, Eigen at least does allow for either runtime and statically defined shapes. That said, I certainly would not bother with static/compile-time shapes on their own, because it is indeed rare that you know the exact shape of arrays before running a program. But with generics, we have the possibility of doing much more sophisticated pattern matching. This isn't feasible in C++, where you would need to use templates for every possible dimension size. In contrast, we can probably do something more sophisticated in Python to allow for typing shapes without enumerating all of them (similar to what you see in Haskell).

@rmcgibbo
Copy link
Contributor

rmcgibbo commented Dec 12, 2017

Shape is clearly more expressive than rank. So in that sense, I agree with both of your bulleted points above (not mutually exclusive, and static shape checks, if they worked, would provide more value to users). There are two factors on the other side that come to mind.

  • In almost any real code, it will be impossible for mypy to statically infer the shapes. For example, take the following:
M = 5
N = 10
arr = np.zeros((M, N))
reveal_type(arr)

Mypy tracks the type of M and N, but not their values. It can infer the rank of arr, but not the full shape. Want to do something more complicated?

def compute_length() -> int:
    return <something>
def compute_width() > int:
    return <something>
arr = np.zeros((compute_length(), compute_width()))
reveal_type(arr)

Not going to work.

  • If the shape is to be tracked explicitly, for example, to the best of my understanding there is no way to express even the simplest numpy function signatures with the current generation of PEP484/mypy. For example, take
# requires a type-level addition function to express the return type.
def concatenate_1d_arrays(arr1 : ndarray[D, Shape[M]], arr2 : ndarray[D, Shape[N]] -> ?

@shoyer
Copy link
Member

shoyer commented Dec 12, 2017

Take another look at my doc on shape typing. You're right that functions like concatenate are hard, but I would claim that even mere matching of integer sizes (without arithmetic) along with dimension identity would be enough to be valuable:

  • You could catch broadcasting errors when using ufuncs.
  • You could catch errors with regards to which dimension is removed by aggregation or indexing
  • You could verify that the shapes of non-concatenated axes of arrays in concatenate match, and are propagated to the result, e.g., stack_matrices(x: ndarray[Shape[:, N], y: ndarray[Shape[:, N]]) -> ndarray[Shape[:, N]].

In almost any real code, it will be impossible for mypy to statically infer the shapes. For example, take the following:

M = 5
N = 10
arr = np.zeros((M, N))
reveal_type(arr)

I'm not certain that mypy will be unable to infer shapes in this case. If typing adds support for literals (which is almost certainly necessary regardless for NumPy), then I imagine it will also support literals in variables, in the same way that it supports type aliases.

@rmcgibbo
Copy link
Contributor

rmcgibbo commented Dec 12, 2017

My suspicions are that shape tracking is (1) going to add a small marginal value, since it will be rarely be possible to statically infer them (e.g. my two examples above) in real world code, whereas it will often be possible to statically infer rank in real world code; and (2) going to add a ton of complexity. That suggests to me it's not the best place to start for the "1.0".

If there's a way to design the syntax so that it is initially rank-only, build a viable package around that, and add shape information afterwards in a second stage, then that would be my preference.

@hameerabbasi
Copy link
Contributor

hameerabbasi commented Feb 19, 2018

Maybe for v1.0 we could have ndarray[2, Dtype] be an alias for ndarray[Shape[Any, Any], Dtype], ignoring the length of the shape but keeping the ndim in mind. This would probably be a good-ish idea to keep forward compatibility.

Edit: It might also be a good idea to have a warning if the shape contains anything other than Any (or similar), for now.

@shoyer
Copy link
Member

shoyer commented Feb 19, 2018

Yes, I agree that it could make sense to add shape support incrementally. We'll probably still need upstream work in typing even to make that work well, though (e.g., to support integer values in types).

@BvB93
Copy link
Member

BvB93 commented Oct 7, 2020

It seems like there is some interesting work being done on variadic generics:
https://mail.python.org/archives/list/typing-sig@python.org/thread/SQVTQYWIOI4TIO7NNBTFFWFMSMS2TA4J/

Depending on how this progresses from here it is possible we could end up with either
ndarray[Dtype, [i, j]] (or ndarray[[i, j], Dtype]) without the need for a dedicated Shape object.

@shoyer shoyer changed the title Support for shapes Typing support for shapes Oct 7, 2020
@BvB93
Copy link
Member

BvB93 commented Jan 27, 2021

It seems like we now have a (draft) pep for variadic generics,
i.e. generics with not one but an arbitrary number of variables (like typing.Tuple): PEP 646.

@mrahtz
Copy link

mrahtz commented Mar 20, 2021

I think we're almost there now with PEP 646. One concern I have, though, is that a related PEP (PEP 637, indexing with keyword arguments) was just rejected, and in some discussion on python-dev, Guido van Rossum noted:

[The steering council's rejection] seems to imply that in order for a proposal like this to fare better in the future, the authors would need to line up support from specific, important communities like the scientific, data science or machine learning communities.

I therefore wanted to ask: What's the current feeling around how likely NumPy would be to use PEP 646 if it were accepted? A vote of support from NumPy folks would be a good sign that we're on the right track. (On the other hand, if the answer is "We're not sure enough yet about how shape-typing as a whole would work to want to commit to anything", that would be a very reasonable answer too - and would suggest that perhaps we need to do more prototyping first.)

@SimpleArt
Copy link

SimpleArt commented Feb 13, 2022

Realized the other day that most of what is wanted here can already be supported using typing.Tuple and typing.Literal, assuming typing.TypeVar behaves well with typing.Literal (I don't think it does with mypy). For many cases, the usual use of @typing.overload will be sufficient.

For example, one could type-hint matmul as follows:

from typing import Tuple, TypeVar, overload
import numpy as np
from numpy.typing import DType, NDArray

T1 = TypeVar("T1", bound=int)
T2 = TypeVar("T2", bound=int)
T3 = TypeVar("T3", bound=int)
DTypeVar = TypeVar("DTypeVar", bound=DType)

@overload
def matmul(x1: NDArray[Tuple[T1], DTypeVar], x2: NDArray[Tuple[T1], DTypeVar], /) -> DTypeVar:
    ...

@overload
def matmul(x1: NDArray[Tuple[T1], DTypeVar], x2: NDArray[Tuple[T1, T2], DTypeVar], /) -> NDArray[Tuple[T2], DTypeVar]:
    ...

@overload
def matmul(x1: NDArray[Tuple[T1, T2], DTypeVar], x2: NDArray[Tuple[T2], DTypeVar], /) -> NDArray[Tuple[T1], DTypeVar]:
    ...

@overload
def matmul(x1: NDArray[Tuple[T1, T2], DTypeVar], x2: NDArray[Tuple[T2, T3], DTypeVar], /) -> NDArray[Tuple[T1, T3], DTypeVar]:
    ...

If it works with typing.Literal, then one would be able to track the full shape. Otherwise, only the number of dimensions can be tracked.

HEIGHT = Literal[500]
WIDTH = Literal[1000]

x: NDArray[Tuple[HEIGHT, WIDTH], int] = ...
A: NDArray[Tuple[WIDTH, WIDTH], int] = ...
y: NDArray[Tuple[HEIGHT, WIDTH], int] = matmul(x, A)

To make it more readable, one could also alias npt.Len = Literal and npt.Shape = Tuple to be able to write like this:

x: NDArray[Shape[Len[500], Len[1000]], int]

Although a little clunkier, it avoids the issue of type-checkers failing to recognize constants like:

HEIGHT = 500
WIDTH = 1000
x: NDArray[[WIDTH, HEIGHT], DType]

It's also probably easier to create generics like this using typing.Literal to ensure it's an actual type. It also allows arbitrary lengths to be typed using int. Perhaps npt.AnyLen = int could be used as well?

# dataset[index][row][col]
dataset: NDArray[Tuple[AnyLen, HEIGHT, WIDTH], DType]

Semantically I would also argue that it makes a lot of sense, since this is the literal type-hint for ndarray.shape being used here.

@phiresky
Copy link

variadic generics (PEP 646) was accepted on 19. Jan 2022: https://mail.python.org/archives/list/python-dev@python.org/message/OR5RKV7GAVSGLVH3JAGQ6OXFAXIP5XDX/

variadic generics support in typing_extensions seems to have been merged 7 days ago: python/typing#963
support in cpython is still pending: python/cpython#31018 python/cpython#31019 python/cpython#31021

Does that mean it's possible for numpy to finally implement support for this now? 🤑

@rossbar
Copy link
Contributor

rossbar commented Oct 25, 2023

Is this feature still otherwise blocked?

According to spec 0, Python 3.10 will be supported until ~October 2024, so it's generally not possible to adopt Python 3.11 language features until that support window expires.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Oct 25, 2023

@rossbar Is there some missing necessary runtime behavior? Why can't things like TypeVarTuple be imported from typing-extensions?

If it's complicated, I guess waiting one year is no so bad.

@rossbar
Copy link
Contributor

rossbar commented Oct 25, 2023

Is there some missing necessary runtime behavior? Why can't things like TypeVarTuple be imported from typing-extensions?

I'm not sure - I defer to others with more knowledge of the typing landscape. I just wanted to point out the support window in the event that there are any dependencies.

@jorenham
Copy link
Member

Having typing-extensions as a python<=3.10 restricted dependency seems like a good idea to me. I also strongly suspect that vast majority of numpy users on python 3.10 and below, will have it installed already.

@BvB93
Copy link
Member

BvB93 commented Oct 28, 2023

The runtime aspect of shape-typing is not exactly the issue here as we're already using types.GenericAlias for handling the __class_getitem__ methods (though it is possible that this will require a from __future__ import annotations statement if older generic alias iterations are incompatible with TypeVarTuple).

It's very much been the lack of variadic support in mypy that's been the blocker previously, as the lack of type checker support made me very reluctant to enable PEP 646 in the stub files. Still, we'll have to wait a bit until the next mypy release wherein python/mypy#16242 is included.

@kaczmarj
Copy link

kaczmarj commented Nov 9, 2023

We have to wait a bit until the likes of mypy actually support it (xref python/mypy#12280), but afterwards we should be good to go.

python/mypy#12280 has been closed as fixed via python/mypy#16354 🚀 💯

could someone please provide guidance on how to type annotate numpy array shapes? it seems that there are several proposals here but is there a best practice?

@jorenham
Copy link
Member

jorenham commented Nov 9, 2023

@kaczmarj You can currently do something like this:

from typing import Literal, TypeAlias
import numpy as np

PauliMatrix: TypeAlias = np.ndarray[tuple[Literal[2], Literal[2]], np.dtype[np.complex128]]

Although this isn't really all that practical, since the first "shape" type parameter isn't really used within numpy at the moment, last time I checked. Instead, it's always cast to something like np.ndarray[Any, T]. E.g. numpy.typing.NDArray is an alias for np.ndarray[Any, np.dtype[T]].

@vnmabus
Copy link

vnmabus commented Nov 14, 2023

Does anyone have an idea of which syntax will be used for typing shapes? I wanted to type a related structure using similar syntax.

@marcospgp
Copy link

marcospgp commented Mar 18, 2024

Is there an update on this, maybe an expected release date?

Also currently confused about numpy.typing.NDArray vs numpy.ndarray for type annotations. (found some info on this above)

@jorenham
Copy link
Member

@marcospgp

type NDArray[T: np.generic] = np.ndarray[Any, np.dtype[T]]

https://github.com/numpy/numpy/blob/v1.26.4/numpy/_typing/_array_like.py#L32

@jorenham
Copy link
Member

jorenham commented Mar 18, 2024

I've been thinking about this for a while now, and I believe that the following proposal could work.

For the sake of brevity and readability, I'll be using the Python 3.12+ PEP 695 syntax. Additionally, I'll prefix type parameter declarations with either a ~, + or - to explicitly indicate their in-, co- or contra-variance, respectively.

The ndarray shape type parameter

Currently, its type signature is ndarray[~Shape: Any, +DType: dtype[Any]] (the + indicates covariance), and has ndarray[Shape, DType].shape: tuple[int, ...].
I propose to

  1. restrict the upper bound of Shape to tuple[int, ...]
  2. change ~Shape to +Shape, i.e. make it covariant (instead of invariant) TYP: Make array _ShapeType bound and covariant #26081
  3. change the type of ndarray.shape from tuple[int, ...] to Shape

To illustrate, what is currently

class ndarray[~Shape: Any, +DType: dtype[Any]]:
    shape: tuple[int, ...] 
    [...]

will become

class ndarray[+Shape: tuple[int, ...], +DType: dtype[Any]]:
    shape: Shape
    [...]

(in practice shape will be typed as a @property)

This is almost always backwards compatible. The only backwards incompatible case is for annotations that bind Shape@ndarray to a specific type that isn't a tuple[int, ...]. Note that having np.ndarray[Any, ...] is still allowed, so that e.g. numpy.typing.NDArray is fully backwards-compatible.

The upper bound of the +Shape type param can be made tighter, by noting the ndim limit of max 64, i.e. +Shape: tuple[()] | tuple[int] | tuple[int, int] | ... # up to tuples of length 64. However, this would become rather messy, and could negatively impact the performance of type checkers.

Add a shape-aware alias like NDArray

For the sake of backwards compatibility, numpy.typing.NDArray should be kept as-is.
I propose adding a new type alias to numpy.typing, that has an additional variadic type parameter to indicate the shape of an array.
Since the semantics of the ND prefix are now reflected through the presence of the type parameter for the shape, let's call it Array for now:

# in numpy.typing

# equivalent to the current definition
type NDArray[+Scalar: generic] = ndarray[Any, dtype[Scalar]]

# new
type Array[*Shape: int, +Scalar: generic] = ndarray[tuple[*Shape], Scalar]

Unfortunately, variadic type parameters (e.g. typing.TypeVarTuple) cannot have an upper bound (bound=...). However, type checkers should be able to infer that restriction automatically, since tuple[*Shape] will bind to tuple[int, ...], and therefore Shape: int.

Some examples:

from typing import Literal
import numpy as np
import numpy.typing as npt

# ndarray[tuple[Literal[1]], dtype[float64]]
array_0d: npt.Array[Literal[1], np.float64] = np.array(1 / 137)

# ndarray[tuple[N], dtype[T]]
type Vector[N: int, T: np.generic] = npt.Array[N, T]

# ndarray[tuple[N, N], dtype[floating[B]]]
type Covariance[N: int, B: npt.NBitBase] = npt.Array[N, N, np.floating[B]]

# ndarray[tuple[2, 2], dtype[complexfloating[Re, Im]]
type Pauli[B: npt.NBitBase] = npt.Array[2, 2, np.complexfloating[B, B]]

Updates to callable signatures

The annotations of many methods, operators, ufuncs, and other functions can now be made more specific.

For instance, numpy.outer(a, b, out=None) can now be annotated as

@overload
def outer[M, N, T](
    a: Array[M, Any], 
    b: Array[N, Any],
    out: Array[M, N, T] = ...,
) -> Array[M, N, T]: ...

@overload
def outer[M, N, T](
    a: Array[M, T], 
    b: Array[N, T],
    out: None = ...,
) -> Array[M, N, T]: ...

...  # overloads for the remaining dtype combinations

For functions with a more dynamic shape mapping, this might not be possible. For example numpy.transpose effectively reverses the shape of the input array. In this case, @overload's could be used for the first few dimensions (e.g. 4), assuming that lower-dimensional arrays used most often.

@baterflyrity
Copy link

@jorenham, good proposal, however consider annotating types should be simple, easy and fast:

def convert_image_rgb8_grayscale1[W, H](image: Array[H, W, 3, uint8]) -> Array[H, W, uint1]:
    return (image.sum(axis=2)/3).astype(uint1)

@vnmabus
Copy link

vnmabus commented Mar 19, 2024

Is it not possible to keep the parentheses around the shape dimensions? (Array[(H, W), uint8] instead of Array[H, W, uint8]). I think it would make it more consistent with functions and classes that need to take more than one shape, and thus require that the dimensions are grouped somehow.

@vnmabus
Copy link

vnmabus commented Mar 19, 2024

I just checked and that syntax is not allowed, so it should be Array[tuple[H, W], uint8], which I admit is not very pretty.

@baterflyrity
Copy link

baterflyrity commented Mar 19, 2024

I just checked and that syntax is not allowed,

Ive checked and there is not syntax error. Perhaps your met Array restrictions?

Array = ... # define type
def f[W,H](a: Array[(H,W,3),int]):...

@vnmabus
Copy link

vnmabus commented Mar 19, 2024

I did not mean a syntax error from Python (which does not care about annotations), but from MyPy.

@baterflyrity
Copy link

baterflyrity commented Mar 19, 2024

I did not mean a syntax error from Python (which does not care about annotations), but from MyPy.

Ah, sure. We can discuss anything interface here but these all breaks into mypy wall. Tbh, I dont track and know current way to achieve Array[(H,W,3),int].

@Jacob-Stevens-Haas
Copy link
Contributor

Jacob-Stevens-Haas commented Mar 19, 2024

Edit: Forgot the new 3.12 syntax. FWIW, here's the current code in pre-3.12 syntax:

_ShapeType = TypeVar("_ShapeType", bound=Any)
_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])

class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):

If you're talking about modifying _ShapeType, I def agree! #25729

in practice shape will be typed as a @property

Isn't it already?

The annotations of many methods, operators, ufuncs, and other functions can now be made more specific... For functions with a more dynamic shape mapping, this might not be possible

I'd also add that stepping down this path could help illuminate the right way to resolve a lot of the "left for the future" stuff in PEP 646.

EDIT: It would be nice to understand how/whether this applies to record arrays and masked arrays, which I don't have much experience with.

@jorenham
Copy link
Member

There appears to be some confusion about the PEP 695 syntax, and variadic typing parameters (i.e. typing.TypeVarTuple), in general. For those unfamiliar with them, I personally learned a lot by playing around with them in the pyright playground; perhaps it can help you too.

@baterflyrity You use an integer directly in your example, that is not allowed. Instead, a typing.Literal[3] should be used. Furthermore, your Array[(H, W, 3), int] example is invalid in two ways, I believe you meant Array[H, W, Literal[3], int64].

@vnmabus You correctly noticed that using a tuple for the shape in Array wouldn't be pretty. But for those that still want to do so, they could use ndarray[tuple[H, W], dtype[uint8]], which is exactly equivalent to Array[H, W, uint8].

@Jacob-Stevens-Haas I agree that this proposal highlights several missing features within python typing. But I'd prefer to tackle one problem at a time (I know from experience how deep the typing PEP proposal rabbithole goes).
Additionally, I must admit that I am not all too familiar with record arrays and masked arrays. But since there is no mention of these within numpy.typing last time I checked, I believe it might be best to leave those out of this proposal, and to tackle them elsewhere.

@Jacob-Stevens-Haas
Copy link
Contributor

I fully agree - I meant only that taking these small steps would, as a bonus, help make future PEP considerations more clear, i.e. by providing examples of @overload's that can't be written without additional typing capability.

Andrej730 added a commit to IfcOpenShell/IfcOpenShell that referenced this issue Apr 25, 2024
On older numpy versions, np.ndarray was less forgiving and wasn't allowing passing 1 argument instead of required 2.

And turned out numpy doesn't yet have typing for shapes (numpy/numpy#16544), so all matrices and other shapes specified as `npt.NDArray[np.float64]`.

Fixed type discrepancies for `get_edges` and `get_faces` and also had to fix `import_ifc` as Blender apparently has problems with storing np.int32 in custom attributes (https://projects.blender.org/blender/blender/issues/121072), tested that Blender is okay with np.int32 in other cases we had (addressing BMesh.verts[i] where `i` is np.int32).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests