Skip to content

Conversation

@titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms commented Nov 30, 2023

Fix #1160
Follow up #1043 It's another scaled_dot_product_XXX_attention

It only supports CUDA.

https://github.com/pytorch/pytorch/blob/38ae17d166a001ef6837553d1ddffa111624df27/torch/_meta_registrations.py#L5195-L5236

NOTE: This PR also enables CUDA tests.

@titaiwangms titaiwangms added the module: torchlib Related to the torch/aten function lib in development label Nov 30, 2023
@codecov
Copy link

codecov bot commented Nov 30, 2023

Codecov Report

Attention: 30 lines in your changes are missing coverage. Please review.

Comparison is base (649bdff) 78.66% compared to head (05c5b0f) 78.60%.
Report is 1 commits behind head on main.

Files Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 30.00% 14 Missing ⚠️
...ript/tests/function_libs/torch_lib/extra_opinfo.py 20.00% 12 Missing ⚠️
onnxscript/evaluator.py 50.00% 1 Missing and 1 partial ⚠️
...nxscript/function_libs/torch_lib/graph_building.py 75.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1197      +/-   ##
==========================================
- Coverage   78.66%   78.60%   -0.07%     
==========================================
  Files         118      118              
  Lines       15445    15473      +28     
  Branches     2428     2431       +3     
==========================================
+ Hits        12150    12162      +12     
- Misses       2897     2915      +18     
+ Partials      398      396       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@titaiwangms
Copy link
Contributor Author

@justinchuby I tried to run the test on dev10, but it seems the tests are somehow fixed with CPU. Could you point me the direction to test it with GPU?

@justinchuby
Copy link
Collaborator

Maybe try changing

? Apparently we assumed cpu everywhere

del kwargs

make = opinfo_core.partial(
opinfo_core.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may also control device here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I enabled the whole CUDA tests if it's needed by the test.

nn_ops.aten__scaled_dot_product_efficient_attention,
trace_only=True,
tolerance={torch.float32: (3e-4, 1.5e-5)},
# Output[0] is OK, but other outputs just have the same shape with zero values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the other compare option instead

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_shape_only_for_output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@titaiwangms
Copy link
Contributor Author

@justinchuby You might want to review again. I changed the function implementation and enabled CUDA tests.

@justinchuby
Copy link
Collaborator

LGTM. Thanks!

@titaiwangms titaiwangms merged commit 3c05276 into microsoft:main Dec 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

[torchlib] Implement aten::_scaled_dot_product_efficient_attention.default

2 participants