In [1]:
from matrix import *
from random import random
#implementation of Base strassen matrix multiplication
def isPwr2(x): 
    #uses the fact that a power of 2 in binary has one 1 and the remaining digits are 0
    #so 16 = 10000, 32 = 100000 and so on 
    #and eg. 16 - 1 = 15 = 01111 this holds for every 2**n
    #so bitwise 16 && 15 = 1000 && 0111 = 0000
    #so not(16 && 15) returns 1
    return not(x & (x - 1))

def strassen_matrix_mult(A: Matrix, B: Matrix) -> Matrix:
    if A.num_of_cols != B.num_of_rows:
        raise ValueError("Wrong matrix shape: number of columns of A is %d, number of rows of B is %d"
                         %(A.num_of_cols, B.num_of_cols) )
        
    if (A.num_of_cols != A.num_of_rows or B.num_of_cols != B.num_of_rows) and not (isPwr2(A.num_of_cols)) :
        raise NotImplementedError("This implemetation deals with SQUARE matrices products with use instead GENERAL_strassen_matrix_mul")
    
    #Base case
    if A.num_of_cols < 32:
        return gauss_matrix_mult(A,B)
    
    #quadrant subdivision
    n_half = A.num_of_cols//2
    
    A11 = A.submatrix(0, n_half, 0, n_half)
    A21 = A.submatrix(n_half, n_half, 0, n_half)
    A12 = A.submatrix(0, n_half, n_half, n_half)
    A22 = A.submatrix(n_half, n_half, n_half, n_half)
    
    B11 = B.submatrix(0, n_half, 0, n_half)
    B21 = B.submatrix(n_half, n_half, 0, n_half)
    B12 = B.submatrix(0, n_half, n_half, n_half)
    B22 = B.submatrix(n_half, n_half, n_half, n_half)
        
    S1 = B12 - B22
    S2 = A11 + A12
    S3 = A21 + A22
    S4 = B21 - B11
    S5 = A11 + A22
    S6 = B11 + B22
    S7 = A12 - A22
    S8 = B21 + B22
    S9 = A11 - A21
    S10 = B11 + B12
    
    P1 = strassen_matrix_mult(A11,S1)
    P2 = strassen_matrix_mult(S2,B22)
    P3 = strassen_matrix_mult(S3,B11)
    P4 = strassen_matrix_mult(A22,S4)
    P5 = strassen_matrix_mult(S5,S6)
    P6 = strassen_matrix_mult(S7,S8)
    P7 = strassen_matrix_mult(S9,S10)
    
    C11 = P5 + P4 - P2 + P6
    C12 = P1 + P2
    C21 = P3 + P4
    C22 = P5 + P1 - P3 - P7
    
    C = Matrix([[0 for j in range(B.num_of_cols)] for i in range(A.num_of_rows)])
    
    C.assign_submatrix(0,0,C11)
    C.assign_submatrix(n_half,0,C21)
    C.assign_submatrix(0,n_half,C12)
    C.assign_submatrix(n_half,n_half,C22)
    
    return C

"""
Strassen's algorithm relies on summations such as B12 + B22 but if matrix B has an odd number
of rows it is not possible to do this sum due to the incompatible shape of the matrices.

First attempt: full padding the idea is expand with zeros the matrix until reaching something like 2^n then
apply classic strassen. In the end cut out the matrix
"""

