# Benchmarks on matrix multiplication

Matrix multiplication can be defined as: ${A*B}_{ik} = a_i*b_k$

where $A$ is a $i \times j$ matrix

$B$ is a $j \times k$ matrix

and $a_i$ is the $i^{th}$ row vector of $A$

and $b_i$ is the $k^{th}$ column vector of $B$

In [None]:
We can write a matrix multiplication function using this definition

In [2]:
import numpy as np

In [4]:
def matrix_mult(a,b):
    if a.shape[1] == b.shape[0]:
        I,J = a.shape
        J,K = b.shape
        c=np.zeros(I*K).reshape(I,K)
        for i in range(I):
            ai=a[i,:]
            for k in range(K):
                bk=b[:,k]
                c[i,k] = (ai*bk).sum() # numpy step
    else:
        print('a and b must have compatible shapes')
    return(c)

We want to benchmark this against `np.dot` and a numba version

In [7]:
from numba import njit, jit

In [8]:
@njit(parallel=True)
def matrix_mult_numba_jit(a,b):
    if a.shape[1] == b.shape[0]:
        I,J = a.shape
        J,K = b.shape
        c=np.zeros(I*K).reshape(I,K)
        for i in range(I):
            ai=a[i,:]
            for k in range(K):
                bk=b[:,k]
                c[i,k] = (ai*bk).sum() # numpy step
    else:
        print('a and b must have compatible shapes')
    return(c)

In [9]:
@jit
def matrix_mult_numba_njit(a,b):
    if a.shape[1] == b.shape[0]:
        I,J = a.shape
        J,K = b.shape
        c=np.zeros(I*K).reshape(I,K)
        for i in range(I):
            ai=a[i,:]
            for k in range(K):
                bk=b[:,k]
                c[i,k] = (ai*bk).sum() # numpy step
    else:
        print('a and b must have compatible shapes')
    return(c)

In [10]:
i, j, k = 100,101,102
a = np.arange(i*j).reshape(i,j)
b = np.arange(j*k).reshape(j,k)

In [11]:
%timeit np.dot(a,b)

585 µs ± 3.75 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [12]:
%timeit matrix_mult(a,b)

34.2 ms ± 322 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
%timeit matrix_mult_numba_jit(a,b)

291 ms ± 8.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%timeit matrix_mult_numba_njit(a,b)

2.13 ms ± 27.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### benchmark different sizes of matrices

In [34]:
%%time
d_list = []
size_list = [10,100,1000]
for i in size_list:
    for j in size_list:
        a = np.arange(i*j).reshape(i,j)
        for k in size_list:
            if i>=j and j>=k:
                b = np.arange(j*k).reshape(j,k)
                t_np = %timeit -oq np.dot(a,b)
                t_plain = %timeit -oq matrix_mult(a,b)
                t_numba_jit = %timeit -oq matrix_mult_numba_jit(a,b)
                t_numba_njit = %timeit -oq matrix_mult_numba_njit(a,b)
                d_list.append({'t_np':t_np.average,
                               't_plain':t_plain.average,
                               't_numba_jit':t_numba_jit.average,
                               't_numba_njit':t_numba_njit.average,
                               'i':i,
                               'j':j,
                               'k':k})

CPU times: user 8min 43s, sys: 11min 12s, total: 19min 56s
Wall time: 9min 47s


In [18]:
import pandas as pd

In [51]:
benchmark_df = pd.DataFrame(d_list)

In [61]:
benchmark_df['ratio_np'] = benchmark_df['t_plain'] / benchmark_df['t_np']
benchmark_df['ratio_numba_jit'] = benchmark_df['t_plain'] / benchmark_df['t_numba_jit']
benchmark_df['ratio_numba_njit'] = benchmark_df['t_plain'] / benchmark_df['t_numba_njit']


We can see that we get the largest speedups with `numpy`, while `jit` doesn't give any over plain python. We jave to use `njit(parallel=True)` to get some speedups

In [62]:
benchmark_df

Unnamed: 0,i,j,k,t_np,t_numba_jit,t_plain,ratio_np,ratio_numba_jit,ratio_numba_jnit,t_numba_njit,ratio_numba_njit
0,10,10,10,2e-06,0.002896,0.000351,172.573985,0.121252,24.338782,1.4e-05,24.338782
1,100,10,10,1.1e-05,0.030061,0.003535,308.795863,0.117588,26.232226,0.000135,26.232226
2,100,100,10,5.8e-05,0.028459,0.003716,64.034474,0.130574,16.084315,0.000231,16.084315
3,100,100,100,0.000569,0.291665,0.036374,63.954127,0.124711,15.594155,0.002333,15.594155
4,1000,10,10,0.000103,0.282053,0.034925,340.723281,0.123823,25.938937,0.001346,25.938937
5,1000,100,10,0.000572,0.279909,0.036643,64.020513,0.130911,15.675439,0.002338,15.675439
6,1000,100,100,0.005677,2.826907,0.361903,63.751701,0.128021,15.296839,0.023659,15.296839
7,1000,1000,10,0.008541,0.357186,0.067103,7.856758,0.187864,5.611985,0.011957,5.611985
8,1000,1000,100,0.099013,2.968083,0.627378,6.336345,0.211375,4.751782,0.13203,4.751782
9,1000,1000,1000,1.764958,30.809161,12.465579,7.062821,0.404606,5.06537,2.460941,5.06537
