diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 15b32c5fe45c..3002a6ee0cf0 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -3462,15 +3462,21 @@ def _get_test_info(): if test_frame is not None: line_number = test_frame.lineno - # most inner (recent) to most outer () frames + # The frame of `patched` being called (the one and the only one calling `_get_test_info`) + # This is used to get the original method being patched in order to get the context. + frame_of_patched_obj = None + captured_frames = [] to_capture = False - # up to the test method being called + # From the most outer (i.e. python's `runpy.py`) frame to most inner frame (i.e. the frame of this method) + # Between `the test method being called` and `before entering `patched``. for frame in reversed(stack_from_inspect): if test_file in str(frame).replace(r"\\", "/"): if "self" in frame.frame.f_locals and test_name == frame.frame.f_locals["self"]._testMethodName: to_capture = True - elif "patched" in frame.frame.f_code.co_name: + # TODO: check simply with the name is not robust. + elif "patched" == frame.frame.f_code.co_name: + frame_of_patched_obj = frame to_capture = False break if to_capture: @@ -3482,11 +3488,17 @@ def _get_test_info(): tb_next = tb test_traceback = tb + origin_method_being_patched = frame_of_patched_obj.frame.f_locals["orig_method"] + + # An iterable of type `traceback.StackSummary` with each element of type `FrameSummary` stack = traceback.extract_stack() + # The frame which calls `the original method being patched` + caller_frame = None + # From the most inner (i.e. recent) frame to the most outer frame + for frame in reversed(stack): + if origin_method_being_patched.__name__ in frame.line: + caller_frame = frame - # The frame that calls this patched method (it may not be the test method) - # -1: `_get_test_info`; -2: `patched_xxx`; -3: the caller to `patched_xxx` - caller_frame = stack[-3] caller_path = os.path.relpath(caller_frame.filename) caller_lineno = caller_frame.lineno