@@ -24,7 +24,7 @@ def test_logger_called(decoy: Decoy):
24
24
equality comparisons (`==`) for stubbing and verification.
25
25
"""
26
26
from re import compile as compile_re
27
- from typing import cast , Any , List , Mapping , Optional , Pattern , Type
27
+ from typing import cast , Any , List , Mapping , Optional , Pattern , Type , TypeVar
28
28
29
29
30
30
__all__ = [
@@ -255,10 +255,10 @@ def StringMatching(match: str) -> str:
255
255
256
256
257
257
class _ErrorMatching :
258
- _error_type : Type [Exception ]
258
+ _error_type : Type [BaseException ]
259
259
_string_matcher : Optional [_StringMatching ]
260
260
261
- def __init__ (self , error : Type [Exception ], match : Optional [str ] = None ) -> None :
261
+ def __init__ (self , error : Type [BaseException ], match : Optional [str ] = None ) -> None :
262
262
"""Initialize with the Exception type and optional message matcher."""
263
263
self ._error_type = error
264
264
self ._string_matcher = _StringMatching (match ) if match is not None else None
@@ -281,7 +281,10 @@ def __repr__(self) -> str:
281
281
)
282
282
283
283
284
- def ErrorMatching (error : Type [Exception ], match : Optional [str ] = None ) -> Exception :
284
+ ErrorT = TypeVar ("ErrorT" , bound = BaseException )
285
+
286
+
287
+ def ErrorMatching (error : Type [ErrorT ], match : Optional [str ] = None ) -> ErrorT :
285
288
"""Match any error matching an Exception type and optional message matcher.
286
289
287
290
Arguments:
@@ -294,7 +297,7 @@ def ErrorMatching(error: Type[Exception], match: Optional[str] = None) -> Except
294
297
assert ValueError("oh no!") == ErrorMatching(ValueError, match="no")
295
298
```
296
299
"""
297
- return cast (Exception , _ErrorMatching (error , match ))
300
+ return cast (ErrorT , _ErrorMatching (error , match ))
298
301
299
302
300
303
class _Captor :
0 commit comments