In [1]:
import numpy as np
import time
from scipy.sparse.linalg import LinearOperator
import scipy.sparse.linalg as spsalg
import backend.numpy_ext as tenpy
import CPD.standard_ALS3 as stnd_ALS
from CPD.common_kernels import solve_sys, compute_lin_sys
import CPD.common_kernels as ck
import matplotlib.pyplot as plt

In [11]:
def compute_Jxxdel(X,loc,delta):
    n = len(X)
    s= X[loc].shape[0]
    R = X[loc].shape[1]
    D = np.ones((R,R))
    for i in range(n):
        if i == loc:
            continue
        else:
            D = np.einsum('ij,ij->ij',np.einsum('kr,kz->rz',X[i],X[i]),D)
    
    prod = np.einsum('iz,zr->ir',delta,D)
    return prod


def compute_Jxydel(X,loc1,loc2,delta):
    n = len(X)
    R = X[loc1].shape[1]
    D = np.ones((R,R))
    for i in range(n):
        if i == loc1 or i == loc2:
            continue
        else:
            D = np.einsum('ij,ij->ij',np.einsum('kr,kz->rz',X[i],X[i]),D)
            
    temp = np.einsum("jr,jz->rz",X[loc2],delta)
    prod = np.einsum("iz,zr,rz->ir",X[loc1],D,temp)
    
    return prod


def compute_JTJdel(X,delta):
    K = np.zeros_like(delta)
    n = len(X)
    for j in range(n):
        K[j] = compute_Jxxdel(X,j,delta[j])
        for i in range(n):
            if i ==j:
                continue
            else:
                K[j]+= compute_Jxydel(X,j,i,delta[i])

    return K

def einstr(order,contract_index,wrt_index):
    if wrt_index == order-1 and contract_index == wrt_index-1:
        str1 ="".join(['R' for j in range(order-2)])
        str1= str1+ "".join([chr(ord('a')+contract_index)])+ "".join([chr(ord('a')+wrt_index)])
        
        str2 = "".join([chr(ord('a')+contract_index)])+'R'
        str3 = "".join([chr(ord('a')+wrt_index)])+'R'
        
    elif contract_index == order-1:
        str1 ="".join(['R' for j in range(order-1)])
        str1= str1+ "".join([chr(ord('a')+contract_index)])
        ls = list(str1)
        ls[wrt_index] = "".join([chr(ord('a')+wrt_index)])
        str1 = "".join(ls)
        
        str2 = "".join([chr(ord('a')+contract_index)])+'R'
        str3 = "".join([chr(ord('a')+wrt_index)])+'R'
    else:
        str1= "".join([chr(ord('a')+j) for j in range(order)])
        str2= "".join([chr(ord('a')+contract_index)])+'R'
        str3= str1.replace("".join([chr(ord('a')+contract_index)]),'R')
    
    string = str1+','+str2 +'->'+str3
    
    
    return string

def flatten_Tensor(G,order,s,R):
    g = np.zeros(order*s*R)
    for i in range(order):
        offset1 = i*s*R
        for j in range(R):
            offset2 = j*s
            start = offset1 + offset2
            end = start + s
            g[start:end] = G[i][:,j]
    return g

def gen_gradient(X,T):
    order = T.ndim
    out = np.zeros_like(X)

    for i in range(order):
        inter = T.copy()
        R = out[i].shape[1]
        D = np.ones((R,R))
        for j in range(order):
            if i != j:
                inter = np.einsum(einstr(order,j,i),inter,X[j])
                D = np.einsum('ij,ij->ij',np.einsum('kr,kz->rz',X[j],X[j]),D)
                
        out[i] = -inter + X[i]@D
        
    return flatten_Tensor(out,order,s,R)
    




def create_LinOp(X,Regu):
    n = len(X)
    s = X[0].shape[0]
    R= X[0].shape[1]
    
    def mv(delta):
        delta = compute_matrices(delta,s,R)
        K = np.zeros_like(delta)
        for j in range(n):
            K[j] = compute_Jxxdel(X,j,delta[j])
            for i in range(n):
                if i ==j:
                    continue
                else:
                    K[j]+= compute_Jxydel(X,j,i,delta[i])
                    
            K[j]+= Regu*delta[j]
        vec = flatten_Tensor(K,n,s,R)
        return vec 
    
    V = LinearOperator(shape = (n*s*R,n*s*R), matvec=mv)
    return V
def compute_matrices(x,s,R):
    A = x[:s*R].reshape(s,R,order= 'F')
    B = x[s*R:2*s*R].reshape(s,R,order='F')
    C = x[2*s*R:3*s*R].reshape(s,R,order='F')
    
    return [A,B,C]

In [12]:
s = 6

num_gen = 1

R = 4

max_iter = 100
num_init = 1
for k in range(num_gen):
    
    a = np.random.rand(s,R)
    
    b = np.random.rand(s,R)

    c = np.random.rand(s,R)


    T = np.einsum('ia,ja,ka->ijk', a,b,c)
    for j in range(num_init):

        A = np.random.rand(s,R)

        B = np.random.rand(s,R)

        C = np.random.rand(s,R)

        state = np.random.get_state()

        X = np.array([A,B,C])

        global cg_iters

        cg_iters=0

        def cg_call(v):
            global cg_iters
            cg_iters= cg_iters+1

        
        X = np.array([A,B,C])
        
        
        res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
        print('Residual is',res)


        

        print('starting NLS')
        start = time.time()
        for i in range(max_iter):
            Regu = 10**-6
            tolerance = 10**-6
            NLS_conv = 10**-6

            Op= create_LinOp(X,Regu)
            [delta,_] = spsalg.cg(Op,-gen_gradient(X,T), tol= tolerance,callback=cg_call)
            #print('Number of cg iterations',cg_iters)
            #cg_iters=0
            #print('Iteration number:',i)
            delta = np.array(compute_matrices(delta,s,R))
            X+=delta
            res = ck.get_residual3(tenpy,T,X[0],X[1],X[2])
            print('Residual is',res)
            if np.linalg.norm(delta.reshape(-1), ord= np.inf)<NLS_conv:
                print('NLS Iterations:',i)
                break

        end = time.time()

        print("Time taken",end-start)
        
        #print('state is',state)
        print('Residual is',res)
        print('Total cg iterations',cg_iters)
        print('--------------------')
    
    print('******************')


Residual is 7.869508592507158
starting NLS
Residual is 17.887435413303294
Residual is 4.883071370713876
Residual is 4.1602948480086805
Residual is 0.8086667820894337
Residual is 7.422656790405477
Residual is 0.65858393744622
Residual is 0.2221764747100079
Residual is 0.6784641946120851
Residual is 0.3200727580423783
Residual is 0.017418432512409883
Residual is 2.0017236212279952e-05
Residual is 1.411927774734272e-08
Residual is 1.411927774734272e-08
NLS Iterations: 12
Time taken 0.37099695205688477
Residual is 1.411927774734272e-08
Total cg iterations 1334
--------------------
******************
