# Matrix multiplication

## Brute force

In [1]:
def multiply_matrices(A, B):
    '''Multiply matrix A by matrix B.'''
    
    l = len(A)
    m = len(A[0])
    n = len(B)
    o = len(B[0])
    
    assert m == n, 'ERROR: Dimensions don\'t match'
    
    C = [[0 for row in range(l)] for col in range(o)]
    
    for i in range(l):
        for j in range(o):
            for k in range(m):
                C[i][j] += A[i][k] * B[k][j]
    
    return C

### Testing

In [2]:
A = [[0, 1, 2],
     [3, 4, 5]]

B = [[7, 8],
     [9, 10],
     [11, 12]]

print(multiply_matrices(A, B))

[[31, 34], [112, 124]]


## Strassen's method

In [None]:
def strassen(A, B):
    '''Multiply matrix A by matrix B using Strassen\'s recursive method.'''
    
    return strassen_helper(A, B, 0, len(A) - 1, 0, len(A[0]) - 1, 0, len(B) - 1, 0, len(B[0]) - 1)
    
def strassen_helper(A, B,
                    A_row_top, A_row_btm, A_col_lft, A_col_rgt,
                    B_row_top, B_row_btm, B_col_lft, B_col_rgt):
    '''Helper function for strassen().'''
    
    print(A_row_top, A_row_btm, A_col_lft, A_col_rgt, B_row_top, B_row_btm, B_col_lft, B_col_rgt)
    
    n = A_row_btm - A_row_top
    
    if n == 0:
        C = A[A_row_top][A_col_lft] * B[B_row_top][B_col_lft]
        return [[C]]
    else:
        A_row_mid = (A_row_top + A_row_btm) // 2
        A_col_mid = (A_col_lft + A_col_rgt) // 2
        B_row_mid = (B_row_top + B_row_btm) // 2
        B_col_mid = (B_row_top + B_row_btm) // 2
        #print(A_row_mid, A_col_mid, B_row_mid, B_col_mid)
        
        C11 = \
        strassen_helper(A, B,
                        A_row_top, A_row_mid - 1, A_col_lft, A_col_mid - 1,
                        B_row_top, B_row_mid - 1, B_col_lft, B_col_mid - 1) +\
        strassen_helper(A, B,
                        A_row_top, A_row_mid - 1, A_col_mid + 1, A_col_rgt,
                        B_row_mid + 1, B_row_btm, B_col_lft, B_col_mid - 1)
        
        C12 = \
        strassen_helper(A, B,
                        A_row_top, A_row_mid - 1, A_col_lft, A_col_mid - 1,
                        B_row_top, B_row_mid - 1, B_col_mid + 1, B_col_rgt) +\
        strassen_helper(A, B,
                        A_row_top, A_row_mid - 1, A_col_mid + 1, A_col_rgt,
                        B_row_mid + 1, B_row_btm, B_col_mid + 1, B_col_rgt)
        
        C21 = \
        strassen_helper(A, B,
                        A_row_mid + 1, A_row_btm, A_col_lft, A_col_mid - 1,
                        B_row_top, B_row_mid - 1, B_col_lft, B_col_mid - 1) +\
        strassen_helper(A, B,
                        A_row_mid + 1, A_row_btm, A_col_mid + 1, A_col_rgt,
                        B_row_mid + 1, B_row_btm, B_col_lft, B_col_mid - 1)
        
        C22 = \
        strassen_helper(A, B,
                        A_row_mid + 1, A_row_btm, A_col_lft, A_col_mid - 1,
                        B_row_top, B_row_mid - 1, B_col_mid + 1, B_col_rgt) +\
        strassen_helper(A, B,
                        A_row_mid + 1, A_row_btm, A_col_mid + 1, A_col_rgt,
                        B_row_mid + 1, B_row_btm, B_col_mid + 1, B_col_rgt)
        
        C = [C11[i] + C12[i] for i in range(n)] + [C21[i] + C22[i] for i in range(n)]
        #print(C)
        return C

In [77]:
A = [[0, 1],
     [3, 4]]
B = [[5, 6],
     [7, 8]]
n = 2

C = [A[i] + B[i] for i in range(n)] + [A[i] + B[i] for i in range(n)]
print(C)

[[0, 1, 5, 6], [3, 4, 7, 8], [0, 1, 5, 6], [3, 4, 7, 8]]


In [None]:
A = [[0, 1, 2, 3],
     [4, 5, 6, 7],
     [8, 9, 10, 11],
     [12, 13, 14, 15]]

B = [[10, 11, 12, 13],
     [4, 5, 6, 7],
     [8, 9, 10, 11],
     [12, 13, 14, 15]]
print(strassen(A, B))

In [70]:
A = [[0, 1],
     [3, 4]]

B = [[7, 8],
     [9, 10]]

print(strassen(A, B))

0 1 0 1 0 1 0 1
0 -1 0 -1 0 -1 0 -1
0 -1 1 1 1 1 0 -1
0 -1 0 -1 0 -1 1 1
0 -1 1 1 1 1 1 1
1 1 0 -1 0 -1 0 -1
1 0 0 -2 0 -2 0 -2
1 0 0 -1 0 -1 0 -2
1 0 0 -2 0 -2 0 -1
1 0 0 -1 0 -1 0 -1
2 1 0 -2 0 -2 0 -2


IndexError: list index out of range

In [13]:
A = [[10]]

B = [[10]]

print(strassen(A, B))

0 0 0 0 0 0 0 0
100


In [19]:
A + B

[[0, 1, 2, 3],
 [4, 5, 6, 7],
 [8, 9, 10, 11],
 [12, 13, 14, 15],
 [10, 11, 12, 13],
 [4, 5, 6, 7],
 [8, 9, 10, 11],
 [12, 13, 14, 15]]

In [26]:
A[0] + A[1]

[0, 1, 2, 3, 4, 5, 6, 7]