# Matrix multiplication

## Packages (for checking)

In [1]:
import numpy as np

## Brute force

In [2]:
def multiply_matrices(A, B):
    '''Multiply matrix A by matrix B.'''
    
    l = len(A)
    m = len(A[0])
    n = len(B)
    o = len(B[0])
    
    assert m == n, 'ERROR: Dimensions don\'t match'
    
    C = [[0 for row in range(l)] for col in range(o)]
    
    for i in range(l):
        for j in range(o):
            for k in range(m):
                C[i][j] += A[i][k] * B[k][j]
    
    return C

### Testing

In [3]:
A = [[0, 1, 2],
     [3, 4, 5]]

B = [[7, 8],
     [9, 10],
     [11, 12]]

print(multiply_matrices(A, B))
print(np.matmul(A, B))

[[31, 34], [112, 124]]
[[ 31  34]
 [112 124]]


## Strassen's method

In [4]:
def add_matrices(A, B):
    '''Add two matrices elementwise.'''
    
    n = len(A)
    m = len(A[0])
    
    C = [[0 for row in range(n)] for col in range(m)]
    
    for i in range(n):
        for j in range(m):
            C[i][j] = A[i][j] + B[i][j]
            
    return C

In [5]:
def strassen(A, B):
    '''Multiply matrix A by matrix B using Strassen\'s recursive method.'''
    
    return strassen_helper(A, B, 0, len(A), 0, len(A[0]), 0, len(B), 0, len(B[0]))
    
def strassen_helper(A, B,
                    A_row_top, A_row_btm, A_col_lft, A_col_rgt,
                    B_row_top, B_row_btm, B_col_lft, B_col_rgt):
    '''Helper function for strassen().'''
    
    n = A_row_btm - A_row_top + 1
    
    if n == 2:
        C = A[A_row_top][A_col_lft] * B[B_row_top][B_col_lft]
        return [[C]]
    else:
        A_row_mid = (A_row_top + A_row_btm) // 2
        A_col_mid = (A_col_lft + A_col_rgt) // 2
        B_row_mid = (B_row_top + B_row_btm) // 2
        B_col_mid = (B_col_lft + B_col_rgt) // 2
        
        C111 = strassen_helper(A, B,
                        A_row_top, A_row_mid, A_col_lft, A_col_mid,
                        B_row_top, B_row_mid, B_col_lft, B_col_mid)
        C112 = strassen_helper(A, B,
                        A_row_top, A_row_mid, A_col_mid, A_col_rgt,
                        B_row_mid, B_row_btm, B_col_lft, B_col_mid)
        C11 = add_matrices(C111, C112)

        C121 = strassen_helper(A, B,
                        A_row_top, A_row_mid, A_col_lft, A_col_mid,
                        B_row_top, B_row_mid, B_col_mid, B_col_rgt)
        C122 = strassen_helper(A, B,
                        A_row_top, A_row_mid, A_col_mid, A_col_rgt,
                        B_row_mid, B_row_btm, B_col_mid, B_col_rgt)
        C12 = add_matrices(C121, C122)
        
        C211 = strassen_helper(A, B,
                        A_row_mid, A_row_btm, A_col_lft, A_col_mid,
                        B_row_top, B_row_mid, B_col_lft, B_col_mid)
        C212 = strassen_helper(A, B,
                        A_row_mid, A_row_btm, A_col_mid, A_col_rgt,
                        B_row_mid, B_row_btm, B_col_lft, B_col_mid)
        C21 = add_matrices(C211, C212)
        
        C221 = strassen_helper(A, B,
                        A_row_mid, A_row_btm, A_col_lft, A_col_mid,
                        B_row_top, B_row_mid, B_col_mid, B_col_rgt)
        C222 = strassen_helper(A, B,
                        A_row_mid, A_row_btm, A_col_mid, A_col_rgt,
                        B_row_mid, B_row_btm, B_col_mid, B_col_rgt)
        C22 = add_matrices(C221, C222)
        
        C = [C11[i] + C12[i] for i in range(n // 2)] + [C21[i] + C22[i] for i in range(n // 2)]
        
        return C

### Testing

In [6]:
A = [[10]]
B = [[7]]
print(strassen(A, B))
print(np.matmul(A, B))

[[70]]
[[70]]


In [10]:
A = [[0, 1],
     [3, 4]]

B = [[7, 8],
     [9, 10]]

print(strassen(A, B))
print(np.matmul(A, B))

[[9, 10], [57, 64]]
[[ 9 10]
 [57 64]]


In [8]:
A = [[0, 1, 2, 3],
     [4, 5, 6, 7],
     [8, 9, 10, 11],
     [12, 13, 14, 15]]

B = [[10, 11, 12, 13],
     [4, 5, 6, 7],
     [8, 9, 10, 11],
     [12, 13, 14, 15]]
print(strassen(A, B))
print(np.matmul(A, B))

[[56, 62, 68, 74], [192, 214, 236, 258], [328, 366, 404, 442], [464, 518, 572, 626]]
[[ 56  62  68  74]
 [192 214 236 258]
 [328 366 404 442]
 [464 518 572 626]]
