In [177]:
import numpy as np
import matplotlib.pyplot as plt
from tensor.operation.kruskal import kruskal
from tensor.operation.khatri_rao import khatri_rao
from tensor.operation.matricize import matricize

## CP decomposition

Algorithm: CP decomposition ALS

$\text{Input:} \quad \text{Tensor } X \in \mathbb{R}^{J_1 \times J_2 \times \cdots \times J_N}, \text{ rank } R$

$\text{Output:} \quad \text{Factor matrices } U_1, U_2, \cdots, U_N \\
$

\begin{align*}
&\text{Initialize } U_1, U_2, \cdots, U_N \text{ with random matrices} \\
&\text{while } \text{not converged } \text{do} \\
& \quad \quad \text{for } n = 1, 2, \cdots, N \text{ do} \\
& \quad \quad \quad \text{Compute } \mathbf{A}^{(n)^T} = \left( \mathbf{A}^{(N)} \odot \mathbf{A}^{(N-1)} \odot \cdots \odot \mathbf{A}^{(n+1)} \odot \mathbf{A}^{(n-1)} \cdot \mathbf{A}^{(1)} \right )^{\dag} X^T_{(n)} \\
& \quad \quad \text{end for} \\
&\text{end while} \\
\end{align*}

In [178]:
def cpDecomposition(X: np.ndarray, rank: int, maxIter: int = 5, tol: float = 1e-4, initialisation: str = 'random'):
    """cpDecomposition performs CP decomposition of a tensor X using alternating least squares.

    Args:
        X (np.ndarray): Tensor to be decomposed.
        rank (int): Rank of the decomposition.
        maxIter (int, optional): Maximum number of iterations. Defaults to 1000.
        tol (float, optional): Tolerance for the stopping criterion. Defaults to 1e-6.

    Returns:
        np.ndarray: Factor matrices of the decomposition.

    """
    # Initialize factor matrices
    if initialisation == 'random':
        U = [np.random.rand(X.shape[i], rank) for i in range(X.ndim)]
    elif initialisation == 'hosvd':
        U = []
        for i in range(X.ndim):
            M = np.linalg.svd(matricize(X, i))[0]
            if M.shape[1] < rank:
                M_ = np.zeros((M.shape[0], rank - M.shape[1]))
                M = np.concatenate((M, M_), axis=1)
            else:
                M = M[:, :rank]
            print("M", M.shape)
            U.append(M)
    else:
        print("Invalid initialisation method")
        return

    for i in range(X.ndim):
        print(i, U[i].shape , matricize(X, i).shape)

    # Iterate until convergence
    for itr in range(maxIter):

        for i in range(X.ndim):

            khatriRaoProd = np.ones((1, rank))
            for j in range(X.ndim, 0, -1):
                if j != (i + 1):
                    khatriRaoProd = khatri_rao(khatriRaoProd, U[j - 1])

            U[i] = matricize(X, i) @ np.linalg.pinv(khatriRaoProd).T
        print("Iteration ", itr+1, " completed. loss =", np.linalg.norm(X - kruskal(*U)))

        # Check for convergence
        if np.linalg.norm(X - kruskal(*U)) < tol:
            break

    return np.array(U)


In [179]:
# X = np.array([[[1, -1], [0, 0]], [[0, 0], [1, 1]]])
X = np.random.randn(3, 3, 3)
# print(X.shape)
ans = cpDecomposition(X, 2, maxIter=100, initialisation='random')

0 (3, 2) (3, 9)
1 (3, 2) (3, 9)
2 (3, 2) (3, 9)
Iteration  1  completed. loss = 3.121962667675004
Iteration  2  completed. loss = 3.042200831869479
Iteration  3  completed. loss = 3.022234971950612
Iteration  4  completed. loss = 3.0111868639556714
Iteration  5  completed. loss = 3.0039591341283014
Iteration  6  completed. loss = 2.9992241502438355
Iteration  7  completed. loss = 2.9962067916565633
Iteration  8  completed. loss = 2.99434329508674
Iteration  9  completed. loss = 2.9932205026531005
Iteration  10  completed. loss = 2.9925525870755396
Iteration  11  completed. loss = 2.9921538706293638
Iteration  12  completed. loss = 2.9919099209457793
Iteration  13  completed. loss = 2.991753003431649
Iteration  14  completed. loss = 2.9916441883991554
Iteration  15  completed. loss = 2.991561550330812
Iteration  16  completed. loss = 2.991492888488279
Iteration  17  completed. loss = 2.991431436127014
Iteration  18  completed. loss = 2.991373408537754
Iteration  19  completed. loss = 2.

In [180]:
print("A = \n",ans[0]) 
print("B = \n",ans[1])
print("C = \n",ans[2])

A = 
 [[ -3.55787548  -3.19350492]
 [ -1.05576725 -12.75620143]
 [  1.05679598  -4.15441272]]
B = 
 [[ 0.4074453   0.16899306]
 [ 0.69986873 -0.1788886 ]
 [ 0.21453157  0.07221839]]
C = 
 [[0.27789052 0.66825859]
 [0.73250864 0.20892182]
 [0.83599293 0.07792074]]
