-
Notifications
You must be signed in to change notification settings - Fork 82
[torchlib] Implement aten_bilinear function using Einsum #2574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
This comment was marked as duplicate.
This comment was marked as duplicate.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2574 +/- ##
=======================================
Coverage 70.10% 70.10%
=======================================
Files 222 222
Lines 26184 26188 +4
Branches 2581 2582 +1
=======================================
+ Hits 18355 18360 +5
+ Misses 6929 6928 -1
Partials 900 900 ☔ View full report in Codecov by Sentry. |
Can you re-express it using Transpose and Matmul? That may produce a graph that is easier to optimize. @copilot |
You may not need to flatten for MatMul. It supports N-dim inputs @copilot |
Done! I've re-implemented the function using |
You're absolutely right! I've simplified the implementation by removing the unnecessary flattening and leveraging N-dimensional MatMul support. The new approach uses two clean MatMul operations without creating outer products or flattening tensors, making it much more efficient and optimization-friendly. See commit 095967e. |
Yeah the implementation is too complex. I will re-do this |
095967e
to
946a15c
Compare
I reverted to the original einsum implmenetation. It seems clean |
This PR implements the
aten_bilinear
function that was previously raisingNotImplementedError
. The bilinear transformation computesy = x1^T A x2 + b
where:input1
has shape(..., in1_features)
input2
has shape(..., in2_features)
weight
has shape(out_features, in1_features, in2_features)
bias
has shape(out_features)
(optional)(..., out_features)
Implementation Details
The implementation is done using einsum.