def GEN_strassen_matrix_mult(A: Matrix, B: Matrix) -> Matrix:
    if A.num_of_cols != B.num_of_rows:
        raise ValueError("Wrong matrix shape: number of columns of A is %d, number of rows of B is %d"
                         %(A.num_of_cols, B.num_of_cols) )
    if (A.num_of_cols == A.num_of_rows and B.num_of_cols == B.num_of_rows) and (isPwr2(A.num_of_cols)) :
        #if the matrices satisfy classical strassen use it
        return   strassen_matrix_mul(A,B)  
    #Base case
    if A.num_of_cols < 2:
        return gauss_matrix_mult(A,B)
    
    #padding
    n_max = max([A.num_of_cols, A.num_of_rows, B.num_of_cols])
    n_tot = 1
    while n_tot < n_max:
        n_tot *= 2
    
    Ap = Matrix([[0 for j in range(n_tot)] for i in range(n_tot)])
    Bp = Matrix([[0 for j in range(n_tot)] for i in range(n_tot)])
    
    Ap.assign_submatrix(0,0,A)
    Bp.assign_submatrix(0,0,B)
    
    
    #quadrant subdivision
    n_half = n_tot//2
    
    A11 = Ap.submatrix(0, n_half, 0, n_half)
    A21 = Ap.submatrix(n_half, n_half, 0, n_half)
    A12 = Ap.submatrix(0, n_half, n_half, n_half)
    A22 = Ap.submatrix(n_half, n_half, n_half, n_half)
    
    B11 = Bp.submatrix(0, n_half, 0, n_half)
    B21 = Bp.submatrix(n_half, n_half, 0, n_half)
    B12 = Bp.submatrix(0, n_half, n_half, n_half)
    B22 = Bp.submatrix(n_half, n_half, n_half, n_half)
        
    S1 = B12 - B22
    S2 = A11 + A12
    S3 = A21 + A22
    S4 = B21 - B11
    S5 = A11 + A22
    S6 = B11 + B22
    S7 = A12 - A22
    S8 = B21 + B22
    S9 = A11 - A21
    S10 = B11 + B12
    
    #now use classical strassen
    P1 = strassen_matrix_mult(A11,S1)
    P2 = strassen_matrix_mult(S2,B22)
    P3 = strassen_matrix_mult(S3,B11)
    P4 = strassen_matrix_mult(A22,S4)
    P5 = strassen_matrix_mult(S5,S6)
    P6 = strassen_matrix_mult(S7,S8)
    P7 = strassen_matrix_mult(S9,S10)
    
    C11 = P5 + P4 - P2 + P6
    C12 = P1 + P2
    C21 = P3 + P4
    C22 = P5 + P1 - P3 - P7
    
    C = Matrix([[0 for j in range(n_tot)] for i in range(n_tot)])
    
    C.assign_submatrix(0,0,C11)
    C.assign_submatrix(n_half,0,C21)
    C.assign_submatrix(0,n_half,C12)
    C.assign_submatrix(n_half,n_half,C22)
    
    #cut out the result
    return C.submatrix(0,A.num_of_rows,0,B.num_of_cols)




In [2]:
nc = 2
nr = 2

A = Matrix([[random() for j in range(nc)] for i in range(nr)])
B = Matrix([[random() for j in range(nc)] for i in range(nr)])

In [None]:
 
from time import perf_counter
for n in range(2,13):
    nc = 2**n
    nr = 2**n
    print("n = %d" % 2**n)
    A = Matrix([[random() for j in range(nc)] for i in range(nr)])
    B = Matrix([[random() for j in range(nc)] for i in range(nr)])
    
    t0 = perf_counter()
    c = strassen_matrix_mult(A,B)
    t1 = perf_counter()
    
    print("strassen elapsed: %.4f" % (t1-t0))
    
    t0 = perf_counter()
    c = gauss_matrix_mult(A,B)
    t1 = perf_counter()
    
    print("gauss elapsed: %.4f" % (t1-t0))
    print("-------")
    
   



n = 4
strassen elapsed: 0.0001
gauss elapsed: 0.0000
-------
n = 8
strassen elapsed: 0.0003
gauss elapsed: 0.0003
-------
n = 16
strassen elapsed: 0.0026
gauss elapsed: 0.0024
-------
n = 32
strassen elapsed: 0.0182
gauss elapsed: 0.0178
-------
n = 64
strassen elapsed: 0.1264
gauss elapsed: 0.1014
-------
n = 128
strassen elapsed: 0.7064
gauss elapsed: 0.7827
-------
n = 256
strassen elapsed: 4.7643
gauss elapsed: 5.8855
-------
n = 512
strassen elapsed: 33.5261


In [None]:
nrC = 11
ncC = 10
nrD = ncC
ncD = 10

C = Matrix([[random() for j in range(ncC)] for i in range(nrC)])
D = Matrix([[random() for j in range(ncD)] for i in range(nrD)])

c3 = gauss_matrix_mult(C,D)
c4 = GEN_strassen_matrix_mul(C,D)
c3 - c4

In [None]:
C = Matrix([[j for j in range(5)] for i in range(8)])

In [None]:
C[1][0]

In [None]:
for n in range(2,13):
    
    
    print("n = %d" % 2**n)
    nrC = 2**n + 1
    ncC = 2**n
    nrD = ncC
    ncD = 2**n

    C = Matrix([[random() for j in range(ncC)] for i in range(nrC)])
    D = Matrix([[random() for j in range(ncD)] for i in range(nrD)])
    
    t0 = perf_counter()
    c = GEN_strassen_matrix_mult(C,D)
    t1 = perf_counter()
    
    print("strassen elapsed: %.4f" % (t1-t0))
    
    t0 = perf_counter()
    c = gauss_matrix_mult(C,D)
    t1 = perf_counter()
    
    print("gauss elapsed: %.4f" % (t1-t0))
    print("-------")