In [1]:
import numpy as np
import time
from scipy.sparse.linalg import LinearOperator
import scipy.sparse.linalg as spsalg
import backend.numpy_ext as tenpy
import CPD.standard_ALS3 as stnd_ALS
from CPD.common_kernels import solve_sys, compute_lin_sys
import CPD.common_kernels as ck
import matplotlib.pyplot as plt

In [7]:
from sympy.utilities.iterables import multiset_permutations

def measure_dist(a,b,c,A,B,C):
    A= flipsign(tenpy,A)
    A = A/np.linalg.norm(A,axis=0)
    B = flipsign(tenpy,B)
    B = B/np.linalg.norm(B,axis=0)
    C = flipsign(tenpy,C)
    C = C/np.linalg.norm(C,axis=0)
    ind = []
    
    
    nums= np.arange(a.shape[1])


    a = a/np.linalg.norm(a,axis=0)
    b = b/np.linalg.norm(b,axis=0)
    c = c/np.linalg.norm(c,axis=0)

    for perm in multiset_permutations(nums):
        norm = np.linalg.norm(A[:,perm]-a)
        ind.append([perm,norm])


    ind = np.array(ind)
    val_index = np.argmin(ind[:,1])
    norm = ind[:,1][val_index]

    norm2 = np.linalg.norm(B[:,(ind[:,0][val_index])] - b)
    norm3 = np.linalg.norm(C[:,(ind[:,0][val_index])] - c)

    return(norm,norm2,norm3)

def equilibrate(A,B,C):
    norm1 = np.linalg.norm(A,axis=0)
    norm2 = np.linalg.norm(B,axis =0)
    norm3 = np.linalg.norm(C,axis =0)
    delta = (norm1*norm2*norm3)**(1/3)
    A = A*delta/norm1
    B = B*delta/norm2
    C = C*delta/norm3
    return [A,B,C]

def flipsign(tenpy, U):
    """
    Flip sign of factor matrices such that largest magnitude
    element will be positive
    """
    midx = tenpy.argmax(U, axis=0)
    for i in range(U.shape[1]):
        if U[int(midx[i]), i] < 0:
            U[:, i] = -U[:, i]
    return U

In [20]:
def compute_Jxxdel(X,loc,delta):
    n = len(X)
    s= X[loc].shape[0]
    R = X[loc].shape[1]
    D = np.ones((R,R))
    for i in range(n):
        if i == loc:
            continue
        else:
            D = np.einsum('ij,ij->ij',np.einsum('kr,kz->rz',X[i],X[i]),D)
    
    prod = np.einsum('iz,zr->ir',delta,D)
    return prod


def compute_Jxydel(X,loc1,loc2,delta):
    n = len(X)
    R = X[loc1].shape[1]
    D = np.ones((R,R))
    for i in range(n):
        if i == loc1 or i == loc2:
            continue
        else:
            D = np.einsum('ij,ij->ij',np.einsum('kr,kz->rz',X[i],X[i]),D)
            
    temp = np.einsum("jr,jz->rz",X[loc2],delta)
    prod = np.einsum("iz,zr,rz->ir",X[loc1],D,temp)
    
    return prod


def compute_JTJdel(X,delta):
    K = np.zeros_like(delta)
    n = len(X)
    for j in range(n):
        K[j] = compute_Jxxdel(X,j,delta[j])
        for i in range(n):
            if i ==j:
                continue
            else:
                K[j]+= compute_Jxydel(X,j,i,delta[i])

    return K

def einstr(order,contract_index,wrt_index):
    if wrt_index == order-1 and contract_index == wrt_index-1:
        str1 ="".join(['R' for j in range(order-2)])
        str1= str1+ "".join([chr(ord('a')+contract_index)])+ "".join([chr(ord('a')+wrt_index)])
        
        str2 = "".join([chr(ord('a')+contract_index)])+'R'
        str3 = "".join([chr(ord('a')+wrt_index)])+'R'
        
    elif contract_index == order-1:
        str1 ="".join(['R' for j in range(order-1)])
        str1= str1+ "".join([chr(ord('a')+contract_index)])
        ls = list(str1)
        ls[wrt_index] = "".join([chr(ord('a')+wrt_index)])
        str1 = "".join(ls)
        
        str2 = "".join([chr(ord('a')+contract_index)])+'R'
        str3 = "".join([chr(ord('a')+wrt_index)])+'R'
    else:
        str1= "".join([chr(ord('a')+j) for j in range(order)])
        str2= "".join([chr(ord('a')+contract_index)])+'R'
        str3= str1.replace("".join([chr(ord('a')+contract_index)]),'R')
    
    string = str1+','+str2 +'->'+str3
    
    
    return string

