-
Notifications
You must be signed in to change notification settings - Fork 908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Metal inverse (mx.linalg.inv
)
#1238
Comments
Also not working in mlx-swift, which is where I'm using it from but reproduced in python so filed it here. The CPU backend appears to work though. |
mx.linalg.inv
)
Yes, this isn't a bug, the GPU back-end is not yet implemented. It's most likely going to take some time before we have GPU support for matrix inversion. I changed this to be a feature req rather than a bug, and we can leave the issue open. |
My recommendation is to use the CPU for now. You can do something like: out = mx.llinalg.inv(x, stream=mx.cpu) Just for that operation. |
By the way if all you want to do is 3x3 matrix inversion it is way faster to write it explicitly and compile it with mlx. The inversion would be as simple as the following: import mlx.core as mx
@partial(mx.compile, shapeless=True)
def _inverse_3x3(a11, a12, a13, a21, a22, a23, a31, a32, a33):
det = (
a11 * a22 * a33
+ a12 * a23 * a31
+ a13 * a21 * a32
- a11 * a23 * a32
- a12 * a21 * a33
- a13 * a22 * a31
)
c11 = (a22 * a33 - a23 * a32) / det
c12 = (a13 * a32 - a12 * a33) / det
c13 = (a12 * a23 - a13 * a22) / det
c21 = (a23 * a31 - a21 * a33) / det
c22 = (a11 * a33 - a13 * a31) / det
c23 = (a13 * a21 - a11 * a23) / det
c31 = (a21 * a32 - a22 * a31) / det
c32 = (a12 * a31 - a11 * a32) / det
c33 = (a11 * a22 - a12 * a21) / det
return c11, c12, c13, c21, c22, c23, c31, c32, c33
def inverse_3x3(A):
shape = A.shape
return mx.concatenate(
_inverse_3x3(*mx.split(A.reshape(*shape[:-2], -1), 9, -1)), -1
).reshape(shape) For inverting thousands of 3x3 matrices the improvement over CPU is pretty great on my puny M2 Air:
For a single matrix obviously using the GPU is overkill but if you want to do 3x3 matmuls for instance writing them out explicitly like I did above may be significantly faster, same goes for triangle intersection math etc. |
Describe the bug
When trying to invert a small 3x3 (camera intrinsics matrix), mlx crashes.
To Reproduce
Expected behavior
Works properly in numpy
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: