# Strassen matrix multiplication algorithm

In [1]:
import numpy as np
from numba import njit

### Test data

In [2]:
n = 1024
A = np.random.randn(n, n)
B = np.random.randn(n, n)

### Benchmark

In [3]:
%timeit A @ B

10 loops, best of 3: 21.3 ms per loop


### Naive IKJ multiplication

In [4]:
@njit
def matrix_multiply(A, B):
   
    n = A.shape[0]
    C = np.zeros_like(A)

    for i in range(n):
        for k in range(n):       
            for j in range(n):
                C[i][j] += A[i][k] * B[k][j]
    return C

In [5]:
%timeit matrix_multiply(A, B)

1 loop, best of 3: 1.05 s per loop


In [6]:
np.allclose(A @ B, matrix_multiply(A, B))

True

### Naive strassen

In [7]:
@njit
def matrix_merge(A, B, C, D):

    n = A.shape[0]
    result = np.empty(shape=(2*n, 2*n))

    for i in range(n):
        for j in range(n):
            result[i][j] = A[i][j]
            result[i][j + n] = B[i][j]
            result[i + n][j] = C[i][j]
            result[i + n][j + n] = D[i][j]

    return result


@njit(parallel=True)
def matrix_split(M):
  
    n = M.shape[0] // 2
 
    A = np.empty(shape=(n, n))
    B = np.empty_like(A)
    C = np.empty_like(A)
    D = np.empty_like(A)
    
    for i in range(n):    
        for j in range(n):
            A[i][j] = M[i][j]
            B[i][j] = M[i][j + n]
            C[i][j] = M[i + n][j]
            D[i][j] = M[i + n][j + n]

    return A, B, C, D

@njit
def matrix_add(A, B):

    n = A.shape[0]
    
    C = np.empty(shape=(n, n))
    for i in range(n):
        for j in range(n):
            C[i][j] = A[i][j] + B[i][j]
            
    return C


@njit
def matrix_subtract(A, B):

    n = A.shape[0]
    
    C = np.empty(shape=(n, n))
    for i in range(n):
        for j in range(n):
            C[i][j] = A[i][j] - B[i][j]
            
    return C


@njit
def strassen_naive(A, B):
    
    m = A.shape[0]

    A11, A12, A21, A22 = matrix_split(A)
    B11, B12, B21, B22 = matrix_split(B)
    S1 = matrix_subtract(B12, B22)
    S2 = matrix_add(A11, A12)
    S3 = matrix_add(A21, A22)
    S4 = matrix_subtract(B21, B11)
    S5 = matrix_add(A11, A22)
    S6 = matrix_add(B11,B22)
    S7 = matrix_subtract(A12, A22)
    S8 = matrix_add(B21, B22)
    S9 = matrix_subtract(A11, A21)
    S10 = matrix_add(B11, B12)

    cutoff = 256

    if m > cutoff:
        P1 = strassen_naive(A11, S1)
        P2 = strassen_naive(S2, B22)        
        P3 = strassen_naive(S3, B11)        
        P4 = strassen_naive(A22, S4)        
        P5 = strassen_naive(S5, S6)        
        P6 = strassen_naive(S7, S8)        
        P7 = strassen_naive(S9, S10)        
    else:
        P1 = matrix_multiply(A11, S1)
        P2 = matrix_multiply(S2, B22)
        P3 = matrix_multiply(S3, B11)
        P4 = matrix_multiply(A22, S4)
        P5 = matrix_multiply(S5, S6)
        P6 = matrix_multiply(S7, S8)
        P7 = matrix_multiply(S9, S10)

    C11 = matrix_add(matrix_add(P5, P6), matrix_subtract(P4, P2))
    C12 = matrix_add(P1, P2)
    C21 = matrix_add(P3, P4)
    C22 = matrix_subtract(matrix_add(P5, P1), matrix_add(P3, P7))

    return matrix_merge(C11, C12, C21, C22)


In [11]:
%timeit strassen_naive(A, B)

1 loop, best of 3: 225 ms per loop


In [9]:
np.allclose(A @ B, strassen_naive(A, B))

True