def flatten_Tensor(G,order,s,R):
    g = np.zeros(order*s*R)
    for i in range(order):
        offset1 = i*s*R
        for j in range(R):
            offset2 = j*s
            start = offset1 + offset2
            end = start + s
            g[start:end] = G[i][:,j]
    return g

def gen_gradient(X,T):
    order = T.ndim
    out = np.zeros_like(X)

    for i in range(order):
        inter = T.copy()
        R = out[i].shape[1]
        s = out[i].shape[0]
        D = np.ones((R,R))
        for j in range(order):
            if i != j:
                inter = np.einsum(einstr(order,j,i),inter,X[j])
                D = np.einsum('ij,ij->ij',np.einsum('kr,kz->rz',X[j],X[j]),D)
                
        out[i] = -inter + X[i]@D
        
    return flatten_Tensor(out,order,s,R)
    




def create_LinOp(X,Regu):
    n = len(X)
    s = X[0].shape[0]
    R= X[0].shape[1]
    
    def mv(delta):
        delta = compute_matrices(delta,s,R)
        K = np.zeros_like(delta)
        for j in range(n):
            K[j] = compute_Jxxdel(X,j,delta[j])
            for i in range(n):
                if i ==j:
                    continue
                else:
                    K[j]+= compute_Jxydel(X,j,i,delta[i])
                    
            K[j]+= Regu*delta[j]
        vec = flatten_Tensor(K,n,s,R)
        return vec 
    
    V = LinearOperator(shape = (n*s*R,n*s*R), matvec=mv)
    return V

def compute_matrices(x,s,R):
    A = x[:s*R].reshape(s,R,order= 'F')
    B = x[s*R:2*s*R].reshape(s,R,order='F')
    C = x[2*s*R:3*s*R].reshape(s,R,order='F')
    
    return [A,B,C]

In [21]:
s =20

#for R in [40,41,42]:
num_gen = 1

R = 10
als_iter = 10000

max_iter = 250
num_init = 1

for k in range(num_gen):
    
    a = np.random.randn(s,R)
    b = np.random.randn(s,R)
    c = np.random.randn(s,R)
    
    T = np.einsum('ia,ja,ka->ijk', a,b,c)
    
    
    for j in range(num_init):
        A = np.random.randn(s,R)
        K = A.copy()
        P = A.copy()

        B = np.random.randn(s,R)
        O = B.copy()
        Q = B.copy()

        C = np.random.randn(s,R)
        M = C.copy()
        N = C.copy()

        state = np.random.get_state()

        X = np.array([A,B,C])

        global cg_iters

        cg_iters=0

        def cg_call(v):
            global cg_iters
            cg_iters= cg_iters+1

        
        X = np.array([A,B,C])
        
        
        res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
        print('Residual is',res)


        

        print('starting NLS')
        start = time.time()
        for i in range(max_iter):
            Regu = 10**-6
            tolerance = 10**-6
            L= create_LinOp(X,Regu)
            [delta,_] = spsalg.cg(L,-gen_gradient(X,T), tol= tolerance,callback=cg_call)
            delta = np.array(compute_matrices(delta,s,R))
            X+=delta
            #res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
            #print('Residual is',res)
            if np.linalg.norm(delta.reshape(-1), ord= np.inf)<10**-6:
                print('NLS Iterations:',i)
                break

        end = time.time()

        print("Time taken",end-start)
        res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
        #print('state is',state)
        print('Residual is',res)
        print('Total cg iterations',cg_iters)
        print('--------------------')
        
        
        
        X = np.array([P,Q,N])


        cg_iters=0

        def cg_call(v):
            global cg_iters
            cg_iters= cg_iters+1
        
        
        res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
        print('Residual is',res)

        last_norm = None
        

        print('starting NLS')
        start = time.time()
        for i in range(max_iter):
            Regu = 10**-6
            tolerance = 10**-6
            L= create_LinOp(X,Regu)
            g = gen_gradient(X,T)
            [delta,_] = spsalg.cg(L,-g, tol= tolerance,callback=cg_call,atol = last_norm)
            delta = np.array(compute_matrices(delta,s,R))
            last_norm = np.linalg.norm(delta)
            X+=delta
            #res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
            #print('Residual is',res)
            if np.linalg.norm(g.reshape(-1))<10**-4:
                print('NLS Iterations:',i)
                break

        end = time.time()

        print("Time taken for atol",end-start)
        res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
        #print('state is',state)
        print('Residual with atol is',res)
        print('Total cg iterations with atol',cg_iters)
        print('--------------------')
        
                
     
    print('******************')


('Residual computation took', 0.0008265972137451172, 'seconds')
Residual is 378.19795630318845
starting NLS
NLS Iterations: 11
Time taken 1.200024127960205
('Residual computation took', 0.0, 'seconds')
Residual is 2.3843969824633834e-12
Total cg iterations 1918
--------------------
('Residual computation took', 0.0009980201721191406, 'seconds')
Residual is 378.19795630318845
starting NLS
NLS Iterations: 14
Time taken for atol 0.44803881645202637
('Residual computation took', 0.0009615421295166016, 'seconds')
Residual with atol is 2.6964940720154707e-06
Total cg iterations with atol 530
--------------------
******************


In [22]:
from CPD.common_kernels import compute_number_of_variables, flatten_Tensor, reshape_into_matrices, solve_sys
from scipy.sparse.linalg import LinearOperator
import scipy.sparse.linalg as spsalg

try:
    import Queue as queue
except ImportError:
    import queue

def fast_hessian_contract(tenpy,X,A,gamma,regu=1):
    N = len(A)
    ret = []
    for n in range(N):
        for p in range(N):
            M = gamma[n][p]
            if n==p:
                Y = tenpy.einsum("iz,zr->ir",X[p],M)
            else:
                Y = tenpy.einsum("iz,zr,jr,jz->ir",A[n],M,A[p],X[p])
            if p==0:
                ret.append(Y)
            else:
                ret[n] += Y

    for i in range(N):
        ret[i] += regu*X[i]
    return ret

def fast_block_diag_precondition(tenpy,X,P):
    N = len(X)
    ret = []
    for i in range(N):
        Y = tenpy.solve_tri(P[i], X[i], True, False, True)
        Y = tenpy.solve_tri(P[i], Y, True, False, False)
        ret.append(Y)
    return ret

