Skip to content

Non unit-stride ndarrays #16

@mratsim

Description

@mratsim

Looking into the code I'm pretty sure the code is buggy for non-unit stride ndarrays such as those resulting from slicing, reverse-slicing or broadcasting:

cython-blis/blis/py.pyx

Lines 64 to 102 in c5df079

def gemm(const_reals2d_ft A, const_reals2d_ft B,
np.ndarray out=None, bint trans1=False, bint trans2=False,
double alpha=1., double beta=1.):
cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
if const_reals2d_ft is const_float2d_t:
if out is None:
out = numpy.zeros((nM, nN), dtype='f')
C = <float*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
nM, nN, nK,
alpha,
&A[0,0], A.shape[1], 1,
&B[0,0], B.shape[1], 1,
beta,
C, out.shape[1], 1)
return out
elif const_reals2d_ft is const_double2d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
C = <double*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
A.shape[0], B.shape[1], A.shape[1],
alpha,
&A[0,0], A.shape[1], 1,
&B[0,0], B.shape[1], 1,
beta,
C, out.shape[1], 1)
return out
else:
C = NULL
raise TypeError("Unhandled fused type")

There is no check for row-major inputs but this &A[0,0], A.shape[1], 1 assumes row-major layout.

Instead the code should probably be:

def gemm(const_reals2d_ft A, const_reals2d_ft B,
         np.ndarray out=None, bint trans1=False, bint trans2=False,
         double alpha=1., double beta=1.):
    cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
    cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
    cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
    if const_reals2d_ft is const_float2d_t:
        if out is None:
            out = numpy.zeros((nM, nN), dtype='f')
        C = <float*>out.data
        with nogil:
            cy.gemm(
                cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
                cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
                nM, nN, nK,
                alpha,
                &A[0,0], A.strides[0], A.strides[1],
                &B[0,0], B.strides[0], B.strides[1],
                beta,
                C, out.strides[0], out.strides[1])
        return out
    elif const_reals2d_ft is const_double2d_t:
        if out is None:
            out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
        C = <double*>out.data
        with nogil:
            cy.gemm(
                cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
                cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
                A.shape[0], B.shape[1], A.shape[1],
                alpha,
                &A[0,0], A.strides[0], A.strides[1],
                &B[0,0], B.strides[0], B.strides[1],
                beta,
                C, , out.strides[0], out.strides[1])
        return out
    else:
        C = NULL
        raise TypeError("Unhandled fused type")

same thing for gemv.

This has several advantages:

  • works for any strides
  • faster than default OpenBLAS/MKL as there is no conversion to contiguous array needed.

The main draw of the BLIS API is supporting strided arrays without giving up performance, this is the perfect use-case.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions