Skip to content

Commit

Permalink
Limit arbitrary captures
Browse files Browse the repository at this point in the history
  • Loading branch information
gordonwatts committed Oct 12, 2023
1 parent 2c4e2c7 commit 681383f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
7 changes: 7 additions & 0 deletions func_adl/util_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,13 @@ def visit_Name(self, node: ast.Name) -> Any:
# Modules should be sent on down to be dealt with by the
# backend.
if not isinstance(v, ModuleType):
legal_capture_types = [str, int, float, bool, complex, str, bytes]
if type(v) not in legal_capture_types:
raise ValueError(
f"Do not know how to capture data type '{type(v).__name__}' for "
f"variable '{node.id}' - only {', '.join([c.__name__ for c in legal_capture_types])} are "
"supported."
)
return as_literal(v)
return node

Expand Down
13 changes: 13 additions & 0 deletions tests/test_util_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,19 @@ def test_parse_lambda_capture_nested_local():
assert ast.dump(r) == ast.dump(r_true)


def test_sensible_error_with_bad_variable_capture():
class bogus:
def __init__(self):
self.my_var = 10

my_var = bogus()

with pytest.raises(ValueError) as e:
parse_as_ast(lambda x: x > my_var)

assert "my_var" in str(e)


def test_parse_simple_func():
"A oneline function defined at local scope"

Expand Down

0 comments on commit 681383f

Please sign in to comment.