# Bini Fast Matrix Multiplication O(2.7799)

In [1]:
import numpy as np

In [2]:
from numpy import linalg as la

In [3]:
np.set_printoptions(precision=2)

In [4]:
#without dynamic peeling:
# row of A need to be divisible by 3 (power of is preferable)
# columns of A need do be divisble by 2

# rows and cols of B need to be divisible by two

In [115]:
A = np.random.rand(17, 14)
B = np.random.rand(14, 21)

In [120]:
def bini(A, B, steps, e=1e-8):
    
    #Check Dimensions
    (m, n) = A.shape
    #rn assuming that m is bigger than n, nn and p
    (nn, p) = B.shape
    if n != nn: raise ValueError("incompatible dimensions")
    
    #pre-allocate output matrix
    C = np.zeros((m,p))
    
    """
    This is the notation I use from Bini's 1980 paper

    |A1, A4|  |B1, B2|  =  |C1, C2|
    |A2, A5|  |B3, B4|     |C3, C4|
    |A3, A6|               |C5, C6|
    """
    
    #Base case
    if steps == 0 or m == 1 or n == 1 or p == 1:
        C = np.dot(A,B)
        return C
    
    #Dynamic peeling
    if m % 3 == 1:
        C[:m-1, :] = bini(A[:m-1,:],B, steps)
        C[m-1,:] = A[m-1,:]@B
        return C
    if m % 3 == 2:
        C[:m-2, :] = bini(A[:m-2,:],B, steps)
        C[m-2:,:] = A[m-2:,:]@B
        return C
    if n % 2 == 1:
        C = bini(A[:, :n-1], B[:n-1,:], steps)
        C = C + np.outer(A[:,n-1],B[n-1,:])
        return C
    if p % 2 == 1:
        C[:, :p-1] = bini(A, B[:,:p-1], steps)
        C[:,p-1] = A@B[:,p-1]
        return C
 

    # split up the matricies once rows of A are divisible by 3
    # and cols of A and rows and cols of are divisible by 2
    m2 = int(m/3) #first third of the rows of A
    m3 = m2*2     #second third of the rows of A
    n2 = int(n/2) #half of the cols of A
    p2 = int(p/2) #half of the cols of B
    #nn2 = int(nn/2) # half of the rows of B
    
    A1 = A[:m2, :n2]
    A2 = A[m2:m3, :n2]
    A3 = A[m3:, :n2]
    A4 = A[:m2, n2:]
    A5 = A[m2:m3, n2:]
    A6 = A[m3:, n2:]
    
    B1 = B[:n2, :p2]
    B2 = B[:n2, p2:]
    B3 = B[n2:, :p2]
    B4 = B[n2:, p2:]
    
    #bini(A, B, steps, e=0.1)
    # conquer
    M1 = bini(A1 + A5, e*B1 + B4, steps-1, e) 
    M2 = bini(A5, -B3-B4, steps-1, e)
    M3 = bini(A1, B4, steps-1, e)
    M4 = bini(e*A4+A5, -e*B1 + B3, steps-1, e)
    M5 = bini(A1 + e*A4, e*B2 + B4, steps-1, e)
    M6 = bini(A2 + A6, B1 + e*B4, steps-1, e)
    M7 = bini(A2, -B1 - B2, steps-1, e) #
    M8 = bini(A6, B1, steps-1, e)
    M9 = bini(A2 + e*A3, B2 - e*B4, steps-1, e)
    M10 = bini(e*A3 + A6, B1 + e*B3, steps-1, e)
    
    # put C together
    C[:m2, :p2] = (1/e)*(M1+M2-M3+M4) #C1
    C[:m2, p2:] = (1/e)*(-M3+M5)      #C2
    C[m2:m3, :p2] = M4+M6-M10         #C3 error from bini paper -M10 from +M10
    C[m2:m3, p2:] = M1-M5+M9          #C4 error from bini paper -M5 from +M5
    C[m3:, :p2] = (1/e)*(-M8+M10)     #C5
    C[m3:, p2:] = (1/e)*(M6+M7-M8+M9) #C6
    
    return C

In [153]:
C = bini(A,B, 2, e=1e-4)

In [154]:
la.norm(C-A@B, 'fro')/la.norm(C)

0.36658522301461355

In [155]:
C-A@B

