### Final Implementation of TT-cross

The function for truncation the SVD is implemented. It doesn't work for full-rank matrices. The function **truncated_SVD** doesn't work in case if we want to truncate the rank with given accuracy **eps**. 

In [6]:
import numpy as np
import numpy.linalg as la
from tt import *
import maxvolpy as mv

In [7]:
A = np.array(
[[[  [1, 8, 3], 
    [1, 2, 4], 
    [1, 3, 5]   ],
[   [1, 8, 3], 
    [1, 2, 4], 
    [1, 7, 5]   ],
[   [1, 2, 9], 
    [1, 5, 4], 
    [1, 3, 5]   ], 
[   [8, 2, 3], 
    [1, 4, 4], 
    [1, 3, 5]   ]], 
[[  [1, 8, 3], 
    [1, 2, 4], 
    [1, 3, 5]   ],
[   [1, 8, 3], 
    [1, 2, 4], 
    [1, 3, 9]   ],
[   [1, 2, 3], 
    [1, 5, 4], 
    [1, 3, 5]   ], 
[   [6, 2, 3], 
    [1, 4, 1], 
    [1, 3, 5]   ]],
[[  [1, 8, 3], 
    [1, 2, 2], 
    [1, 3, 5]   ],
[   [1, 8, 3], 
    [1, 2, 6], 
    [1, 3, 5]   ],
[   [1, 2, 3], 
    [1, 5, 4], 
    [1, 3, 5]   ], 
[   [1, 2, 3], 
    [1, 4, 4], 
    [1, 3, 5]   ]]]    )

B = np.random.rand(10, 20, 13, 5, 3)
C = np.random.rand(4, 4, 4, 4)

In [20]:
def TT(A, eps=1e-3, method='maxvol', test=False):
    
    def lhs(Nl, Nr, r, s):
        if(Nl > Nr):
            m = Nr
        else:
            m = Nl
        summ = 0
        for k in range(r+1, m):
            summ += s[k]**2
            print('summ = {}'.format(summ))
        return np.sqrt(summ)
    
    def compar(lhs, rhs):
        if(lhs <= rhs):
            return True
        else:
            return False
        
    def truncated_SVD(Nl, Nr, r, u, s, vh, eps): 
        l = lhs(Nl, Nr, r, s)
        
        while(compar(l, rhs) == False):
            u = u[:, :-1]
            s = s[:-1]
            vh = vh[:-1, :]
        return u, s, vh  
    
    def skeleton(M, r, delta):
        
        # initialization: M_p - previous, M_n - next
        assert(M.shape[1] >= r)
        J = np.arange(r)
        Mp = np.zeros((M.shape[0], M.shape[1]))
        while(True):

            C = M[:, J]
            Q, T = la.qr(C)
            I = maxvol.maxvol(Q)
            R = M[I, :].T
            Q, T = la.qr(R)
            J = maxvol.maxvol(Q)
            Q_hat = Q[J, :]
            Mn = M[:, J] @ (Q @ la.solve(Q_hat.T, Q.T).T).T
#             np.linalg.solve(Qhat.T, Q.T).T
            # CHANGE IT: SOLVE LINEAR SYSTEM BESIDES OF INVERTING MATRIX
            print("rows", I)
            print("columns", J)
            if(la.norm(Mn - Mp, ord='fro') > delta * la.norm(Mn, ord='f')):
                break;

        M_hat_inv = M[I, :]
        M_hat_inv = M_hat_inv[:, J]
        M_hat_inv = la.inv(M_hat_inv)

        return C, M_hat_inv, R
    
    def test(A, G, eps):
        G1 = np.tensordot(G[0], G[1], axes=1)
        for i in range(len(A.shape)-2):
            G1 = np.tensordot(G1, G[i+2], axes=1) 

        G1 = G1.reshape(A.shape)

        if(np.linalg.norm(G1 - A) < 1e-3):
            print('\t testing: ok')
        else:
            print('\t testing: failed')
    
    n = A.shape
    d = len(n)

    Nl = n[0]
    Nr = np.prod(n[1:])

    M = A.reshape((Nl, Nr))

    if(method == 'SVD'):
        nrm = la.norm(M, ord='fro')
        rhs = (eps*nrm) / (np.sqrt(d-1))
    
        u, s, vh = np.linalg.svd(M, full_matrices=True, compute_uv=True)

        r = 0
        s_full = np.zeros((u.shape[1], vh.shape[0]))
        for i in range(s.shape[0]):
            s_full[i, i] = s[i]
            if (s[i] > 1e-14):
                r += 1
        s_nnz = s[:r]
        
        u = u[:, :r]
        s_full = s_full[:r, :r]
        vh = vh[:r, :]
        
        u, s, vh = truncated_SVD(Nl, Nr, s_nnz.shape[0], u, s_nnz, vh, eps)

        G = list()
        G.append(u.reshape(1, n[0], r))
        
        M = s_full @ vh       

        for k in range(1, d-1):

            Nl = n[k]
            Nr = Nr // n[k]

            M = M.reshape((r*Nl, Nr))

            u, s, vh = np.linalg.svd(M, full_matrices=True, compute_uv=True)
            
            nnz = 0
            s_full = np.zeros((u.shape[1], vh.shape[0]))
            for i in range(s.shape[0]):
                s_full[i, i] = s[i]
                if (s[i] > 1e-14):
                    nnz += 1
            s_nnz = s[:nnz]

            u, s, vh = truncated_SVD(Nl, Nr, s_nnz.shape[0], u, s_nnz, vh, eps)

            r_prev = r
            r = nnz
            
            u = u[:, :r]
            s_full = s_full[:r, :r]
            vh = vh[:r, :]
            
            G.append(u.reshape(r_prev, n[k], r))
            
            M = s_full @ vh
                
        G.append(M.reshape(r, n[d-1], 1))
        
        print("CARRIAGES")
        for i in range(d):
            print(G[i].shape)
            
        if(test == True):
            test(A, G, eps)
            
        return G
        
    if(method == 'maxvol'):
        
        if(test == True):
            test(A, G)
        return None

In [21]:
G = TT(A, eps=1e-3, method='SVD', test=True)

CARRIAGES
(1, 3, 3)
(3, 4, 8)
(8, 3, 3)
(3, 3, 1)


In [None]:
np.linalg.solve(Qhat.T, Q.T).T