-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Open
Labels
Description
numpy.linalg.inv seems to be significantly slower than a hard-coded version for a simple test case, a (1000,4,3,3,) array.
Any pointers as to what could be going off and possible remedies, if any, to make it faster. See the following MWE:
Reproducing code example:
from numpy.linalg import inv as npinv
from numpy import zeros_like, einsum, random
from timeit import timeit
# Helper function for determinant
def vdet(A):
detA = zeros_like(A[0, 0])
detA = A[0, 0] * (A[1, 1] * A[2, 2] - A[1, 2] * A[2, 1]) -\
A[0, 1] * (A[2, 2] * A[1, 0] - A[2, 0] * A[1, 2]) +\
A[0, 2] * (A[1, 0] * A[2, 1] - A[2, 0] * A[1, 1])
return detA
# Another function for computing inverse
# using the Cayley-Hamilton corollary
def finv(A):
detA = vdet(A)
I1 = einsum("ii...", A)
I2 = -0.5 * (einsum("ik...,ki...", A, A) - I1**2)
Asq = einsum("ik...,kj...->ij...", A, A)
eye = zeros_like(A)
eye[0, 0] = 1.
eye[1, 1] = 1.
eye[2,2] = 1.
return 1./detA * (Asq - I1 * A + I2 * eye)
# Hard coded inverse
def hdinv(A):
invA = zeros_like(A)
detA = vdet(A)
invA[0, 0] = (-A[1, 2] * A[2, 1] +
A[1, 1] * A[2, 2]) / detA
invA[1, 0] = (A[1, 2] * A[2, 0] -
A[1, 0] * A[2, 2]) / detA
invA[2, 0] = (-A[1, 1] * A[2, 0] +
A[1, 0] * A[2, 1]) / detA
invA[0, 1] = (A[0, 2] * A[2, 1] -
A[0, 1] * A[2, 2]) / detA
invA[1, 1] = (-A[0, 2] * A[2, 0] +
A[0, 0] * A[2, 2]) / detA
invA[2, 1] = (A[0, 1] * A[2, 0] -
A[0, 0] * A[2, 1]) / detA
invA[0, 2] = (-A[0, 2] * A[1, 1] +
A[0, 1] * A[1, 2]) / detA
invA[1, 2] = (A[0, 2] * A[1, 0] -
A[0, 0] * A[1, 2]) / detA
invA[2, 2] = (-A[0, 1] * A[1, 0] +
A[0, 0] * A[1, 1]) / detA
return invATimings
F = random.random((3,3,1000,4))
F2 = einsum("ij...->...ij", F)
%timeit hdinv(F) # 371 µs ± 27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit finv(F) # 5.35 ms ± 661 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit npinv(F2) # 79.2 ms ± 12.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Numpy/Python version information:
python version
3.7.6 | packaged by conda-forge | (default, Jun 1 2020, 18:57:50) \n[GCC 7.5.0]
numpy version
1.19.0