In [1]:
import tensorflow as tf
import numpy as np
from numpy import linalg as la

  from ._conv import register_converters as _register_converters


In [2]:
tf.reset_default_graph()
tf.set_random_seed(25)

In [None]:
def bini(A, B, steps, e=1e-8):
    
    #Check Dimensions
    (m, n) = A.get_shape().as_list()
    #rn assuming that m is bigger than n, nn and p
    (nn, p) = B.get_shape().as_list()
    if n != nn: raise ValueError("incompatible dimensions")
    
    #pre-allocate output matrix
    C = tf.zeros([m,p])
    
    """
    This is the notation I use from Bini's 1980 paper

    |A1, A4|  |B1, B2|  =  |C1, C2|
    |A2, A5|  |B3, B4|     |C3, C4|
    |A3, A6|               |C5, C6|
    """
    
    #Base case
    if steps == 0 or m == 1 or n == 1 or p == 1:
        C = tf.matmul(A,B)
        return C
    
    #Static peeling
    if (3**steps > m) or (2**steps > n) or (2**steps > p):
        raise ValueError("Too many steps/ too small matricies for static peeling")
    
    if (m % 3**steps) != 0:
        extra_rows = m % 3**steps
        
        #C[:m-extra_rows, :] = 
        Cmat = bini(A[:m-extra_rows, :], B, steps, e)
        #C[m-extra_rows:, :] =
        
        # need to expand dims if slice of A is a vector, and expand dims if it is
        A_slice = A[m-extra_rows:, :]
        (x, y) = A_slice.get_shape().as_list()
        
        # vector case
        if x == 1:
            # need to expand dims here
            
        else:
            # don't need to expand dims
            
        Crows = @B
        return C
    if (n % 2**steps) != 0:
        extra_cols = n % (2**steps)
        
        C = bini(A[:, :n-extra_cols], B[:n-extra_cols,:], steps, e)
        C = C + A[:, n-extra_cols:]@B[n-extra_cols:, :]
        return C
    if (p % 2**steps) != 0:
        multiP = p//(2**steps) #multipler to find how large to make the bini matrix
        extra_cols = p % (2**steps)
        
        C[:, :p-extra_cols] = bini(A, B[:, :p-extra_cols], steps, e)
        C[:, p-extra_cols:] = A@B[:, p-extra_cols:]
        return C
    
    """
    Dynamic peeling causes issues because the ideal epsilon value is determined by 
    the shape of the matrix and in dynamic peeling, the shape of the matrix
    is changed every recursive step which results in dimensions with a different 
    ideal epsilon value
    
    #Dynamic peeling
    if m % 3 == 1:
        C[:m-1, :] = bini(A[:m-1,:],B, steps, e)
        C[m-1,:] = A[m-1,:]@B
        return C
    if m % 3 == 2:
        C[:m-2, :] = bini(A[:m-2,:],B, steps, e)
        C[m-2:,:] = A[m-2:,:]@B
        return C
    if n % 2 == 1:
        C = bini(A[:, :n-1], B[:n-1,:], steps, e)
        C = C + np.outer(A[:,n-1],B[n-1,:])
        return C
    if p % 2 == 1:
        C[:, :p-1] = bini(A, B[:,:p-1], steps, e)
        C[:,p-1] = A@B[:,p-1]
        return C
    """
 

    # split up the matricies once rows of A are divisible by 3
    # and cols of A and rows and cols of are divisible by 2
    m2 = int(m/3) #first third of the rows of A
    m3 = m2*2     #second third of the rows of A
    n2 = int(n/2) #half of the cols of A
    p2 = int(p/2) #half of the cols of B
    #nn2 = int(nn/2) # half of the rows of B
    
    A1 = A[:m2, :n2]
    A2 = A[m2:m3, :n2]
    A3 = A[m3:, :n2]
    A4 = A[:m2, n2:]
    A5 = A[m2:m3, n2:]
    A6 = A[m3:, n2:]
    
    B1 = B[:n2, :p2]
    B2 = B[:n2, p2:]
    B3 = B[n2:, :p2]
    B4 = B[n2:, p2:]
    
    #bini(A, B, steps, e=0.1)
    # conquer
    
    # check if TF has a special fun for scalar mul
    M1 = bini(A1 + A5, e*B1 + B4, steps-1, e) 
    M2 = bini(A5, -B3-B4, steps-1, e)
    M3 = bini(A1, B4, steps-1, e)
    M4 = bini(e*A4+A5, -e*B1 + B3, steps-1, e)
    M5 = bini(A1 + e*A4, e*B2 + B4, steps-1, e)
    M6 = bini(A2 + A6, B1 + e*B4, steps-1, e)
    M7 = bini(A2, -B1 - B2, steps-1, e) #
    M8 = bini(A6, B1, steps-1, e)
    M9 = bini(A2 + e*A3, B2 - e*B4, steps-1, e)
    M10 = bini(e*A3 + A6, B1 + e*B3, steps-1, e)
    
    # nation building
    C[:m2, :p2] = (1/e)*(M1+M2-M3+M4) #C1
    C[:m2, p2:] = (1/e)*(-M3+M5)      #C2
    C[m2:m3, :p2] = M4+M6-M10         #C3 error from bini paper -M10 from +M10
    C[m2:m3, p2:] = M1-M5+M9          #C4 error from bini paper -M5 from +M5
    C[m3:, :p2] = (1/e)*(-M8+M10)     #C5
    C[m3:, p2:] = (1/e)*(M6+M7-M8+M9) #C6
    
    return C

