Skip to content

Commit c60af56

Browse files
authored
fix(matchers): ensure ErrorMatching return type matches spec (#86)
1 parent dbc4813 commit c60af56

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

decoy/matchers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_logger_called(decoy: Decoy):
2424
equality comparisons (`==`) for stubbing and verification.
2525
"""
2626
from re import compile as compile_re
27-
from typing import cast, Any, List, Mapping, Optional, Pattern, Type
27+
from typing import cast, Any, List, Mapping, Optional, Pattern, Type, TypeVar
2828

2929

3030
__all__ = [
@@ -255,10 +255,10 @@ def StringMatching(match: str) -> str:
255255

256256

257257
class _ErrorMatching:
258-
_error_type: Type[Exception]
258+
_error_type: Type[BaseException]
259259
_string_matcher: Optional[_StringMatching]
260260

261-
def __init__(self, error: Type[Exception], match: Optional[str] = None) -> None:
261+
def __init__(self, error: Type[BaseException], match: Optional[str] = None) -> None:
262262
"""Initialize with the Exception type and optional message matcher."""
263263
self._error_type = error
264264
self._string_matcher = _StringMatching(match) if match is not None else None
@@ -281,7 +281,10 @@ def __repr__(self) -> str:
281281
)
282282

283283

284-
def ErrorMatching(error: Type[Exception], match: Optional[str] = None) -> Exception:
284+
ErrorT = TypeVar("ErrorT", bound=BaseException)
285+
286+
287+
def ErrorMatching(error: Type[ErrorT], match: Optional[str] = None) -> ErrorT:
285288
"""Match any error matching an Exception type and optional message matcher.
286289
287290
Arguments:
@@ -294,7 +297,7 @@ def ErrorMatching(error: Type[Exception], match: Optional[str] = None) -> Except
294297
assert ValueError("oh no!") == ErrorMatching(ValueError, match="no")
295298
```
296299
"""
297-
return cast(Exception, _ErrorMatching(error, match))
300+
return cast(ErrorT, _ErrorMatching(error, match))
298301

299302

300303
class _Captor:

tests/test_matchers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_error_matching_matcher() -> None:
9595
"""It should have an "any error that matches" matcher."""
9696
assert RuntimeError("ah!") == matchers.ErrorMatching(RuntimeError)
9797
assert RuntimeError("ah!") == matchers.ErrorMatching(RuntimeError, "ah")
98-
assert RuntimeError("ah!") != matchers.ErrorMatching(TypeError, "ah")
98+
assert RuntimeError("ah!") != matchers.ErrorMatching(TypeError, "ah") # type: ignore[comparison-overlap] # noqa: E501
9999
assert RuntimeError("ah!") != matchers.ErrorMatching(RuntimeError, "ah$")
100100

101101

tests/typing/test_typing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,5 @@
130130
main:4: note: Revealed type is "Any"
131131
main:5: note: Revealed type is "Any"
132132
main:6: note: Revealed type is "builtins.str"
133-
main:7: note: Revealed type is "builtins.Exception"
133+
main:7: note: Revealed type is "builtins.RuntimeError*"
134134
main:8: note: Revealed type is "Any"

0 commit comments

Comments
 (0)