Skip to content

Commit

Permalink
Fix #85 (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
naegelejd committed Oct 23, 2023
1 parent 77c0171 commit ab66994
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 38 deletions.
26 changes: 26 additions & 0 deletions python/tests/test_generated_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,30 @@ def test_get_dtype():

with pytest.raises(RuntimeError, match="Generic type arguments not provided"):
tm.get_dtype(tm.MyTuple)
with pytest.raises(RuntimeError, match="Generic type arguments not provided"):
tm.get_dtype(tm.AliasedTuple)
with pytest.raises(RuntimeError, match="Generic type arguments not provided"):
tm.get_dtype(tm.AliasedOpenGeneric)
with pytest.raises(RuntimeError, match="Generic type arguments not provided"):
tm.get_dtype(tm.AliasedGenericUnion2)

assert tm.get_dtype(tm.RecordWithAliasedGenerics) == np.dtype(
[
("my_strings", tm.get_dtype(tm.MyTuple[str, str])),
("aliased_strings", tm.get_dtype(tm.AliasedTuple[str, str])),
],
align=True,
)

assert tm.get_dtype(tm.RecordWithFixedArrays) == np.dtype(
[
("ints", tm.get_dtype(tm.Int32), (2, 3)),
("fixed_simple_record_array", tm.get_dtype(tm.SimpleRecord), (3, 2)),
("fixed_record_with_vlens_array", tm.get_dtype(tm.RecordWithVlens), (2, 2)),
],
align=True,
)

assert tm.get_dtype(tm.MyTuple[tm.Int32, tm.Float32]) == np.dtype(
[("v1", "<i4"), ("v2", "<f4")], align=True
)
Expand All @@ -80,6 +104,8 @@ def test_get_dtype():
)
assert tm.get_dtype(tm.Int32 | tm.Float32) == np.object_

assert tm.get_dtype(tm.AliasedGenericUnion2[tm.SimpleRecord, bool]) == np.object_

assert tm.get_dtype(typing.Optional[tm.SimpleRecord]) == np.dtype(
[("has_value", "?"), ("value", tm.get_dtype(tm.SimpleRecord))], align=True
)
Expand Down
56 changes: 18 additions & 38 deletions tooling/internal/python/static_files/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if sys.version_info >= (3, 10):
from types import UnionType

from typing import Any, Callable, Union, cast, get_args, get_origin
from typing import Any, Callable, Union, cast, get_args, get_origin, TypeVar
import numpy as np
from . import yardl_types as yardl

Expand Down Expand Up @@ -44,54 +44,34 @@ def make_get_dtype_func(
dtype_map[float] = np.dtype(np.float64)
dtype_map[complex] = np.dtype(np.complex128)

annotatedRuntimeType = type(yardl.Int32)

def get_dtype_impl(
dtype_map: dict[
Union[type, GenericAlias],
Union[np.dtype[Any], Callable[[tuple[type, ...]], np.dtype[Any]]],
],
t: Union[type, GenericAlias],
) -> np.dtype[Any]:
if sys.version_info >= (3, 10):
if (
isinstance(t, type)
or isinstance(t, UnionType)
or isinstance(t, annotatedRuntimeType)
):
if (res := dtype_map.get(t, None)) is not None:
if callable(res):
raise RuntimeError(
f"Generic type arguments not provided for {t}"
)
return res
# type_args = list(filter(lambda t: type(t) != TypeVar, get_args(t)))
origin = get_origin(t)

if isinstance(t, UnionType):
return _get_union_dtype(get_args(t))
else:
if isinstance(t, type) or isinstance(t, annotatedRuntimeType):
if (res := dtype_map.get(t, None)) is not None:
if callable(res):
raise RuntimeError(
f"Generic type arguments not provided for {t}"
)
return res
if origin == Union or (
sys.version_info >= (3, 10) and isinstance(t, UnionType)
):
return _get_union_dtype(get_args(t))

origin = get_origin(t)
if origin == np.ndarray:
if (res := dtype_map.get(cast(GenericAlias, t), None)) is not None:
if callable(res):
raise RuntimeError(f"Unexpected generic type arguments for {t}")
# If t is found in dtype_map here, t is either a Python type
# or t is a types.GenericAlias with missing type arguments
if (res := dtype_map.get(t, None)) is not None:
if callable(res):
raise RuntimeError(f"Generic type arguments not provided for {t}")
else:
return res

if origin is not None:
if (res := dtype_map.get(origin, None)) is not None:
if callable(res):
return res(get_args(t))

if origin == Union:
# A union specified with syntax Union[A, B] or Optional[A]
return _get_union_dtype(get_args(t))
# Here, t is either invalid (no dtype registered)
# or t is a types.GenericAlias with type arguments specified
if origin is not None and (res := dtype_map.get(origin, None)) is not None:
if callable(res):
return res(get_args(t))

raise RuntimeError(f"Cannot find dtype for {t}")

Expand Down

0 comments on commit ab66994

Please sign in to comment.