From b0c1b8fdd414ffad50ec1a8fe9e38db0a391c74a Mon Sep 17 00:00:00 2001 From: Gordon Watts Date: Sat, 27 Apr 2024 00:05:58 -0700 Subject: [PATCH] Do proper lookup when working with a nameclass --- func_adl/type_based_replacement.py | 7 ++++++- tests/test_type_based_replacement.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/func_adl/type_based_replacement.py b/func_adl/type_based_replacement.py index 043fe04..1ccc51d 100644 --- a/func_adl/type_based_replacement.py +++ b/func_adl/type_based_replacement.py @@ -5,7 +5,7 @@ import inspect import logging import sys -from dataclasses import dataclass, make_dataclass +from dataclasses import dataclass, is_dataclass, make_dataclass from typing import ( Any, Callable, @@ -971,6 +971,11 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: raise ValueError(f"Key {key} not found in dict expression!!") value = t_node.value.values[key_index[0]] self._found_types[node] = self.lookup_type(value) + elif ((dc := self.lookup_type(t_node.value)) is not None) and is_dataclass(dc): + dc_types = get_type_hints(dc) + if node.attr not in dc_types: + raise ValueError(f"Key {node.attr} not found in dataclass/dictionary {dc}") + self._found_types[node] = dc_types[node.attr] return t_node tt = type_transformer(o_stream) diff --git a/tests/test_type_based_replacement.py b/tests/test_type_based_replacement.py index 30e3f6a..ca39992 100644 --- a/tests/test_type_based_replacement.py +++ b/tests/test_type_based_replacement.py @@ -564,6 +564,20 @@ def test_dictionary_through_Select(): assert j_info.annotation == float +def test_dictionary_through_Select_reference(): + """Make sure the Select statement carries the typing all the way through, + including a later reference""" + + s = ast_lambda( + "e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()}).Select(lambda info: info.pt)" + ) + objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) + + _, _, expr_type = remap_by_types(objs, "e", Event, s) + + assert expr_type == Iterable[float] + + def test_indexed_tuple(): "Check that we can type-follow through tuples"