class CP_fastNLS_Optimizer():
    """Fast Nonlinear Least Square Method for CP is a novel method of
    computing the CP decomposition of a tensor by utilizing tensor contractions
    and preconditioned conjugate gradient to speed up the process of solving
    damped Gauss-Newton problem of CP decomposition.
    """

    def __init__(self,tenpy,T,A,cg_tol=1e-4,args=None):
        self.tenpy = tenpy
        self.T = T
        self.A = A
        self.cg_tol = cg_tol
        self.G = None
        self.gamma = None
        self.last_step_norm = None


    def _einstr_builder(self,M,s,ii):
        ci = ""
        nd = M.ndim
        if len(s) != 1:
            ci ="R"
            nd = M.ndim-1

        str1 = "".join([chr(ord('a')+j) for j in range(nd)])+ci
        str2 = (chr(ord('a')+ii))+"R"
        str3 = "".join([chr(ord('a')+j) for j in range(nd) if j != ii])+"R"
        einstr = str1 + "," + str2 + "->" + str3
        return einstr

    def compute_G(self):
        G = []
        for i in range(len(self.A)):
            G.append(self.tenpy.einsum("ij,ik->jk",self.A[i],self.A[i]))
        self.G = G

    def compute_coefficient_matrix(self,n1,n2):
        ret = self.tenpy.ones(self.G[0].shape)
        for i in range(len(self.G)):
            if i!=n1 and i!=n2:
                ret = self.tenpy.einsum("ij,ij->ij",ret,self.G[i])
        return ret

    def compute_gamma(self):
        N = len(self.A)
        result = []
        for i in range(N):
            result.append([])
            for j in range(N):
                if j>=i:
                    M = self.compute_coefficient_matrix(i,j)
                    result[i].append(M)
                else:
                    M = result[j][i]
                    result[i].append(M)
        self.gamma = result

    def compute_block_diag_preconditioner(self,Regu):
        n = self.gamma[0][0].shape[0]
        P = []
        for i in range(len(self.A)):
            P.append(self.tenpy.cholesky(self.gamma[i][i]+self.tenpy.eye(n)))
        return P


    def gradient(self):
        grad = []
        q = queue.Queue()
        for i in range(len(self.A)):
            q.put(i)
        s = [(list(range(len(self.A))),self.T)]
        while not q.empty():
            i = q.get()
            while i not in s[-1][0]:
                s.pop()
                assert(len(s) >= 1)
            while len(s[-1][0]) != 1:
                M = s[-1][1]
                idx = s[-1][0].index(i)
                ii = len(s[-1][0])-1
                if idx == len(s[-1][0])-1:
                    ii = len(s[-1][0])-2

                einstr = self._einstr_builder(M,s,ii)

                N = self.tenpy.einsum(einstr,M,self.A[ii])

                ss = s[-1][0][:]
                ss.remove(ii)
                s.append((ss,N))
            M = s[-1][1]
            g = -1*M + self.A[i].dot(self.gamma[i][i])
            grad.append(g)
        return flatten_Tensor(self.tenpy,grad)


    def create_fast_hessian_contract_LinOp(self,Regu):
        num_var = compute_number_of_variables(self.A)
        A = self.A
        gamma = self.gamma
        tenpy = self.tenpy
        template = self.A

        def mv(delta):
            delta = reshape_into_matrices(tenpy,delta,template)
            result = fast_hessian_contract(tenpy,delta,A,gamma,Regu)
            vec = flatten_Tensor(tenpy,result)
            return vec

        V = LinearOperator(shape = (num_var,num_var), matvec=mv)
        return V

    def create_block_precondition_LinOp(self,P):
        num_var = compute_number_of_variables(self.A)
        tenpy = self.tenpy
        template = self.A

        def mv(delta):

            delta = reshape_into_matrices(tenpy,delta,template)
            result = fast_block_diag_precondition(tenpy,delta,P)
            vec = flatten_Tensor(tenpy,result)
            return vec

        V = LinearOperator(shape = (num_var,num_var), matvec=mv)
        return V

    def update_A(self,delta):
        for i in range(len(delta)):
            self.A[i] += delta[i]



    def step(self,Regu):
        """global cg_iters
        def cg_call(v):
            global cg_iters
            cg_iters= cg_iters+1
        """

        self.compute_G()
        self.compute_gamma()
        g = self.gradient()
        mult_LinOp = self.create_fast_hessian_contract_LinOp(Regu)
        #P = self.compute_block_diag_preconditioner(Regu)
        #precondition_LinOp = self.create_block_precondition_LinOp(P)
        [delta,_] = spsalg.cg(mult_LinOp,-1*g,tol=self.cg_tol,callback=None,atol=self.last_step_norm)
        self.last_step_norm = self.tenpy.vecnorm(delta)
        delta = reshape_into_matrices(self.tenpy,delta,self.A)
        self.update_A(delta)
        return delta

In [23]:
start = time.time()
X = np.array([K,O,M])

cg_iters=0

opt = CP_fastNLS_Optimizer(tenpy,T,X,cg_tol=1e-6)

for i in range(max_iter):

    delta= np.array(opt.step(Regu))

    res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
    print('Residual is',res)

    if np.linalg.norm(delta.reshape(-1),ord = np.inf)<10**-8:
        print('NLS Iterations:',i)
        break
end = time.time()
print("Time taken for atol",end-start)
print("Residual is",res)

('Residual computation took', 0.000993967056274414, 'seconds')
Residual is 279.700007593367
('Residual computation took', 0.0, 'seconds')
Residual is 222.6140909461351
('Residual computation took', 0.0011854171752929688, 'seconds')
Residual is 230.82363911698064
('Residual computation took', 0.0, 'seconds')
Residual is 146.94323739734898
('Residual computation took', 0.0010228157043457031, 'seconds')
Residual is 186.3851770777031
('Residual computation took', 0.0009398460388183594, 'seconds')
Residual is 61.814473846521174
('Residual computation took', 0.0009391307830810547, 'seconds')
Residual is 26.303711915364385
('Residual computation took', 0.0009999275207519531, 'seconds')
Residual is 3.9885964694190004
('Residual computation took', 0.0, 'seconds')
Residual is 0.2972249597778691
('Residual computation took', 0.0006496906280517578, 'seconds')
Residual is 0.02615007104552042
('Residual computation took', 0.0014495849609375, 'seconds')
Residual is 0.007258679049673539
('Residual com