Skip to content

Commit

Permalink
Do proper lookup when working with a nameclass
Browse files Browse the repository at this point in the history
  • Loading branch information
gordonwatts committed Apr 27, 2024
1 parent aea1ae2 commit b0c1b8f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
7 changes: 6 additions & 1 deletion func_adl/type_based_replacement.py
Expand Up @@ -5,7 +5,7 @@
import inspect
import logging
import sys
from dataclasses import dataclass, make_dataclass
from dataclasses import dataclass, is_dataclass, make_dataclass
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -971,6 +971,11 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
raise ValueError(f"Key {key} not found in dict expression!!")
value = t_node.value.values[key_index[0]]
self._found_types[node] = self.lookup_type(value)
elif ((dc := self.lookup_type(t_node.value)) is not None) and is_dataclass(dc):
dc_types = get_type_hints(dc)
if node.attr not in dc_types:
raise ValueError(f"Key {node.attr} not found in dataclass/dictionary {dc}")
self._found_types[node] = dc_types[node.attr]
return t_node

tt = type_transformer(o_stream)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_type_based_replacement.py
Expand Up @@ -564,6 +564,20 @@ def test_dictionary_through_Select():
assert j_info.annotation == float


def test_dictionary_through_Select_reference():
"""Make sure the Select statement carries the typing all the way through,
including a later reference"""

s = ast_lambda(
"e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()}).Select(lambda info: info.pt)"
)
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))

_, _, expr_type = remap_by_types(objs, "e", Event, s)

assert expr_type == Iterable[float]


def test_indexed_tuple():
"Check that we can type-follow through tuples"

Expand Down

0 comments on commit b0c1b8f

Please sign in to comment.