Skip to content

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

Adds a lightweight test suite for popular attention backends. By default this won't be run on our CI.

@sayakpaul sayakpaul requested a review from DN6 September 25, 2025 11:33


FORWARD_CASES = [
("flash_hub", None),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add this once #12387 is merged.

]

COMPILE_CASES = [
("flash_hub", None, True),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add add the test slices after #12387 is merged.

@lauri9
Copy link

lauri9 commented Oct 22, 2025

Hi! I'm implementing a new attention backend and in preparation for that I tried the unit tests from this PR. I was working in an environment with a different torch version from nightly 2.10.0.dev20250924+cu128 indicated in the unit test file. In my environment the native backend produces numerically divergent results from expected, as seen in the following pytest output snippet:

_____________________________________________________________________________________________________________________ test_forward_with_compile[native] ______________________________________________________________________________________________________________________
output = FluxPipelineOutput(images=tensor([[[[0.0391, 0.0391, 0.0410,  ..., 0.2090, 0.2090, 0.2070],
          [0.0391, 0.0586,...8],
          [0.0879, 0.0801, 0.0801,  ..., 0.2930, 0.2891, 0.3184]]]],
       device='cuda:0', dtype=torch.bfloat16))
expected_slice = tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344,
        0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066],
       dtype=torch.bfloat16)

    def _check_if_slices_match(output, expected_slice):
        img = output.images.detach().cpu()
        generated_slice = img.flatten()
        generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
>       assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
E       assert False
E        +  where False = <built-in method allclose of type object at 0x7f4b31c218c0>(tensor([0.0391, 0.0391, 0.0410, 0.0488, 0.0449, 0.0566, 0.0586, 0.0566, 0.2422,\n        0.2539, 0.2656, 0.2871, 0.2969, 0.2930, 0.2891, 0.3184],\n       dtype=torch.bfloat16), tensor(
[0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344,\n        0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066],\n       dtype=torch.bfloat16), atol=0.0001)
E        +    where <built-in method allclose of type object at 0x7f4b31c218c0> = torch.allclose

tests/others/test_attention_backends.py:103: AssertionError

There were differences also in the eager mode tests.

Is it expected that the values diverge between versions? Could there be a better way to test than comparing numerical accuracy if the values are expected to vary between versions?

@sayakpaul
Copy link
Member Author

The hardware might matter too. I ran the tests on an H100, actually. The CUDA version could matter, too.

I will try to swap out exact assertion with cosine similarity-based checks which are a little more reliable and robust.

@sayakpaul sayakpaul requested a review from dg845 October 22, 2025 16:25
)
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
pytest.xfail(f"Test with {backend_name} is compatible with a higher version of torch.")

nit: typo?

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :)

@sayakpaul sayakpaul merged commit ccdd96c into main Oct 23, 2025
9 of 11 checks passed
@sayakpaul sayakpaul deleted the test-attention-backends branch October 23, 2025 09:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants