In [1]:
import cython

In [2]:
%load_ext Cython

In [92]:
%%cython

import cython
cimport numpy as np
import numpy as np
from scipy.linalg.cython_blas cimport dgemm

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[::1, :] mm3(double[::1, :] A, double[::1, :] B, double[::1, :] C):
    """
    
    """
    
    # step 1: write product of A and B into B
    
    cdef int m = A.shape[0]
    cdef int ka = A.shape[1]
    cdef int kb = B.shape[0]
    cdef int nb = B.shape[1]
    cdef int nc = C.shape[0]
    cdef int p = C.shape[1]
    
    cdef int *LDA = &m
    cdef int *LDB = &kb
    
    cdef double alpha = 1.0
    cdef double beta = 0.0
    
    #assert ka == kb, "shape mismatch between A and B"
    #assert nb == nc, "shape mismatch between B and C"
    
    cdef double[::1, :] AB = np.empty((m, nb), dtype=np.float64, order='F')
    cdef double[::1, :] ABC = np.empty((m, p), dtype=np.float64, order='F')
    
    # step 1: write product of A and B into B
    
    dgemm("n", "n", &m, &nb, &ka, &alpha, &A[0,0], LDA, &B[0,0], LDB, &beta, &AB[0,0], LDA)
    
    # step 2: write product of AB (=B) into C
    
    dgemm("n", "n", &m, &p, &nc, &alpha, &AB[0,0], LDA, &C[0,0], LDB, &beta, &ABC[0,0], LDB) 
    
    return ABC

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[::1, :] mm3_loop(double[::1, :] A, double[::1, :] B, double[::1, :] C, int loops):
    
    for i in range(loops):
        mm3(A, B, C)
    
    return mm3(A, B, C)

In file included from /Users/nielsota/.ipython/cython/_cython_magic_4708d4c54aeb17eec46d14da12811b0a.c:750:
In file included from /Users/nielsota/Downloads/Niels/Academic/Code/CythonMethods/venv/lib/python3.10/site-packages/numpy/core/include/numpy/arrayobject.h:5:
In file included from /Users/nielsota/Downloads/Niels/Academic/Code/CythonMethods/venv/lib/python3.10/site-packages/numpy/core/include/numpy/ndarrayobject.h:12:
In file included from /Users/nielsota/Downloads/Niels/Academic/Code/CythonMethods/venv/lib/python3.10/site-packages/numpy/core/include/numpy/ndarraytypes.h:1948:
 ^


In [79]:
def mm3_loop_python(A, B, C, loops):
    
    for i in range(loops):
        np.matmul(A1, np.matmul(A2, A3))
    
    return np.matmul(A1, np.matmul(A2, A3))

In [90]:
A1 = 2 * np.eye(10, dtype=np.float64, order='F')
A2 = np.eye(10, dtype=np.float64, order='F')
A3 = 6 * np.eye(10, dtype=np.float64, order='F')

In [72]:
np.asarray(mm3(A1, A2, A3))

array([[12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0., 12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0., 12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0., 12.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0., 12.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0., 12.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0., 12.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0., 12.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 12.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 12.]])

In [73]:
%timeit mm3(A1, A2, A3)

1.78 µs ± 3.88 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [68]:
%timeit np.matmul(A1, np.matmul(A2, A3))

1.53 µs ± 1.87 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [95]:
%timeit mm3_loop(A1, A2, A3, 1000000)

1.17 s ± 8.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [96]:
%timeit mm3_loop_python(A1, A2, A3, 1000000)

1.49 s ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
