In [102]:
# probabalistic matrix multiply
import numpy as np

From https://arxiv.org/pdf/1712.08880.pdf

$$ prob_{k} = \frac{\|A_{*k}\|_{2} \cdot \|B_{k*}\|_{2}}{\sum_{k'=1}^{n} \|A_{*k'}\|_{2} \cdot \|B_{k'*}\|_{2}} $$

$$ C_{*t} = \frac{1}{\sqrt{c \cdot prob_{i_t}}} A_{*i_t}$$
$$ R_{t*} = \frac{1}{\sqrt{c \cdot prob_{i_t}}} B_{i_t*}$$

$$ CR = \sum_{t=1}^{c} \frac{1}{c \cdot prob_{i_t}} A_{*i_t} B_{i_t*} = 
 \frac{1}{c} \sum_{t=1}^{c} \frac{1}{prob_{i_t}} A_{*i_t} B_{i_t*}$$

In [103]:
np.random.seed(100)
# A = m rows, n cols
# B = n rows, p cols
m, n, p = 50, 100, 80
A = np.random.rand(m,n)
B = np.random.rand(n,p)

In [104]:
def norm(X, axis=0):
    return np.sqrt(np.sum(X ** 2, axis=axis))

def approx_matmul(A, B, factor=1):
    m, n = A.shape
    _, p = B.shape
    
    A_col_norm, B_row_norm = norm(A, axis=0), norm(B, axis=1)
    sum_norm = np.sum(A_col_norm * B_row_norm)
    prob = np.array([A_col_norm[k] * B_row_norm[k] / sum_norm for k in range(n)])
    np.testing.assert_almost_equal(np.sum(prob), 1, 5) # probability should sum to 1, within floating point error
    c = int(n * factor)
    i = np.flip(np.argsort(prob))[:c] # arg sort is ascending, so first reverse then slice
    prob_i = prob[i]
    A_i = A[:,i]
    B_i = B[i,:]

    result = np.zeros((m, p))
    for t in range(c):
        for i in range(m):
            for j in range(p):
                result[i, j] += (1 / prob_i[t]) * A_i[i, t] * B_i[t, j]
    return (1 / c) * result

In [105]:
approx_matmul(A,B,factor=0.8)

array([[20.69161515, 25.82614528, 22.3773683 , ..., 24.59434283,
        20.86593701, 23.38318224],
       [27.30487591, 31.58371029, 25.7934565 , ..., 28.63853213,
        26.74889688, 26.87976057],
       [26.48694919, 32.76078426, 26.77218696, ..., 27.8878618 ,
        28.16856331, 26.69426382],
       ...,
       [24.46240435, 29.38974387, 25.45994266, ..., 24.85479098,
        26.06522319, 26.53235013],
       [22.1735354 , 25.83860564, 21.00780044, ..., 23.24765748,
        21.41353967, 23.25660577],
       [22.69847497, 27.78549586, 21.97999475, ..., 23.6478906 ,
        23.56302511, 25.57350114]])

In [109]:
# how did we do?
mse = np.mean(((A@B) - approx_matmul(A,B,0.7)) ** 2)
mse

1.991670784133436