Skip to content

Commit

Permalink
Stubs for Array, ArrayOrNone, and CArray (#1682)
Browse files Browse the repository at this point in the history
This PR adds type stubs for the Array, ArrayOrNone and CArray trait types.

Fixes #1657.
  • Loading branch information
mdickinson committed Aug 9, 2022
1 parent e0c6008 commit 57139d0
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 1 deletion.
6 changes: 6 additions & 0 deletions traits-stubs/traits-stubs/api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,9 @@ from .trait_handlers import (
TraitPrefixMap as TraitPrefixMap,
TraitCompound as TraitCompound,
)

from .trait_numeric import (
Array as Array,
ArrayOrNone as ArrayOrNone,
CArray as CArray,
)
68 changes: 68 additions & 0 deletions traits-stubs/traits-stubs/trait_numeric.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

from typing import Any, List, Optional, Tuple, Type, Union

import numpy as np

from .trait_type import _TraitType

# Things that are allowed as individual shape elements in the 'shape'
# tuple or list.
_ShapeElement = Union[None, int, Tuple[int, Union[None, int]]]

# Type for the shape parameter.
_Shape = Union[Tuple[_ShapeElement, ...], List[_ShapeElement]]

# The "Array" trait type is not as permissive as NumPy's asarray: it
# accepts only NumPy arrays, lists and tuples.
_ArrayLike = Union[List[Any], Tuple[Any, ...], np.ndarray[Any, Any]]

# Synonym for the "stores" type of the trait.
_Array = np.ndarray[Any, Any]

# Things that are accepted as dtypes. This doesn't attempt to cover
# all legal possibilities - only those that are common.
_DTypeLike = Union[np.dtype[Any], Type[Any], str]

class Array(_TraitType[_ArrayLike, _Array]):
def __init__(
self,
dtype: Optional[_DTypeLike] = ...,
shape: Optional[_Shape] = ...,
value: Optional[_ArrayLike] = ...,
*,
casting: str = ...,
**metadata: Any,
) -> None: ...

class ArrayOrNone(
_TraitType[Optional[_ArrayLike], Optional[_Array]]
):
def __init__(
self,
dtype: Optional[_DTypeLike] = ...,
shape: Optional[_Shape] = ...,
value: Optional[_ArrayLike] = ...,
*,
casting: str = ...,
**metadata: Any,
) -> None: ...

class CArray(_TraitType[_ArrayLike, _Array]):
def __init__(
self,
dtype: Optional[_DTypeLike] = ...,
shape: Optional[_Shape] = ...,
value: Optional[_ArrayLike] = ...,
*,
casting: str = ...,
**metadata: Any,
) -> None: ...
44 changes: 44 additions & 0 deletions traits-stubs/traits_stubs_tests/numpy_examples/Array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

import numpy as np

from traits.api import Array, ArrayOrNone, CArray, HasTraits


class HasArrayTraits(HasTraits):
spectrum = Array(shape=(None,), dtype=np.float64)
complex_shape = Array(shape=((512, None), (512, None), (3, 4)))
list_shape = Array(shape=[(512, None), (512, None), (3, 4)])
str_dtype = Array(dtype="f4")
dtype_dtype = Array(dtype=np.dtype("float"))
with_default_value = Array(value=np.zeros(5))
with_list_default = Array(value=[1, 2, 3, 4, 5])
with_tuple_default = Array(value=(1, 2, 3, 4, 5))
with_casting = Array(casting="same_kind")

maybe_image = ArrayOrNone(shape=(None, None, 3), dtype=np.float64)
cspectrum = CArray(shape=(None,), dtype=np.float64)

# Bad trait declarations
bad_dtype = Array(dtype=62) # E: arg-type
bad_default = Array(value=123) # E: arg-type
bad_shape = Array(shape=3) # E: arg-type
bad_shape_element = Array(shape=(3, (None, None))) # E: arg-type


obj = HasArrayTraits()
obj.spectrum = np.array([2, 3, 4], dtype=np.float64)
obj.spectrum = "not an array" # E: assignment
obj.spectrum = None # E: assignment

obj.maybe_image = None
obj.maybe_image = np.zeros((5, 5, 3))
obj.maybe_image = 2.3 # E: assignment
23 changes: 22 additions & 1 deletion traits-stubs/traits_stubs_tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from traits.testing.optional_dependencies import (
pkg_resources,
requires_mypy,
requires_numpy_testing,
requires_pkg_resources,
)
from traits_stubs_tests.util import MypyAssertions


@requires_pkg_resources
@requires_mypy
class TestAnnotations(TestCase, MypyAssertions):
@requires_pkg_resources
def test_all(self, filename_suffix=""):
""" Run mypy for all files contained in traits_stubs_tests/examples
directory.
Expand All @@ -41,3 +42,23 @@ def test_all(self, filename_suffix=""):
for file_path in examples_dir.glob("*{}.py".format(filename_suffix)):
with self.subTest(file_path=file_path):
self.assertRaisesMypyError(file_path)

@requires_numpy_testing
def test_numpy_examples(self):
""" Run mypy for files contained in traits_stubs_tests/numpy_examples
directory.
Lines with expected errors are marked inside these files.
Any mismatch will raise an assertion error.
Parameters
----------
filename_suffix: str
Optional filename suffix filter.
"""
examples_dir = Path(pkg_resources.resource_filename(
'traits_stubs_tests', 'numpy_examples'))

for file_path in examples_dir.glob("*.py"):
with self.subTest(file_path=file_path):
self.assertRaisesMypyError(file_path)
4 changes: 4 additions & 0 deletions traits/testing/optional_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def optional_import(name):
numpy = optional_import("numpy")
requires_numpy = unittest.skipIf(numpy is None, "NumPy not available")

numpy_testing = optional_import("numpy.testing")
requires_numpy_testing = unittest.skipIf(
numpy_testing is None, "numpy.testing not available")

pkg_resources = optional_import("pkg_resources")
requires_pkg_resources = unittest.skipIf(
pkg_resources is None, "pkg_resources not available"
Expand Down

0 comments on commit 57139d0

Please sign in to comment.