# Matrix Multiplication

In [None]:
from typing import List
from random import Random
from datetime import datetime
from math import log, sqrt
from functools import partial
from collections import defaultdict
import sys
import plotly.graph_objects as go
from IPython.display import display, HTML, Markdown, Image

In [None]:
INTERACTIVE_CHARTS = True

def show_figure(fig, image_path: str):
    if INTERACTIVE_CHARTS:
        fig.show()
    else:
        fig.write_image(image_path)
        display(Image(image_path))

def create_line_chart(title, xs, result, width=800, height=800, xscale='log', yscale='log'):
    fig = go.Figure()
    
    # add traces
    for name, ys in result.items():
        fig.add_trace(go.Scatter(x=xs, y=ys, mode='lines', name=name))

    fig.update_xaxes(title='n', type=xscale)
    fig.update_yaxes(title='running time (sec)', type=yscale)
    fig.update_layout(dict(title=title, showlegend=True, width=width, height=height))
    
    # show figure
    show_figure(fig, f'img/04_2_Matrix.png')

In [None]:
M = List[List[int]]

In [None]:
rand = Random(12345)

def create_instance(n: int) -> M:
    return [[rand.randint(1, 100) for j in range(n)] for i in range(n)]

In [None]:
def add(X: M, Y: M) -> M:
    n = len(X)
    return [[X[i][j] + Y[i][j] for j in range(n)] for i in range(n)]

def subtract(X: M, Y: M) -> M:
    n = len(X)
    return [[X[i][j] - Y[i][j] for j in range(n)] for i in range(n)]

def mul_naive(X: M, Y: M) -> M:
    n = len(X)
    ret = [[0 for j in range(n)] for i in range(n)]
    for i in range(n):
        for j in range(n):
            for k in range(n):
                ret[i][j] += X[i][k] * Y[k][j]
    return ret

def mul_naive_rec(X: M, Y: M) -> M:
    n = len(X)
    if n == 1:
        return [[X[0][0] * Y[0][0]]]
    
    # divide
    nn = n // 2
    assert nn * 2 == n, 'n must be a power of 2'
    
    A = [[X[i][j] for j in range(nn)] for i in range(nn)]
    B = [[X[i][j] for j in range(nn, n)] for i in range(nn)]
    C = [[X[i][j] for j in range(nn)] for i in range(nn, n)]
    D = [[X[i][j] for j in range(nn, n)] for i in range(nn, n)]
    
    E = [[Y[i][j] for j in range(nn)] for i in range(nn)]
    F = [[Y[i][j] for j in range(nn, n)] for i in range(nn)]
    G = [[Y[i][j] for j in range(nn)] for i in range(nn, n)]
    H = [[Y[i][j] for j in range(nn, n)] for i in range(nn, n)]
    
    AE = mul_naive_rec(A, E)
    BG = mul_naive_rec(B, G)
    AF = mul_naive_rec(A, F)
    BH = mul_naive_rec(B, H)
    CE = mul_naive_rec(C, E)
    DG = mul_naive_rec(D, G)
    CF = mul_naive_rec(C, F)
    DH = mul_naive_rec(D, H)
    
    # conquer
    ret = [[0 for j in range(n)] for i in range(n)]
    for i in range(nn):
        for j in range(nn):
            ret[i][j] = AE[i][j] + BG[i][j]
            ret[i][j + nn] = AF[i][j] + BH[i][j]
            ret[i + nn][j] = CE[i][j] + DG[i][j]
            ret[i + nn][j + nn] = CF[i][j] + DH[i][j]
    
    return ret
    
def mul_fast(X: M, Y: M) -> M:
    n = len(X)
    if n == 1:
        return [[X[0][0] * Y[0][0]]]
    
    # divide
    nn = n // 2
    assert nn * 2 == n, 'n must be a power of 2'
    
    A = [[X[i][j] for j in range(nn)] for i in range(nn)]
    B = [[X[i][j] for j in range(nn, n)] for i in range(nn)]
    C = [[X[i][j] for j in range(nn)] for i in range(nn, n)]
    D = [[X[i][j] for j in range(nn, n)] for i in range(nn, n)]
    
    E = [[Y[i][j] for j in range(nn)] for i in range(nn)]
    F = [[Y[i][j] for j in range(nn, n)] for i in range(nn)]
    G = [[Y[i][j] for j in range(nn)] for i in range(nn, n)]
    H = [[Y[i][j] for j in range(nn, n)] for i in range(nn, n)]
    
    P1 = mul_fast(A, subtract(F, H))
    P2 = mul_fast(add(A, B), H)
    P3 = mul_fast(add(C, D), E)
    P4 = mul_fast(D, subtract(G, E))
    P5 = mul_fast(add(A, D), add(E, H))
    P6 = mul_fast(subtract(B, D), add(G, H))
    P7 = mul_fast(subtract(A, C), add(E, F))
    
    # conquer
    ret = [[0 for j in range(n)] for i in range(n)]
    for i in range(nn):
        for j in range(nn):
            ret[i][j] = P5[i][j] + P4[i][j] - P2[i][j] + P6[i][j]
            ret[i][j + nn] = P1[i][j] + P2[i][j]
            ret[i + nn][j] = P3[i][j] + P4[i][j]
            ret[i + nn][j + nn] = P1[i][j] + P5[i][j] - P3[i][j] - P7[i][j]
    
    return ret

In [None]:
# run experiments

algs = [('naive', mul_naive), ('rec', mul_naive_rec), ('fast', mul_fast)]
ns = [2 ** i for i in range(3, 8)]
result = defaultdict(list)

for n in ns:
    X = create_instance(n)
    Y = create_instance(n)
    
    # validation
    expected_result = algs[0][1](X, Y)
    assert all(f(X, Y) == expected_result for _, f in algs[1:]), 'Algorithm is not correct'
    sys.stdout.write(f'Running measurment (n={n:4d}): ')
    
    # measurement
    for name, f in algs:
        sys.stdout.write('.')
        start = datetime.now()
        f(X, Y)
        elapsed = datetime.now() - start
        result[name] += [elapsed.total_seconds()]
    print()

In [None]:
create_line_chart('Matrix Multiplication', ns, result)