In [5]:
!pip install flax

Collecting flax
[?25l  Downloading https://files.pythonhosted.org/packages/91/b8/ab292e363cb8758a391541b7942f175f79b3ac06a477dd4495de7c8c91f6/flax-0.2.0-py3-none-any.whl (84kB)
[K     |███▉                            | 10kB 16.1MB/s eta 0:00:01[K     |███████▊                        | 20kB 1.8MB/s eta 0:00:01[K     |███████████▋                    | 30kB 2.3MB/s eta 0:00:01[K     |███████████████▌                | 40kB 2.6MB/s eta 0:00:01[K     |███████████████████▍            | 51kB 2.1MB/s eta 0:00:01[K     |███████████████████████▎        | 61kB 2.3MB/s eta 0:00:01[K     |███████████████████████████▏    | 71kB 2.6MB/s eta 0:00:01[K     |███████████████████████████████ | 81kB 2.8MB/s eta 0:00:01[K     |████████████████████████████████| 92kB 2.5MB/s 
Installing collected packages: flax
Successfully installed flax-0.2.0


In [21]:
from jax import ops
import jax.numpy as jnp
from jax.numpy import maximum
from jax import grad
import jax.numpy as np
from jax.numpy import linalg as la
from jax import jit
from jax import lax
import numpy as nnp
import math
import time
import flax
from flax import struct

In [7]:
class TT_Tensor:
    def __init__(self, core, d, n, r):
        self.core = core
        self.d = d
        self.n = n
        self.r = r

In [22]:
@struct.dataclass
class Model():
    core: list

In [23]:
def tt_model(G):
    return Model(G)

In [11]:
def mod_to_old(ttm):
    d = len(ttm.core)
    n = []
    r = []
    for X in ttm.core:
        n.append(X.shape[1])
        r.append(X.shape[0])
    r.append(ttm.core[len(ttm.core) - 1].shape[2])
    return TT_Tensor(ttm.core, d, n, r)

In [14]:
def tt_to_tensor(tt, shape):
    G = tt.core
    B = G[0]
    for i in range(1, len(G)):
        X = np.reshape(G[i], (G[i].shape[0], G[i].shape[1] * G[i].shape[2]))
        B = np.reshape(B, (int(np.size(B) / G[i].shape[0]), G[i].shape[0]))
        B = np.dot(B, X)
    B = np.reshape(B, shape)
    return B

In [55]:
def tt_svd(A, eps):
    d = len(A.shape)
    N = np.size(A)
    n = A.shape
    delta = eps / math.sqrt(d - 1)

    C = A # tmp tensor

    G = [] # tt-cores
    r = [] # tt-ranks
    r.append(1)

    for k in range(1, d):
        C = np.reshape(C, (r[k-1] * n[k-1], int(N / (r[k-1] * n[k-1]))))
  
        # calc low-rank approximation
        u, s, v = la.svd(C)
        rc = s.shape[0]  - (s < delta).sum()
        r.append(rc)
        '''
        sum = 0 
        nsize = np.size(s)
        rres = np.size(s)
        for rk in range(0, nsize):
            for m in range(rk+1, nsize):
                sum = sum + (s[m] ** 2)
            if (sum <= (eps ** 2) * la.norm(A)) and (rres > rk):
                rres = rk + 1 
            sum = 0
        r .append(rres) '''

        G.append(np.reshape(u[:, :r[k]], (r[k-1], n[k-1], r[k])))
        s = np.diag(s)
        C = np.dot(s[:r[k], :r[k]], v[:r[k], :])
        N = (N * r[k]) / (n[k-1] * r[k-1])
    
    if len(C.shape) == 2:
        C = np.reshape(C, (C.shape[0], C.shape[1], 1))
        r.append(C.shape[2])
        
    G.append(C)
    
    return TT_Tensor(G, d, n, r)

In [15]:
def tt_sum(tt1, tt2):
    if tt1.d == tt2.d:
        d = tt1.d
    else:
        raise Exception("Different dimensions of tensors.")
    
    r = [r1 + r2 for r1, r2 in zip(tt1.r, tt2.r)]
    
    if tt1.r[0] == tt2.r[0]:
        r[0] = tt1.r[0]
    else:
        raise Exception("Different sizes of first mode.")
        
    if tt1.r[d] == tt2.r[d]:
        r[d] = tt1.r[d]
    else:
        raise Exception("Different sizes of last mode.")
        
    n = tt1.n
    G = []
    
    for i in range(0, d):
        Gi = nnp.zeros((r[i], n[i], r[i+1]))
        Gi[0 : tt1.r[i], :, 0 : tt1.r[i+1]] = tt1.core[i]
        Gi[r[i] - tt2.r[i] : r[i], :, r[i+1] - tt2.r[i+1]: r[i+1]] = tt2.core[i]   
        G.append(Gi)
    
    return TT_Tensor(G, d, n, r)

In [78]:
def tt_round_model(tt, delta):  
    G = tt.core
    d = len(tt.core)
    rn = []
    time1 = time.perf_counter()    
    
    for k in range(d-1, 0, -1):
        r1, n, r2 = G[k].shape
        G[k] = np.reshape(G[k], (r1, n * r2))
        G[k] = np.transpose(G[k])
        G[k], R = la.qr(G[k])
        G[k] = np.transpose(G[k])
        G[k-1] = np.einsum('ijk,lk->ijl', G[k-1], R) 
    
    for k in range(0, d - 1):
        
        if k == 0:
            r1, n1, r2 = G[k].shape
            r3 = G[k+2].shape[0]
            n2 = int(G[k+1].shape[1] / r3)
        elif k == d - 2:
            r1 = G[k].shape[0]
            r2 = G[k+1].shape[0]
            n1 = int(G[k].shape[1] / r2)
            r3 = 1
            n2 = G[k+1].shape[1]
        else:
            r1 = G[k].shape[0]
            r2 = G[k+1].shape[0]
            n1 = int(G[k].shape[1] / r2)
            r3 = G[k+2].shape[0]
            n2 = int(G[k+1].shape[1] / r3)

        G[k] = np.reshape(G[k], (r1 * n1, r2))

        u, s, v = la.svd(G[k], full_matrices=False)
 
        rc = s.shape[0]  - (s < delta).sum()
        rn.append(rc)
        s = np.diag(s)  
         
        v = np.dot(s, v) 
        G[k] = np.reshape(u, (r1, n1, u.shape[1]))
        G[k+1] = np.einsum('ij,ik->kj', G[k+1], v.T)
    
    
    if len(G[d-1].shape) == 2:
        G[d-1] = np.reshape(G[d-1], (G[d-1].shape[0], G[d-1].shape[1], 1))

    time2 = time.perf_counter()
    print("\n time: ", time2 - time1)
    
    return Model(G), rn

In [80]:
# input tensor A
'''
A = np.array([[1/2, 1/3, 1/4], [1/3, 1/4, 1/5], [1/4, 1/5, 1/6]])
A = np.arange(24)
B = A.reshape(4, 3, 2)
A = B
'''

'''
A = np.ones(1000000)
A = A.reshape(100, 100, 100)
#4 3 2
'''

A = []
for i in range(100):
  for j in range(100):
    for q in range(100):
      A.append(1 / (i + j + q + 3))
A = np.asarray(A)
A = A.reshape(100, 100, 100)

'''
A = []
dim = 15
x = 0
for i in range(2 ** dim):
  A.append(math.sin(x))
  x = x + math.pi / (2 ** dim - 1)
shp = [2 for i in range(dim)]
A = np.asarray(A)
A = A.reshape(shp)
print(A)'''

print("A:\n", A)

eps = 1e-2 # accuracy


'''
B = tt_to_tensor(tt, A.shape)
print("\nB:\n", B)

print("\nttr:\n")
tts = tt_sum(tt, tt)
for X in tts.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print("\nRanks: ", tts.r)

C = tt_to_tensor(tts, A.shape)
print("\nC:\n", C)

D = B + B
print("\nD:\n", D)

print("\nttr:\n")
ttr = tt_round(tt, eps)

for X in ttr.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print(tt.r, ttr.r)

B = tt_to_tensor(ttr, A.shape)
print("\nB:\n", B)
print(la.norm(A-B), eps*la.norm(A))
'''

tt = tt_svd(A, eps)

print("\ntt:")
for X in tt.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print("\nRanks: ", tt.r)

tts = tt_sum(tt, tt)

print("\ntts:")
for X in tts.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print("\nRanks: ", tts.r)

tts_s = tt_model(tts.core)

fast_tt_round = jit(tt_round_model, static_argnums=(1))
ttr, rn = fast_tt_round(tts_s, eps / math.sqrt(len(tts_s.core) - 1))
#ttr, rn = tt_round_model(tts_s, eps / math.sqrt(len(tts_s.core) - 1))

for i in range(len(rn)):  
    ttr.core[i] = ttr.core[i][:, :, :rn[i]]
    ttr.core[i+1] = ttr.core[i+1][:rn[i], :, :]



#ttr = tt_round(tts, eps, eps / math.sqrt(tts_s.d - 1))

#ttr2 = ttpy.matrix.round(tts)
#ttr2 = ttpy.tensor.round(tts, eps)

print("\nttr:")
for X in ttr.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

ttro = mod_to_old(ttr)

print("\nRanks: ", ttro.r)

B = tt_to_tensor(tts, A.shape)
#print("\nB = A + A:\n", B)
C = tt_to_tensor(ttro, A.shape)
#print("\nC = round(B):\n", C)
print("\n", la.norm(B-C), " <= ", eps*la.norm(B), la.norm(B-C) <= eps*la.norm(B))

A:
 [[[0.33333334 0.25       0.2        ... 0.01       0.00990099 0.00980392]
  [0.25       0.2        0.16666667 ... 0.00990099 0.00980392 0.00970874]
  [0.2        0.16666667 0.14285715 ... 0.00980392 0.00970874 0.00961538]
  ...
  [0.01       0.00990099 0.00980392 ... 0.00507614 0.00505051 0.00502513]
  [0.00990099 0.00980392 0.00970874 ... 0.00505051 0.00502513 0.005     ]
  [0.00980392 0.00970874 0.00961538 ... 0.00502513 0.005      0.00497512]]

 [[0.25       0.2        0.16666667 ... 0.00990099 0.00980392 0.00970874]
  [0.2        0.16666667 0.14285715 ... 0.00980392 0.00970874 0.00961538]
  [0.16666667 0.14285715 0.125      ... 0.00970874 0.00961538 0.00952381]
  ...
  [0.00990099 0.00980392 0.00970874 ... 0.00505051 0.00502513 0.005     ]
  [0.00980392 0.00970874 0.00961538 ... 0.00502513 0.005      0.00497512]
  [0.00970874 0.00961538 0.00952381 ... 0.005      0.00497512 0.00495049]]

 [[0.2        0.16666667 0.14285715 ... 0.00980392 0.00970874 0.00961538]
  [0.16666667 0.14