Skip to content

numpy.linalg.inv very slow for a stack of 3x3 arrays #17166

@bhaveshshrimali

Description

@bhaveshshrimali

numpy.linalg.inv seems to be significantly slower than a hard-coded version for a simple test case, a (1000,4,3,3,) array.
Any pointers as to what could be going off and possible remedies, if any, to make it faster. See the following MWE:

Reproducing code example:

from numpy.linalg import inv as npinv
from numpy import zeros_like, einsum, random
from timeit import timeit

# Helper function for determinant
def vdet(A):
    detA = zeros_like(A[0, 0])
    detA = A[0, 0] * (A[1, 1] * A[2, 2] - A[1, 2] * A[2, 1]) -\
           A[0, 1] * (A[2, 2] * A[1, 0] - A[2, 0] * A[1, 2]) +\
        A[0, 2] * (A[1, 0] * A[2, 1] - A[2, 0] * A[1, 1])
    return detA

# Another function for computing inverse
# using the Cayley-Hamilton corollary
def finv(A):
    detA = vdet(A)
    I1 = einsum("ii...", A)
    I2 = -0.5 * (einsum("ik...,ki...", A, A) - I1**2)
    Asq = einsum("ik...,kj...->ij...", A, A)
    eye = zeros_like(A)
    eye[0, 0] = 1.
    eye[1, 1] = 1.
    eye[2,2] = 1.
    return 1./detA * (Asq - I1 * A + I2 * eye)

# Hard coded inverse
def hdinv(A):
    invA = zeros_like(A)
    detA = vdet(A)

    invA[0, 0] = (-A[1, 2] * A[2, 1] +
                  A[1, 1] * A[2, 2]) / detA
    invA[1, 0] = (A[1, 2] * A[2, 0] -
                  A[1, 0] * A[2, 2]) / detA
    invA[2, 0] = (-A[1, 1] * A[2, 0] +
                  A[1, 0] * A[2, 1]) / detA
    invA[0, 1] = (A[0, 2] * A[2, 1] -
                  A[0, 1] * A[2, 2]) / detA
    invA[1, 1] = (-A[0, 2] * A[2, 0] +
                  A[0, 0] * A[2, 2]) / detA
    invA[2, 1] = (A[0, 1] * A[2, 0] -
                  A[0, 0] * A[2, 1]) / detA
    invA[0, 2] = (-A[0, 2] * A[1, 1] +
                  A[0, 1] * A[1, 2]) / detA
    invA[1, 2] = (A[0, 2] * A[1, 0] -
                  A[0, 0] * A[1, 2]) / detA
    invA[2, 2] = (-A[0, 1] * A[1, 0] +
                  A[0, 0] * A[1, 1]) / detA
    return invA

Timings

F = random.random((3,3,1000,4))
F2 = einsum("ij...->...ij", F)

%timeit hdinv(F) # 371 µs ± 27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit finv(F) # 5.35 ms ± 661 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit npinv(F2) # 79.2 ms ± 12.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Numpy/Python version information:

python version

3.7.6 | packaged by conda-forge | (default, Jun 1 2020, 18:57:50) \n[GCC 7.5.0]

numpy version

1.19.0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions