Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3462,15 +3462,21 @@ def _get_test_info():
if test_frame is not None:
line_number = test_frame.lineno

# most inner (recent) to most outer () frames
# The frame of `patched` being called (the one and the only one calling `_get_test_info`)
# This is used to get the original method being patched in order to get the context.
frame_of_patched_obj = None

captured_frames = []
to_capture = False
# up to the test method being called
# From the most outer (i.e. python's `runpy.py`) frame to most inner frame (i.e. the frame of this method)
# Between `the test method being called` and `before entering `patched``.
for frame in reversed(stack_from_inspect):
if test_file in str(frame).replace(r"\\", "/"):
if "self" in frame.frame.f_locals and test_name == frame.frame.f_locals["self"]._testMethodName:
to_capture = True
elif "patched" in frame.frame.f_code.co_name:
# TODO: check simply with the name is not robust.
elif "patched" == frame.frame.f_code.co_name:
frame_of_patched_obj = frame
to_capture = False
break
if to_capture:
Expand All @@ -3482,11 +3488,17 @@ def _get_test_info():
tb_next = tb
test_traceback = tb

origin_method_being_patched = frame_of_patched_obj.frame.f_locals["orig_method"]

# An iterable of type `traceback.StackSummary` with each element of type `FrameSummary`
stack = traceback.extract_stack()
# The frame which calls `the original method being patched`
caller_frame = None
# From the most inner (i.e. recent) frame to the most outer frame
for frame in reversed(stack):
if origin_method_being_patched.__name__ in frame.line:
caller_frame = frame
Comment on lines +3499 to +3500
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, we check against the name of the original method being patched.


# The frame that calls this patched method (it may not be the test method)
# -1: `_get_test_info`; -2: `patched_xxx`; -3: the caller to `patched_xxx`
caller_frame = stack[-3]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bit hardcoded, and it turns out that this causes issue when @require_read_token is applied to a test class, where we will get

stack[-1]: _get_test_info
stack[-2]: patched
stack[-3]: wrapper (from def require_read_token)
stack[-4]: the caller to the method that is wrapped (in this case, it's self.assertEqual)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example is

FAILED tests/models/mistral/test_modeling_mistral.py::MistralIntegrationTest::test_speculative_generation - TypeError: 'NoneType' object is not subscriptable

where we have

@require_read_token
class MistralIntegrationTest(unittest.TestCase):

caller_path = os.path.relpath(caller_frame.filename)
caller_lineno = caller_frame.lineno

Expand Down