-
Notifications
You must be signed in to change notification settings - Fork 44
Open
Labels
Description
Looking into the code I'm pretty sure the code is buggy for non-unit stride ndarrays such as those resulting from slicing, reverse-slicing or broadcasting:
Lines 64 to 102 in c5df079
| def gemm(const_reals2d_ft A, const_reals2d_ft B, | |
| np.ndarray out=None, bint trans1=False, bint trans2=False, | |
| double alpha=1., double beta=1.): | |
| cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1] | |
| cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0] | |
| cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0] | |
| if const_reals2d_ft is const_float2d_t: | |
| if out is None: | |
| out = numpy.zeros((nM, nN), dtype='f') | |
| C = <float*>out.data | |
| with nogil: | |
| cy.gemm( | |
| cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE, | |
| cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE, | |
| nM, nN, nK, | |
| alpha, | |
| &A[0,0], A.shape[1], 1, | |
| &B[0,0], B.shape[1], 1, | |
| beta, | |
| C, out.shape[1], 1) | |
| return out | |
| elif const_reals2d_ft is const_double2d_t: | |
| if out is None: | |
| out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d') | |
| C = <double*>out.data | |
| with nogil: | |
| cy.gemm( | |
| cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE, | |
| cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE, | |
| A.shape[0], B.shape[1], A.shape[1], | |
| alpha, | |
| &A[0,0], A.shape[1], 1, | |
| &B[0,0], B.shape[1], 1, | |
| beta, | |
| C, out.shape[1], 1) | |
| return out | |
| else: | |
| C = NULL | |
| raise TypeError("Unhandled fused type") |
There is no check for row-major inputs but this &A[0,0], A.shape[1], 1 assumes row-major layout.
Instead the code should probably be:
def gemm(const_reals2d_ft A, const_reals2d_ft B,
np.ndarray out=None, bint trans1=False, bint trans2=False,
double alpha=1., double beta=1.):
cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
if const_reals2d_ft is const_float2d_t:
if out is None:
out = numpy.zeros((nM, nN), dtype='f')
C = <float*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
nM, nN, nK,
alpha,
&A[0,0], A.strides[0], A.strides[1],
&B[0,0], B.strides[0], B.strides[1],
beta,
C, out.strides[0], out.strides[1])
return out
elif const_reals2d_ft is const_double2d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
C = <double*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
A.shape[0], B.shape[1], A.shape[1],
alpha,
&A[0,0], A.strides[0], A.strides[1],
&B[0,0], B.strides[0], B.strides[1],
beta,
C, , out.strides[0], out.strides[1])
return out
else:
C = NULL
raise TypeError("Unhandled fused type")same thing for gemv.
This has several advantages:
- works for any strides
- faster than default OpenBLAS/MKL as there is no conversion to contiguous array needed.
The main draw of the BLIS API is supporting strided arrays without giving up performance, this is the perfect use-case.
Reactions are currently unavailable