<a href="https://colab.research.google.com/github/lmcanavals/acomplex/blob/main/03_03_matrix_multiplication.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np

In [4]:
def matmul(a, b):
  aRows, aCols = a.shape
  bRows, bCols = b.shape
  if aCols != bRows:
    return None

  c = np.zeros((aRows, bCols))

  for i in range(aRows):
    for j in range(bCols):
      temp = 0
      for k in range(aCols):
        temp += a[i, k] * b[k, j]

      c[i, j] = temp

  return c

In [5]:
a = np.random.randint(10, size=(3, 4))
b = np.random.randint(10, size=(4, 2))
c = matmul(a, b)

assert c.all() == np.matmul(a, b).all()

print(a)
print(b)
print(c)

[[1 8 4 3]
 [3 9 5 3]
 [0 9 8 7]]
[[9 2]
 [4 3]
 [1 2]
 [5 7]]
[[60. 55.]
 [83. 64.]
 [79. 92.]]


In [8]:
def mm1(a, b, c, rowi, rowf, coli, colf):
  n = len(a)
  if rowi == rowf:
    temp = 0
    for k in range(n):
      temp += a[rowi, k] * b[k, coli]
    c[rowi, coli] = temp
  else:
    rowmid = (rowi + rowf) // 2
    colmid = (coli + colf) // 2
    mm1(a, b, c, rowi, rowmid, coli, colmid)
    mm1(a, b, c, rowi, rowmid, colmid + 1, colf)
    mm1(a, b, c, rowmid + 1, rowf, coli, colmid)
    mm1(a, b, c, rowmid + 1, rowf, colmid + 1, colf)

In [9]:
def dvmatmul1(a, b):
  n = len(a)
  c = np.zeros((n, n))
  mm1(a, b, c, 0, n-1, 0, n-1)
  return c

In [11]:
a = np.random.randint(10, size=(8, 8))
b = np.random.randint(10, size=(8, 8))
c = dvmatmul1(a, b)

assert c.all() == np.matmul(a, b).all()

print(a)
print(b)
print(c)

[[0 0 0 9 7 5 2 4]
 [1 3 3 6 9 8 6 1]
 [2 4 0 9 2 0 4 2]
 [6 3 6 3 6 2 1 5]
 [6 1 6 4 7 1 3 7]
 [6 9 9 1 0 5 0 2]
 [9 5 7 6 3 6 4 1]
 [7 8 0 7 0 8 0 2]]
[[1 5 6 7 0 9 3 8]
 [6 5 9 5 5 4 2 7]
 [9 9 9 4 0 1 7 6]
 [2 9 6 7 5 0 0 1]
 [9 5 7 3 0 1 8 2]
 [4 2 6 1 7 1 3 9]
 [6 8 2 2 8 2 8 9]
 [1 3 6 2 1 5 3 4]]
[[117. 154. 161. 101. 100.  36.  99. 102.]
 [208. 213. 225. 125. 150.  58. 177. 201.]
 [ 88. 159. 136. 115.  99.  54.  68. 101.]
 [157. 183. 221. 134.  57. 107. 143. 167.]
 [166. 207. 226. 141.  63. 113. 166. 173.]
 [165. 181. 246. 139.  87. 114. 120. 219.]
 [190. 249. 269. 183. 130. 130. 163. 255.]
 [103. 160. 216. 150. 133. 113.  67. 199.]]


In [12]:
def mm2(a, b, c, rowi, rowf, coli, colf):
  aRows, aCols = a.shape
  bRows, bCols = b.shape
  if rowi == rowf and coli == colf:
    temp = 0
    for k in range(aCols):
      temp += a[rowi, k] * b[k, coli]
    c[rowi, coli] = temp
  elif rowi == rowf:
    colmid = (coli + colf) // 2
    mm2(a, b, c, rowi, rowf, coli, colmid)
    mm2(a, b, c, rowi, rowf, colmid + 1, colf)
  elif coli == colf:
    rowmid = (rowi + rowf) // 2
    mm2(a, b, c, rowi, rowmid, coli, colf)
    mm2(a, b, c, rowmid + 1, rowf, coli, colf)
  else:
    rowmid = (rowi + rowf) // 2
    colmid = (coli + colf) // 2
    mm2(a, b, c, rowi, rowmid, coli, colmid)
    mm2(a, b, c, rowi, rowmid, colmid + 1, colf)
    mm2(a, b, c, rowmid + 1, rowf, coli, colmid)
    mm2(a, b, c, rowmid + 1, rowf, colmid + 1, colf)

In [14]:
def dvmatmul2(a, b):
  aRows, aCols = a.shape
  bRows, bCols = b.shape
  if aCols != bRows:
    return None

  c = np.zeros((aRows, bCols))
  mm2(a, b, c, 0, aRows-1, 0, bCols-1)
  return c

In [16]:
a = np.random.randint(10, size=(4, 8))
b = np.random.randint(10, size=(8, 3))
c = dvmatmul2(a, b)

assert c.all() == np.matmul(a, b).all()

print(a)
print(b)
print(c)

[[5 5 7 9 5 4 8 4]
 [6 3 9 8 9 5 3 3]
 [3 0 5 7 9 1 3 6]
 [7 8 4 4 8 5 8 4]]
[[9 8 1]
 [3 8 2]
 [4 5 4]
 [8 8 7]
 [3 7 3]
 [2 3 4]
 [6 2 1]
 [3 3 5]]
[[243. 262. 165.]
 [227. 274. 169.]
 [168. 195. 136.]
 [229. 271. 139.]]
