In [4]:
from jax import grad
import jax.numpy as np
from jax.numpy import linalg as la
import math

In [6]:
def tt_svd(A, eps):
    d = len(A.shape)
    N = np.size(A)
    n = A.shape

    C = A # tmp tensor

    G = [] # tt-cores
    r = [] # tt-ranks
    r.append(1)

    for k in range(1, d):
        C = np.reshape(C, (r[k-1] * n[k-1], int(N / (r[k-1] * n[k-1]))))
  
        # calc low-rank approximation
        u, s, v = la.svd(C)
        sum = 0 
        nsize = np.size(s)
        rres = np.size(s)
        for rk in range(0, nsize):
            for m in range(rk+1, nsize):
                sum = sum + (s[m] ** 2)
            if (sum <= (eps ** 2) * la.norm(A)) and (rres > rk):
                rres = rk + 1 
            sum = 0
        r .append(rres) 

        G.append(np.reshape(u[:, :r[k]], (r[k-1], n[k-1], r[k])))
        s = np.diag(s)
        C = np.dot(s[:r[k], :r[k]], v[:r[k], :])
        N = (N * r[k]) / (n[k-1] * r[k-1])
    
    C = np.reshape(C, (C.shape[0], C.shape[1], 1))
    G.append(C) 

    return G, r

In [8]:
def tt_to_tensor(G, shape):
    B = G[0]
    for i in range(1, len(G)):
        X = np.reshape(G[i], (G[i].shape[0], G[i].shape[1] * G[i].shape[2]))
        B = np.reshape(B, (int(np.size(B) / G[i].shape[0]), G[i].shape[0]))
        B = np.dot(B, X)
    B = np.reshape(B, shape)
    return B

In [14]:
# input tensor A
'''
A = np.array([[1/2, 1/3, 1/4], [1/3, 1/4, 1/5], [1/4, 1/5, 1/6]])
A = np.arange(24)
B = A.reshape(4, 3, 2)
A = B
'''

'''
A = np.ones(24)
A = A.reshape(4, 3, 2)
'''

A = []
for i in range(4):
  for j in range(3):
    for q in range(2):
      A.append(1 / (i + j + q + 3))
A = np.asarray(A)
A = A.reshape(4, 3, 2)

print("A:\n", A)

eps = 1e-2 # accuracy

G, r = tt_svd(A, eps)

for X in G:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print("\nRanks: ", r)

B = tt_to_tensor(G, A.shape)
print("\nB:\n", B)

A:
 [[[0.33333334 0.25      ]
  [0.25       0.2       ]
  [0.2        0.16666667]]

 [[0.25       0.2       ]
  [0.2        0.16666667]
  [0.16666667 0.14285715]]

 [[0.2        0.16666667]
  [0.16666667 0.14285715]
  [0.14285715 0.125     ]]

 [[0.16666667 0.14285715]
  [0.14285715 0.125     ]
  [0.125      0.11111111]]]

Core:
 [[[-0.6449674  -0.68758214]
  [-0.5144456   0.09107864]
  [-0.42883325  0.42677027]
  [-0.36805514  0.58034706]]] 
Shape:  (1, 4, 2)

Core:
 [[[-0.6883633   0.40954968]
  [-0.55492496 -0.16872749]
  [-0.46614918 -0.4165662 ]]

 [[-0.02050409  0.66778165]
  [ 0.00807783  0.35927123]
  [ 0.02082663  0.23508413]]] 
Shape:  (2, 3, 2)

Core:
 [[[ 0.704035  ]
  [ 0.57473695]]

 [[-0.01484055]
  [ 0.01817922]]] 
Shape:  (2, 2, 1)

Ranks:  [1, 2, 2]

B:
 [[[0.3332316  0.2501208 ]
  [0.2501208  0.19999859]
  [0.19999862 0.16651078]]

 [[0.2502261  0.19973116]
  [0.19973119 0.16667086]
  [0.16667086 0.14320199]]

 [[0.2000426  0.16661721]
  [0.16661723 0.14285436]
  [0.