Skip to content

Commit

Permalink
Add hypothesis tests for type parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
dflook committed Sep 30, 2023
1 parent 981e99e commit f34ebbd
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
31 changes: 27 additions & 4 deletions hypo_test/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,23 +168,44 @@ def ImportFrom(draw) -> ast.ImportFrom:
names=draw(lists(alias(), min_size=1, max_size=3)),
level=draw(integers(min_value=0, max_value=2)))

@composite
def TypeVar(draw) -> ast.TypeVar:
return ast.TypeVar(name=draw(name()),
bound=draw(none() | expression()))

@composite
def TypeVarTuple(draw) -> ast.TypeVarTuple:
return ast.TypeVarTuple(name=draw(name()))

@composite
def ParamSpec(draw) -> ast.ParamSpec:
return ast.ParamSpec(name=draw(name()))

@composite
def TypeAlias(draw) -> ast.TypeAlias:
return ast.TypeAlias(name=draw(Name(ast.Store)),
value=draw(expression()),
type_params=draw(lists(one_of(TypeVar(), TypeVarTuple(), ParamSpec()), min_size=0, max_size=3)))

@composite
def FunctionDef(draw, statements) -> ast.FunctionDef:
n = draw(name())
args = draw(arguments())
body = draw(lists(statements, min_size=1, max_size=3))
decorator_list = draw(lists(Name(), min_size=0, max_size=2))
type_params = draw(none() | lists(one_of(TypeVar(), TypeVarTuple(), ParamSpec()), min_size=0, max_size=3))
returns = draw(none() | expression())
return ast.FunctionDef(n, args, body, decorator_list, returns)
return ast.FunctionDef(n, args, body, decorator_list, returns, type_params=type_params)

@composite
def AsyncFunctionDef(draw, statements) -> ast.AsyncFunctionDef:
n = draw(name())
args = draw(arguments())
body = draw(lists(statements, min_size=1, max_size=3))
decorator_list = draw(lists(Name(), min_size=0, max_size=2))
type_params = draw(none() | lists(one_of(TypeVar(), TypeVarTuple(), ParamSpec()), min_size=0, max_size=3))
returns = draw(none() | expression())
return ast.AsyncFunctionDef(n, args, body, decorator_list, returns)
return ast.AsyncFunctionDef(n, args, body, decorator_list, returns, type_params=type_params)

@composite
def keyword(draw) -> ast.keyword:
Expand All @@ -208,7 +229,8 @@ def ClassDef(draw, statements) -> ast.ClassDef:
bases=bases,
keywords=keywords,
body=body,
decorator_list=decorator_list
decorator_list=decorator_list,
type_params=draw(none() | lists(one_of(TypeVar(), TypeVarTuple(), ParamSpec()), min_size=0, max_size=3))
)

if hasattr(ast, 'Print'):
Expand Down Expand Up @@ -244,7 +266,8 @@ def ClassDef(draw, statements) -> ast.ClassDef:
AnnAssign(),
AugAssign(),
Import(),
ImportFrom()
ImportFrom(),
TypeAlias()
)

def suite() -> SearchStrategy:
Expand Down
45 changes: 44 additions & 1 deletion hypo_test/test_it.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from python_minifier.ast_printer import print_ast
from python_minifier.expression_printer import ExpressionPrinter
from expressions import Expression
from module import Module
from module import Module, TypeAlias
from python_minifier.rename.mapper import add_parent
from python_minifier.transforms.constant_folding import FoldConstants

Expand Down Expand Up @@ -72,3 +72,46 @@ def test_folding(node):

# The constant folder asserts the value is correct
constant_folder(node)

@given(node=TypeAlias())
@settings(report_multiple_bugs=False, deadline=timedelta(seconds=2), max_examples=1000, verbosity=Verbosity.verbose)
def test_type_alias(node):

module = ast.Module(
body=[node],
type_ignores=[]
)

printer = ModulePrinter()
code = printer(module)
note(code)
compare_ast(module, ast.parse(code, 'test_type_alias'))

@given(node=TypeAlias())
@settings(report_multiple_bugs=False, deadline=timedelta(seconds=2), max_examples=1000, verbosity=Verbosity.verbose)
def test_function_type_param(node):

module = ast.Module(
body=[ast.FunctionDef(
name='test',
args=ast.arguments(
posonlyargs=[],
args=[],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[],
),
body=[ast.Pass()],
type_params=node.type_params,
decorator_list=[],
returns=None
)],
type_ignores=[]
)

printer = ModulePrinter()
code = printer(module)
note(code)
compare_ast(module, ast.parse(code, 'test_function_type_param'))

0 comments on commit f34ebbd

Please sign in to comment.