David Kercher, Nov 12 2021

Matrix Multiplication using Strassens Algorithm for faster computation.

In [66]:
import numpy as np

def sm(A,B):

    # Base Case, 2 x 2 Matrix
    if len(A) == 2:
        return np.array([[(A[0][0] + A[1][1])*(B[0][0]+B[1][1]) + A[1][1]*(B[1][0]-B[0][0]) - (A[0][0]+A[0][1])*B[1][1] + (A[0][1]-A[1][1])*(B[1][0]+B[1][1]), 
                   A[0][0]*(B[0][1] - B[1][1]) + (A[0][0]+A[0][1])*B[1][1]],
                  [(A[1][0] + A[1][1])*B[0][0] + A[1][1]*(B[1][0]-B[0][0]), 
                   (A[0][0] + A[1][1])*(B[0][0]+B[1][1]) - (A[1][0] + A[1][1])*B[0][0] + A[0][0]*(B[0][1] - B[1][1]) + (A[1][0]-A[0][0])*(B[0][0]+B[0][1])]])
    
    # Recursive Step, A and B are 2^k dimensional matrices
    else:
        A_1, A_2, A_3, A_4 = sub_matrices(A)
        B_1, B_2, B_3, B_4 = sub_matrices(B)

        return np.vstack([np.hstack([sm(A_1+A_4, B_1+B_4) + sm(A_4, B_3-B_1) - sm(A_1+A_2, B_4) + sm(A_2-A_4, B_3+B_4), sm(A_1,B_2-B_4) + sm(A_1+A_2, B_4)]), 
                     np.hstack([sm(A_3+A_4, B_1) + sm(A_4, B_3-B_1), sm(A_1+A_4, B_1+B_4) - sm(A_3+A_4, B_1) + sm(A_1,B_2-B_4) + sm(A_3 - A_1, B_1 + B_2)])])
    
# returns A_1, A_2, A_3, A_4
def sub_matrices(A):
    return np.hsplit(np.vsplit(A, 2)[0], 2)[0], np.hsplit(np.vsplit(A, 2)[0], 2)[1], np.hsplit(np.vsplit(A, 2)[1], 2)[0], np.hsplit(np.vsplit(A, 2)[1], 2)[1]

In [68]:
# Numerical Evidence that the Program Works

A = np.random.rand(32,32) # random 2^5, 2^5 matrix
B = np.random.rand(32,32)
print("Approximate Error: ", np.linalg.norm(sm(A,B) - A@B)) # Difference between calculation using Strassen's Formula and using numpy matrix multiplication
print("AB using Strassens: ")
print(sm(A,B)) # The matrix product using Strassen's Formula
print("AB using numpy matrix product: ")
print(A@B)

Approximate Error:  4.359526761800664e-13
AB using Strassens: 
[[8.24791855 6.67559104 9.3531669  ... 7.26695202 7.80138361 8.59493935]
 [8.18864659 7.21354252 8.96482375 ... 7.49994067 8.92178546 8.46522131]
 [7.83414183 7.69057819 8.65345902 ... 7.09272127 8.59392466 8.89903197]
 ...
 [8.67810122 6.60079573 8.5007169  ... 6.88216332 8.52965373 8.35958002]
 [7.34983747 7.6320347  8.03917518 ... 6.39931948 8.88188757 8.18935908]
 [8.48045171 7.98437571 9.50635482 ... 9.04513987 8.47587778 7.99269994]]
AB using numpy matrix product: 
[[8.24791855 6.67559104 9.3531669  ... 7.26695202 7.80138361 8.59493935]
 [8.18864659 7.21354252 8.96482375 ... 7.49994067 8.92178546 8.46522131]
 [7.83414183 7.69057819 8.65345902 ... 7.09272127 8.59392466 8.89903197]
 ...
 [8.67810122 6.60079573 8.5007169  ... 6.88216332 8.52965373 8.35958002]
 [7.34983747 7.6320347  8.03917518 ... 6.39931948 8.88188757 8.18935908]
 [8.48045171 7.98437571 9.50635482 ... 9.04513987 8.47587778 7.99269994]]
