-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Slice indexing in ONNX #94
Comments
The indexing of hidet tensor would follow the specification of ArrayAPI standard. Thus, we need to do deal with the difference between onnx sematics and ArrayAPI standard when importing onnx model. Could you please provide some examples to trigger this error? If not, we could leave it to the future when some actual model triggers this error. |
Yes, a very simple snippet (torch->onnx->hidet) can trigger this error:
|
Thanks @soodoshll, working on it. |
Thanks @soodoshll, this error should be fixed in #106. |
Add graph module for using flash attention and clarify some differences in flash attention vs torch sdpa. **Attention: (pun intended)** Softmax has temperature scaling option. Divides inputs by scalar, good explanation of numerical effects [here](https://medium.com/@harshit158/softmax-temperature-5492e4007f71). Used when softmax inputs QK are too big for float 16 (abs value > 65504). This usually means the numbers are so large that dividing by small (< 4) scalar has little effect. Stable diffusion does not use this, as torch spda supports float 32 (or somehow avoids NaNs from large values). No visual or significant numeric differences in this output layer noticed. Towards #57.
Please refer to pytorch/pytorch#24251
Basically, ONNX uses extremely large numbers to represent slicing until the end of certain dimensions, which will be prohibited the defensive conditions in
hidet/python/hidet/graph/ops/definitions/transform.py
Line 481 in 80a35d6
The text was updated successfully, but these errors were encountered: