# Naive NLS

In [81]:
import numpy as np
import time

def calculate_J(s,R,A,B,C):
    Q = s**3
    J0 = np.empty([Q,s*R])
    
    for r in range(R):
        J0[:,s*r:s*(r+1)] = -np.kron(np.kron(np.eye(s),B[:,r].reshape(-1,1)), C[:,r].reshape(-1,1))
    
    
    J1 = np.empty([Q,s*R])
    
    for r in range(R):
        J1[:,s*r:s*(r+1)] = -np.kron(np.kron(A[:,r].reshape(-1,1),np.eye(s)), C[:,r].reshape(-1,1))
    
        
        
    J2 = np.empty([Q,s*R])
    
    for r in range(R):
        J2[:,s*r:s*(r+1)] = -np.kron(np.kron(A[:,r].reshape(-1,1),B[:,r].reshape(-1,1)), np.eye(s))
        
    J = np.concatenate((J0,J1,J2),axis=1)
    
    
    return J


def create_objective(T,A,B,C):
    F = (T - np.einsum('ia,ja,ka->ijk',A,B,C)).reshape(-1)
    return F

def solve_system(J,F,Regu):
    n = (J.T@J).shape[0]
    
    delta = np.linalg.solve(J.T@J + Regu*np.eye(n),-J.T@F)
    
    return delta

def compute_matrices(x,s,R):
    A = x[:s*R].reshape(s,R,order= 'F')
    B = x[s*R:2*s*R].reshape(s,R,order='F')
    C = x[2*s*R:3*s*R].reshape(s,R,order='F')
    
    return [A,B,C]

In [82]:
## For checking if solution is unique
from sympy.utilities.iterables import multiset_permutations

def measure_dist(a,b,c,A,B,C):
    A= flipsign(tenpy,A)
    A = A/np.linalg.norm(A,axis=0)
    B = flipsign(tenpy,B)
    B = B/np.linalg.norm(B,axis=0)
    C = flipsign(tenpy,C)
    C = C/np.linalg.norm(C,axis=0)
    ind = []
    
    
    nums= np.arange(a.shape[1])


    a = a/np.linalg.norm(a,axis=0)
    b = b/np.linalg.norm(b,axis=0)
    c = c/np.linalg.norm(c,axis=0)

    for perm in multiset_permutations(nums):
        norm = np.linalg.norm(A[:,perm]-a)
        ind.append([perm,norm])


    ind = np.array(ind)
    val_index = np.argmin(ind[:,1])
    norm = ind[:,1][val_index]

    norm2 = np.linalg.norm(B[:,(ind[:,0][val_index])] - b)
    norm3 = np.linalg.norm(C[:,(ind[:,0][val_index])] - c)

    return(norm,norm2,norm3)

# NLS With Tensor Contraction

In [196]:
import numpy as np


def compute_Jxxdel(X,loc,delta):
    n = len(X)
    s= X[loc].shape[0]
    R = X[loc].shape[1]
    D = np.ones((R,R))
    for i in range(n):
        if i == loc:
            continue
        else:
            D = np.einsum('ij,ij->ij',np.einsum('kr,kz->rz',X[i],X[i]),D)
    
    prod = np.einsum('iz,zr->ir',delta,D)
    return prod


def compute_Jxydel(X,loc1,loc2,delta):
    n = len(X)
    R = X[loc1].shape[1]
    D = np.ones((R,R))
    for i in range(n):
        if i == loc1 or i == loc2:
            continue
        else:
            D = np.einsum('ij,ij->ij',np.einsum('kr,kz->rz',X[i],X[i]),D)
            
    temp = np.einsum("jr,jz->rz",X[loc2],delta)
    prod = np.einsum("iz,zr,rz->ir",X[loc1],D,temp)
    
    return prod


def compute_JTJdel(X,delta):
    K = np.zeros_like(delta)
    n = len(X)
    for j in range(n):
        K[j] = compute_Jxxdel(X,j,delta[j])
        for i in range(n):
            if i ==j:
                continue
            else:
                K[j]+= compute_Jxydel(X,j,i,delta[i])

    return K



def compute_negJTF(X,T):
    out = np.zeros_like(X)
    out[0] = -np.einsum('ijk,ja,ka->ia',T,X[1],X[2]) + X[0]@((X[1].T@X[1])*(X[2].T@X[2]))
    out[1] = - np.einsum('ijk,ia,ka->ja',T,X[0],X[2]) + X[1]@((X[0].T@X[0])*(X[2].T@X[2]))
    out[2] = -np.einsum('ijk,ia,ja->ka',T,X[0],X[1]) + X[2]@((X[0].T@X[0])*(X[1].T@X[1]))
    return out

 

In [197]:
def compute_coefficient_matrix(G,n1,n2):
    ret = np.ones(G[0].shape)
    for i in range(len(G)):
        if i!=n1 and i!=n2:
            ret = np.einsum("ij,ij->ij",ret,G[i])
    return ret

def fast_hessian_contract(A,X):
    N = len(A)
    ## Preprocessing step: should be moved outside of contraction 
    G = []
    for mat in A:
        G.append(mat.T.dot(mat))
    
    ret = []
    for n in range(N):
        for p in range(N):
            ## Computation of M should be done outside of contraction
            M = compute_coefficient_matrix(G,n,p)
            if n==p:
                Y = np.einsum("iz,zr->ir",X[p],M)
            else:
                B = np.einsum("jr,jz->rz",A[p],X[p])
                Y = np.einsum("iz,zr,rz->ir",A[n],M,B)
            if p==0:
                ret.append(Y)
            else:
                ret[n] += Y
    return ret