diff --git a/python/tests/test_generated_types.py b/python/tests/test_generated_types.py index aed13f71..7c64dfc4 100644 --- a/python/tests/test_generated_types.py +++ b/python/tests/test_generated_types.py @@ -64,6 +64,30 @@ def test_get_dtype(): with pytest.raises(RuntimeError, match="Generic type arguments not provided"): tm.get_dtype(tm.MyTuple) + with pytest.raises(RuntimeError, match="Generic type arguments not provided"): + tm.get_dtype(tm.AliasedTuple) + with pytest.raises(RuntimeError, match="Generic type arguments not provided"): + tm.get_dtype(tm.AliasedOpenGeneric) + with pytest.raises(RuntimeError, match="Generic type arguments not provided"): + tm.get_dtype(tm.AliasedGenericUnion2) + + assert tm.get_dtype(tm.RecordWithAliasedGenerics) == np.dtype( + [ + ("my_strings", tm.get_dtype(tm.MyTuple[str, str])), + ("aliased_strings", tm.get_dtype(tm.AliasedTuple[str, str])), + ], + align=True, + ) + + assert tm.get_dtype(tm.RecordWithFixedArrays) == np.dtype( + [ + ("ints", tm.get_dtype(tm.Int32), (2, 3)), + ("fixed_simple_record_array", tm.get_dtype(tm.SimpleRecord), (3, 2)), + ("fixed_record_with_vlens_array", tm.get_dtype(tm.RecordWithVlens), (2, 2)), + ], + align=True, + ) + assert tm.get_dtype(tm.MyTuple[tm.Int32, tm.Float32]) == np.dtype( [("v1", "= (3, 10): from types import UnionType -from typing import Any, Callable, Union, cast, get_args, get_origin +from typing import Any, Callable, Union, cast, get_args, get_origin, TypeVar import numpy as np from . import yardl_types as yardl @@ -44,8 +44,6 @@ def make_get_dtype_func( dtype_map[float] = np.dtype(np.float64) dtype_map[complex] = np.dtype(np.complex128) - annotatedRuntimeType = type(yardl.Int32) - def get_dtype_impl( dtype_map: dict[ Union[type, GenericAlias], @@ -53,45 +51,27 @@ def get_dtype_impl( ], t: Union[type, GenericAlias], ) -> np.dtype[Any]: - if sys.version_info >= (3, 10): - if ( - isinstance(t, type) - or isinstance(t, UnionType) - or isinstance(t, annotatedRuntimeType) - ): - if (res := dtype_map.get(t, None)) is not None: - if callable(res): - raise RuntimeError( - f"Generic type arguments not provided for {t}" - ) - return res + # type_args = list(filter(lambda t: type(t) != TypeVar, get_args(t))) + origin = get_origin(t) - if isinstance(t, UnionType): - return _get_union_dtype(get_args(t)) - else: - if isinstance(t, type) or isinstance(t, annotatedRuntimeType): - if (res := dtype_map.get(t, None)) is not None: - if callable(res): - raise RuntimeError( - f"Generic type arguments not provided for {t}" - ) - return res + if origin == Union or ( + sys.version_info >= (3, 10) and isinstance(t, UnionType) + ): + return _get_union_dtype(get_args(t)) - origin = get_origin(t) - if origin == np.ndarray: - if (res := dtype_map.get(cast(GenericAlias, t), None)) is not None: - if callable(res): - raise RuntimeError(f"Unexpected generic type arguments for {t}") + # If t is found in dtype_map here, t is either a Python type + # or t is a types.GenericAlias with missing type arguments + if (res := dtype_map.get(t, None)) is not None: + if callable(res): + raise RuntimeError(f"Generic type arguments not provided for {t}") + else: return res - if origin is not None: - if (res := dtype_map.get(origin, None)) is not None: - if callable(res): - return res(get_args(t)) - - if origin == Union: - # A union specified with syntax Union[A, B] or Optional[A] - return _get_union_dtype(get_args(t)) + # Here, t is either invalid (no dtype registered) + # or t is a types.GenericAlias with type arguments specified + if origin is not None and (res := dtype_map.get(origin, None)) is not None: + if callable(res): + return res(get_args(t)) raise RuntimeError(f"Cannot find dtype for {t}")