<a href="https://colab.research.google.com/github/lmcanavals/algorithmic_complexity2023/blob/main/0301_divide_and_conquer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Divide and conquer

### Multiplication

In [None]:
def mult(a, b, n):
    if n == 1:
        return a*b

    ai = a // 10**(n//2)
    ad = a % 10**(n//2)
    bi = b // 10**(n//2)
    bd = b % 10**(n//2)
    z1 = mult(ai, bi, n//2) * 10**n
    z2 = (mult(ai, bd, n//2) + mult(ad, bi, n//2)) * 10**(n//2)
    z3 = mult(ad, bd, n//2)
    return z1 + z2 + z3

assert mult(1234, 4321, 4) == 1234*4321
assert mult(94494994, 38383848, 8) == 94494994 * 38383848
print("all tests successful")

all tests successful


### Matrix Multiplication

#### Brute force

In [3]:
import numpy as np

def matmul(a, b):
    nrowsa, ncolsa = a.shape
    nrowsb, ncolsb = b.shape
    if ncolsa != nrowsb:
        return None

    c = np.zeros((nrowsa, ncolsb), dtype=int)
    for i in range(nrowsa):
        for j in range(ncolsb):
            for k in range(ncolsa):
                c[i, j] += a[i, k] * b[k, j]

    return c

np.random.seed(42)
a = np.random.randint(1, 10, size=(4, 3))
b = np.random.randint(1, 10, size=(3, 2))

c = matmul(a, b)
assert (c == np.matmul(a, b)).all()
print(c)

[[105  98]
 [ 74  62]
 [101  88]
 [116  88]]


#### Divide and conquer

In [11]:
def matmuldc(a, b):
    nrowsa, ncolsa = a.shape
    nrowsb, ncolsb = b.shape
    if ncolsa != nrowsb:
        return None

    c = np.zeros((nrowsa, ncolsb), dtype=int)

    def solve(i0, i1, j0, j1):
        if i0 == i1 and j0 == j1:
            for k in range(ncolsa):
                c[i0, j0] += a[i0, k] * b[k, j0]
            return
        if i0 < i1 and j0 < j1:
            mi = (i0 + i1) // 2
            mj = (j0 + j1) // 2
            solve(i0, mi, j0, mj)
            solve(i0, mi, mj+1, j1)
            solve(mi+1, i1, j0, mj)
            solve(mi+1, i1, mj+1, j1)
        elif i0 < i1:
            mi = (i0 + i1) // 2
            solve(i0, mi, j0, j1)
            solve(mi+1, i1, j0, j1)
        else:
            mj = (j0 + j1) // 2
            solve(i0, i1, j0, mj)
            solve(i0, i1, mj+1, j1)

    solve(0, nrowsa-1, 0, ncolsb-1)
    return c

np.random.seed(42)
a = np.random.randint(1, 10, size=(4, 3))
b = np.random.randint(1, 10, size=(3, 2))

c = matmuldc(a, b)
assert (c == np.matmul(a, b)).all()
print(c)

[[105  98]
 [ 74  62]
 [101  88]
 [116  88]]
