From f9cadc0c30a4010ac4f98b09e1a9cecb71577262 Mon Sep 17 00:00:00 2001 From: William Wang Date: Tue, 16 Aug 2022 10:04:57 -0700 Subject: [PATCH] Handle type validation for string type hints (#333) Summary: With `from __future__ import annotations`, all type hints become strings. The current type validator assumes that what is being checked is a type and not a string, so no types are being validated in these cases. The fix is to use Python's typing.get_type_hints to evaluate the type hints at runtime - [x] Added tests, if you've added code that should be tested - [x] Ensured the test suite passes - [x] Made sure your code lints - [x] Completed the Contributor License Agreement ("CLA") Pull Request resolved: https://github.com/facebook/TestSlide/pull/333 Test Plan: Imported from GitHub, without a `Test Plan:` line. New unit tests were added to check that function parameters and return types were being interpreted correctly when annotated as a string (which is the case with from __future__ import annotations) Reviewed By: deathowl Differential Revision: D35778959 Pulled By: williamlw999-fb fbshipit-source-id: 20ad9fbbb22c8189a9ba650729b99f5c59a24a16 --- tests/mock_callable_testslide.py | 67 ++++++++++++++++++++++++++++++++ tests/sample_module.py | 6 +++ testslide/lib.py | 21 +++++++--- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/tests/mock_callable_testslide.py b/tests/mock_callable_testslide.py index 419f4be..fda46ba 100644 --- a/tests/mock_callable_testslide.py +++ b/tests/mock_callable_testslide.py @@ -184,6 +184,36 @@ def passes_with_invalid_argument_type(self): } self.callable_target(*call_args, **call_kwargs) + @context.example + def passes_with_valid_str_types(self): + args = ( + "str val", + 1234, + {"key1": "string", "key2": 4321}, + ) + kwargs = {"kwarg1": 1234} + self.mock_callable( + sample_module, "instance_method_with_str_types" + ).for_call(*args, **kwargs).to_return_value("hello") + sample_module.instance_method_with_str_types( + *args, **kwargs + ) + + @context.example + def raises_TypeCheckError_for_invalid_str_types(self): + args = (1234, 1234, 1234) + kwargs = {"kwarg1": "str val"} + self.mock_callable( + sample_module, "instance_method_with_str_types" + ).for_call(*args, **kwargs).to_return_value("hello") + with self.assertRaisesRegex( + TypeCheckError, + r"(?ms)type of arg1 must be str.*type of arg3 must be a dict.*", + ): + sample_module.instance_method_with_str_types( + *args, **kwargs + ) + if has_return_value: @context.sub_context @@ -206,6 +236,43 @@ def raises_TypeCheckError(self): *self.call_args, **self.call_kwargs ) + @context.example + def passes_with_valid_str_return_types(self): + args = ( + "str val", + 1234, + {"key1": "string", "key2": 4321}, + ) + kwargs = {"kwarg1": 1234} + self.mock_callable( + sample_module, "instance_method_with_str_types" + ).to_return_value("hello") + sample_module.instance_method_with_str_types( + *args, **kwargs + ) + + @context.example + def raises_TypeCheckError_for_invalid_str_return_types( + self, + ): + args = ( + "str val", + 1234, + {"key1": "string", "key2": 4321}, + ) + kwargs = {"kwarg1": 1234} + self.mock_callable( + sample_module, "instance_method_with_str_types" + ).to_return_value(1234) + with self.assertRaisesRegex( + TypeCheckError, + r"(?ms)type of return must be one of \(str, NoneType\); " + "got int instead: 1234.*", + ): + sample_module.instance_method_with_str_types( + *args, **kwargs + ) + @context.sub_context(".for_call(*args, **kwargs)") def for_call_args_kwargs(context): @context.sub_context diff --git a/tests/sample_module.py b/tests/sample_module.py index f4fd38e..df5ae6b 100644 --- a/tests/sample_module.py +++ b/tests/sample_module.py @@ -136,6 +136,12 @@ def test_function_returns_coroutine( return async_test_function(arg1, arg2, kwarg1, kwarg2) +def instance_method_with_str_types( + arg1: "str", arg2: "Any", arg3: "UnionArgType", kwarg1: "int" +) -> "Optional[str]": + return "original response" + + UnionArgType = Dict[str, Union[str, int]] diff --git a/testslide/lib.py b/testslide/lib.py index d0a35c0..a5f824c 100644 --- a/testslide/lib.py +++ b/testslide/lib.py @@ -11,7 +11,17 @@ from functools import wraps from inspect import Traceback from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Optional, + Tuple, + Type, + Union, + get_type_hints, +) from unittest.mock import Mock import typeguard @@ -245,6 +255,7 @@ def _validate_callable_arg_types( kwargs: Dict[str, Any], ) -> None: argspec = inspect.getfullargspec(callable_template) + type_hints = get_type_hints(callable_template) idx_offset = 1 if skip_first_arg else 0 type_errors = [] for idx in range(0, len(args)): @@ -255,7 +266,7 @@ def _validate_callable_arg_types( raise TypeError("Extra argument given: ", repr(args[idx])) argname = argspec.args[idx + idx_offset] try: - expected_type = argspec.annotations.get(argname) + expected_type = type_hints.get(argname) if not expected_type: continue @@ -265,7 +276,7 @@ def _validate_callable_arg_types( for argname, value in kwargs.items(): try: - expected_type = argspec.annotations.get(argname) + expected_type = type_hints.get(argname) if not expected_type: continue @@ -359,10 +370,10 @@ def _validate_return_type( unwrap_template_awaitable: bool = False, ) -> None: try: - argspec = inspect.getfullargspec(template) + type_hints = get_type_hints(template) + expected_type = type_hints.get("return") except TypeError: return - expected_type = argspec.annotations.get("return") if expected_type: if unwrap_template_awaitable: type_origin = get_origin(expected_type)