Skip to content

Commit

Permalink
Deep-Type Dictionaries (#142)
Browse files Browse the repository at this point in the history
* Dictionaries are now named tuples for typing

* Fix up flake8 error

* Add dummy test to make sure we don't forget.

* Fix up missing return

* Add test to assure Select typing

* Do proper lookup when working with a nameclass
  • Loading branch information
gordonwatts committed Apr 27, 2024
1 parent 8612740 commit 5f23cfb
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
20 changes: 19 additions & 1 deletion func_adl/type_based_replacement.py
Expand Up @@ -5,7 +5,7 @@
import inspect
import logging
import sys
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass, make_dataclass
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -923,6 +923,19 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
self._found_types[node] = Any
return node

def visit_Dict(self, node: ast.Dict) -> Any:
t_node = self.generic_visit(node)
assert isinstance(t_node, ast.Dict)

fields: List[Tuple[str, type]] = [
(ast.literal_eval(f), self.lookup_type(v)) # type: ignore
for f, v in zip(t_node.keys, t_node.values)
]
dict_dataclass = make_dataclass("dict_dataclass", fields)

self._found_types[t_node] = dict_dataclass
return t_node

def visit_Constant(self, node: ast.Constant) -> Any:
self._found_types[node] = type(node.value)
return node
Expand Down Expand Up @@ -958,6 +971,11 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
raise ValueError(f"Key {key} not found in dict expression!!")
value = t_node.value.values[key_index[0]]
self._found_types[node] = self.lookup_type(value)
elif ((dc := self.lookup_type(t_node.value)) is not None) and is_dataclass(dc):
dc_types = get_type_hints(dc)
if node.attr not in dc_types:
raise ValueError(f"Key {node.attr} not found in dataclass/dictionary {dc}")
self._found_types[node] = dc_types[node.attr]
return t_node

tt = type_transformer(o_stream)
Expand Down
54 changes: 53 additions & 1 deletion tests/test_type_based_replacement.py
@@ -1,6 +1,8 @@
import ast
import copy
import inspect
import logging
from inspect import isclass
from typing import Any, Callable, Iterable, Optional, Tuple, Type, TypeVar, cast

import pytest
Expand All @@ -14,6 +16,7 @@
remap_by_types,
remap_from_lambda,
)
from func_adl.util_types import is_iterable, unwrap_iterable


class Track:
Expand Down Expand Up @@ -504,6 +507,23 @@ def test_collection_Select(caplog):


def test_dictionary():
"Make sure that dictionaries turn into named types"

s = ast_lambda("{'jets': e.Jets()}")
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))

new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)

# Fix to look for the named class with the correct types.
assert isclass(expr_type)
sig = inspect.signature(expr_type.__init__)
assert len(sig.parameters) == 2
assert "jets" in sig.parameters
j_info = sig.parameters["jets"]
assert str(j_info.annotation) == "typing.Iterable[tests.test_type_based_replacement.Jet]"


def test_dictionary_sequence():
"Check that we can type-follow through dictionaries"

s = ast_lambda("{'jets': e.Jets()}.jets.Select(lambda j: j.pt())")
Expand All @@ -526,8 +546,40 @@ def test_dictionary_bad_key():
assert "jetsss" in str(e)


def test_dictionary_through_Select():
"""Make sure the Select statement carries the typing all the way through"""

s = ast_lambda("e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()})")
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))

_, _, expr_type = remap_by_types(objs, "e", Event, s)

assert is_iterable(expr_type)
obj_itr = unwrap_iterable(expr_type)
assert isclass(obj_itr)
sig = inspect.signature(obj_itr.__init__)
assert len(sig.parameters) == 3
assert "pt" in sig.parameters
j_info = sig.parameters["pt"]
assert j_info.annotation == float


def test_dictionary_through_Select_reference():
"""Make sure the Select statement carries the typing all the way through,
including a later reference"""

s = ast_lambda(
"e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()}).Select(lambda info: info.pt)"
)
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))

_, _, expr_type = remap_by_types(objs, "e", Event, s)

assert expr_type == Iterable[float]


def test_indexed_tuple():
"Check that we can type-follow through dictionaries"
"Check that we can type-follow through tuples"

s = ast_lambda("(e.Jets(),)[0].Select(lambda j: j.pt())")
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
Expand Down

0 comments on commit 5f23cfb

Please sign in to comment.