In [14]:
def strass(A, B, steps):
  
    #Check Dimensions
    # tensor.get_shape().as_list()
    (m,n) = A.get_shape().as_list()
    (nn, p) = B.get_shape().as_list()

    #old code case m, n, nn, and p as ints
    
    if n != nn: raise ValueError("incompatible dimensions")
    C = tf.zeros([m,p])
    
    #Base case
    if steps == 0 or m ==1 or n ==1 or p == 1:
        C = tf.matmul(A,B)
        return C
    
    #Dynamic peeling
    # *****************
    if m % 2 == 1:
        #C[:m-1, :] 
        Cmat= strass(A[:m-1,:],B, steps)
        #C[m-1,:]
        Crow = tf.matmul(tf.expand_dims(A[m-1,:],0),B)
        return tf.concat([Cmat, Crow], 0)
    if n % 2 == 1:
        Cmat = strass(A[:, :n-1], B[:n-1,:], steps)
        C = tf.add(Cmat,  tf.matmul(tf.expand_dims(A[:,n-1],1),tf.expand_dims(B[n-1,:],0)))
        return C
    if p % 2 == 1:
        #C[:, :p-1]
        Cmat = strass(A, B[:,:p-1], steps)
        #C[:,p-1]
        Ccol = tf.matmul(A,tf.expand_dims(B[:,p-1],1))
        return tf.concat([Cmat, Ccol], 1)
    
    # divide when m, n and p are all even
    m2 = int(m/2)
    n2 = int(n/2)
    p2 = int(p/2)
    A11 = A[:m2,:n2] 
    A12 = A[:m2,n2:]
    A21 = A[m2:,:n2] 
    A22 = A[m2:,n2:]
    B11 = B[:n2,:p2]   
    B12 = B[:n2,p2:]
    B21 = B[n2:,:p2] 
    B22 = B[n2:,p2:]
    
    # conquer
    M1 = strass(A11, tf.subtract(B12,B22)   ,steps-1)
    M2 = strass(tf.add(A11,A12), B22   ,steps-1)
    M3 = strass(tf.add(A21,A22),B11    ,steps-1)
    M4 = strass(A22    ,tf.subtract(B21,B11),steps-1)
    M5 = strass(tf.add(A11, A22), tf.add(B11, B22),steps-1)
    M6 = strass( tf.subtract(A12,A22), tf.add(B21,B22),steps-1)
    M7 = strass(tf.subtract(A11,A21), tf.add(B11, B12),steps-1)
    
    # conquer    
    #C[:m2,:p2] 
    C11 = tf.add(tf.subtract(tf.add(M5, M4), M2), M6) 
    #C[:m2,p2:]
    C12 = tf.add(M1, M2) 
    #C[m2:,:p2] 
    C21 = tf.add(M3,M4)
    #C[m2:,p2:]
    C22 = tf.subtract(tf.subtract(tf.add(M1,M5), M3), M7)
    
    C1 = tf.concat([C11, C12], 1)
    C2 = tf.concat([C21,C22], 1)
    C = tf.concat([C1,C2], 0)
    
    return C

In [16]:
with tf.Session() as sess:
    a = np.random.rand(4,4)
    w = np.identity(4)

    a = tf.constant(a, dtype=tf.float64)
    w = tf.constant(w)
    
    m = sess.run(tf.matmul(a,w))
    s = sess.run(strass(a,w,1))
    
    print("A: \n", sess.run(a), '\n')
    print("matmul: \n", m)
    print("Strass: \n", s)
    print("\n m-s: \n", m-s, '\n')
    print("Strassen Error: ", la.norm(s-m, 'fro')/la.norm(s))


[[0.30192839 0.40296055 0.15443776 0.06376608]
 [0.68119201 0.2294643  0.49874755 0.86329657]]
A: 
 [[0.30192839 0.40296055 0.15443776 0.06376608]
 [0.68119201 0.2294643  0.49874755 0.86329657]
 [0.51075172 0.21686642 0.33115609 0.26726707]
 [0.92909124 0.42198669 0.16615135 0.41729923]] 

matmul: 
 [[0.30192839 0.40296055 0.15443776 0.06376608]
 [0.68119201 0.2294643  0.49874755 0.86329657]
 [0.51075172 0.21686642 0.33115609 0.26726707]
 [0.92909124 0.42198669 0.16615135 0.41729923]]
Strass: 
 [[0.30192839 0.40296055 0.15443776 0.06376608]
 [0.68119201 0.2294643  0.49874755 0.86329657]
 [0.51075172 0.21686642 0.33115609 0.26726707]
 [0.92909124 0.42198669 0.16615135 0.41729923]]

 m-s: 
 [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  1.11022302e-16  0.00000000e+00 -1.11022302e-16]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 1.11022302e-16  0.00000000e+00 -2.22044605e-16  0.00000000e+00]] 

Strassen Error:  1.564695914969

In [4]:
#with tf.Session() as sess:
    a = np.random.rand(4,4)
    w = np.identity(4)

    a = tf.constant(a, dtype=tf.float64)
    w = tf.constant(w)
    
    a = sess.run(a)
    w = sess.run(w)
    
    s = strass(a,w, 2)

In [5]:
s

array([[0.914128  , 0.04061079, 0.37017392, 0.2666174 ],
       [0.68150857, 0.72203561, 0.02497239, 0.55707602],
       [0.84485791, 0.30971141, 0.5360951 , 0.19091652],
       [0.92230492, 0.078655  , 0.69188986, 0.423988  ]])