array([[-5.43e+00, -4.29e+00,  1.89e+00,  3.14e+00, -6.31e+00, -4.40e-01,
        -2.57e-01,  1.06e+00,  1.36e+00,  6.60e-01,  1.57e+00, -3.56e+00,
        -8.59e-01, -3.70e+00,  3.41e+00, -2.21e-02,  6.55e-01,  2.23e+00,
         6.03e-01, -8.59e-02,  4.44e-16],
       [ 4.53e-08,  1.59e-08,  4.20e-08,  4.23e-09,  6.38e-09,  5.24e-08,
         1.00e-08, -9.00e-09, -1.17e-07, -5.09e-08,  2.53e-08,  9.84e-09,
         1.78e-08,  4.57e-09, -1.80e-08,  1.38e-08, -4.33e-08,  1.05e-08,
        -3.75e-08,  2.87e-08,  0.00e+00],
       [ 3.15e+00,  9.44e-01, -2.43e+00, -8.31e-01, -3.90e-01,  5.65e+00,
         4.87e+00, -1.43e+00, -6.36e+00,  7.14e-03, -1.48e-01,  2.26e+00,
        -2.11e+00, -8.72e-02,  1.31e+00,  2.50e+00, -6.92e-01, -5.19e+00,
        -7.52e-01, -2.32e+00, -4.44e-16],
       [ 5.01e-08,  3.97e-08,  3.54e-08,  1.11e-08, -2.34e-08,  2.99e-08,
        -6.44e-09, -1.30e-08,  4.39e-09,  3.10e-08,  3.74e-08, -1.38e-08,
         1.15e-08,  2.14e-08,  1.99e-09,  7.80e-09,  3.06e-0

In [145]:
C

array([[2.31, 2.63, 2.  , 2.83, 1.36, 2.09, 2.66, 2.82, 3.15, 3.28, 3.11,
        2.35, 3.71, 2.16, 2.77, 3.43, 3.34, 2.58, 2.97, 3.64, 2.21],
       [2.02, 2.43, 2.01, 2.67, 1.72, 1.81, 2.53, 2.73, 2.5 , 2.72, 2.48,
        1.61, 3.09, 1.98, 2.61, 2.82, 2.98, 2.22, 2.75, 3.03, 2.21],
       [3.45, 4.26, 3.11, 4.38, 2.57, 3.64, 4.25, 4.45, 4.3 , 4.69, 4.97,
        2.76, 5.22, 3.64, 3.85, 4.9 , 4.92, 3.46, 4.41, 5.47, 3.36],
       [2.22, 2.53, 1.82, 2.04, 2.01, 1.98, 2.62, 2.6 , 2.95, 2.13, 2.64,
        1.3 , 3.09, 2.4 , 2.2 , 2.69, 2.54, 1.76, 2.47, 3.3 , 2.39],
       [2.69, 3.58, 2.3 , 3.85, 2.6 , 3.16, 4.16, 4.33, 4.18, 3.46, 4.16,
        2.86, 4.96, 3.15, 3.29, 4.38, 4.21, 3.41, 4.22, 4.66, 3.25],
       [1.95, 1.97, 1.56, 2.36, 1.65, 1.68, 2.49, 2.14, 2.48, 2.37, 2.6 ,
        1.39, 2.96, 2.27, 1.93, 2.18, 2.52, 2.11, 2.79, 2.79, 1.92],
       [3.12, 4.03, 2.75, 3.83, 2.92, 2.93, 4.  , 4.68, 3.53, 3.83, 3.74,
        2.6 , 4.85, 3.69, 3.01, 4.35, 4.04, 2.86, 3.97, 4.63, 3.47],

In [132]:
np.dot(A,B)

array([[2.31, 2.63, 2.  , 2.83, 1.36, 2.09, 2.66, 2.82, 3.15, 3.28, 3.11,
        2.35, 3.71, 2.16, 2.77, 3.43, 3.34, 2.58, 2.97, 3.64, 2.21],
       [2.02, 2.43, 2.01, 2.67, 1.72, 1.81, 2.53, 2.73, 2.5 , 2.72, 2.48,
        1.61, 3.09, 1.98, 2.61, 2.82, 2.98, 2.22, 2.75, 3.03, 2.21],
       [3.45, 4.26, 3.11, 4.38, 2.57, 3.64, 4.25, 4.45, 4.3 , 4.69, 4.97,
        2.76, 5.22, 3.64, 3.85, 4.9 , 4.92, 3.46, 4.41, 5.47, 3.36],
       [2.22, 2.53, 1.82, 2.04, 2.01, 1.98, 2.62, 2.6 , 2.95, 2.13, 2.64,
        1.3 , 3.09, 2.4 , 2.2 , 2.69, 2.54, 1.76, 2.47, 3.3 , 2.39],
       [2.69, 3.58, 2.3 , 3.85, 2.6 , 3.16, 4.16, 4.33, 4.18, 3.46, 4.16,
        2.86, 4.96, 3.15, 3.29, 4.38, 4.21, 3.41, 4.22, 4.66, 3.25],
       [1.95, 1.97, 1.56, 2.36, 1.65, 1.68, 2.49, 2.14, 2.48, 2.37, 2.6 ,
        1.39, 2.96, 2.27, 1.93, 2.18, 2.52, 2.11, 2.79, 2.79, 1.92],
       [3.12, 4.03, 2.75, 3.83, 2.92, 2.93, 4.  , 4.68, 3.53, 3.83, 3.74,
        2.6 , 4.85, 3.69, 3.01, 4.35, 4.04, 2.86, 3.97, 4.63, 3.47],

In [119]:
int(5/3)

1