In [1]:
import numpy as np

In [2]:
def split(matrix):
    row, col = matrix.shape
    row2, col2 = row // 2, col // 2
    return matrix[:row2, :col2], matrix[:row2, col2:], matrix[row2:, :col2], matrix[row2:, col2:]

def strassen(matrix_a, matrix_b):
    # Base case: if the matrix is 1x1
    if len(matrix_a) == 1:
        return matrix_a * matrix_b
    
    # Splitting the matrices into quarters
    a11, a12, a21, a22 = split(matrix_a)
    b11, b12, b21, b22 = split(matrix_b)
    
    # Recursive steps
    p1 = strassen(a11 + a22, b11 + b22)
    p2 = strassen(a21 + a22, b11)
    p3 = strassen(a11, b12 - b22)
    p4 = strassen(a22, b21 - b11)
    p5 = strassen(a11 + a12, b22)
    p6 = strassen(a21 - a11, b11 + b12)
    p7 = strassen(a12 - a22, b21 + b22)
    
    # Calculating the resulting submatrices
    c11 = p1 + p4 - p5 + p7
    c12 = p3 + p5
    c21 = p2 + p4
    c22 = p1 - p2 + p3 + p6
    
    # Combining the submatrices into a single matrix
    result = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))
    
    return result

# Test the implementation
A = np.array([[1, 2, 3, 4],
              [5, 6, 7, 8],
              [9, 10, 11, 12],
              [13, 14, 15, 16]])

B = np.array([[17, 18, 19, 20],
              [21, 22, 23, 24],
              [25, 26, 27, 28],
              [29, 30, 31, 32]])

C = strassen(A, B)
print(C)

[[ 250  260  270  280]
 [ 618  644  670  696]
 [ 986 1028 1070 1112]
 [1354 1412 1470 1528]]
