In [4]:
from jax import ops
import jax.numpy as jnp
from jax.numpy import maximum
from jax import grad
import jax.numpy as np
from jax.numpy import linalg as la
from jax import jit
from jax import lax
import numpy as nnp
import math
import tt as ttpy
import time
import flax
from flax import struct

In [5]:
class TT_Tensor:
    def __init__(self, core, d, n, r):
        self.core = core
        self.d = d
        self.n = n
        self.r = r

In [6]:
@struct.dataclass
class Model():
    core: list

In [7]:
def tt_model(G):
    return Model(G)

In [77]:
def mod_to_old(ttm):
    d = len(ttm.core)
    n = []
    r = []
    for X in ttm.core:
        n.append(X.shape[1])
        r.append(X.shape[0])
    r.append(ttm.core[len(ttm.core) - 1].shape[2])
    return TT_Tensor(ttm.core, d, n, r)

In [9]:
def tt_svd(A, eps):
    d = len(A.shape)
    N = np.size(A)
    n = A.shape

    C = A # tmp tensor

    G = [] # tt-cores
    r = [] # tt-ranks
    r.append(1)

    for k in range(1, d):
        C = np.reshape(C, (r[k-1] * n[k-1], int(N / (r[k-1] * n[k-1]))))
  
        # calc low-rank approximation
        u, s, v = la.svd(C)
        sum = 0 
        nsize = np.size(s)
        rres = np.size(s)
        for rk in range(0, nsize):
            for m in range(rk+1, nsize):
                sum = sum + (s[m] ** 2)
            if (sum <= (eps ** 2) * la.norm(A)) and (rres > rk):
                rres = rk + 1 
            sum = 0
        r .append(rres) 

        G.append(np.reshape(u[:, :r[k]], (r[k-1], n[k-1], r[k])))
        s = np.diag(s)
        C = np.dot(s[:r[k], :r[k]], v[:r[k], :])
        N = (N * r[k]) / (n[k-1] * r[k-1])
    
    if len(C.shape) == 2:
        C = np.reshape(C, (C.shape[0], C.shape[1], 1))
        r.append(C.shape[2])
        
    G.append(C)
    
    return TT_Tensor(G, d, n, r)

In [10]:
def tt_to_tensor(tt, shape):
    G = tt.core
    B = G[0]
    for i in range(1, len(G)):
        X = np.reshape(G[i], (G[i].shape[0], G[i].shape[1] * G[i].shape[2]))
        B = np.reshape(B, (int(np.size(B) / G[i].shape[0]), G[i].shape[0]))
        B = np.dot(B, X)
    B = np.reshape(B, shape)
    return B

In [11]:
def tt_sum(tt1, tt2):
    if tt1.d == tt2.d:
        d = tt1.d
    else:
        raise Exception("Different dimensions of tensors.")
    
    r = [r1 + r2 for r1, r2 in zip(tt1.r, tt2.r)]
    
    if tt1.r[0] == tt2.r[0]:
        r[0] = tt1.r[0]
    else:
        raise Exception("Different sizes of first mode.")
        
    if tt1.r[d] == tt2.r[d]:
        r[d] = tt1.r[d]
    else:
        raise Exception("Different sizes of last mode.")
        
    n = tt1.n
    G = []
    
    for i in range(0, d):
        Gi = nnp.zeros((r[i], n[i], r[i+1]))
        Gi[0 : tt1.r[i], :, 0 : tt1.r[i+1]] = tt1.core[i]
        Gi[r[i] - tt2.r[i] : r[i], :, r[i+1] - tt2.r[i+1]: r[i+1]] = tt2.core[i]   
        G.append(Gi)
    
    return TT_Tensor(G, d, n, r)
    

In [19]:
def tt_round(tt, delta):  
    #delta = np.ones([1])
    #delta = delta.astype(float)
    #delta = (eps / math.sqrt(tt.d - 1))
    G = tt.core
    d = len(tt.core)
    n = tt.n
    r = tt.r
    rn = []
    time1 = time.perf_counter()  
    #ind = d-1
    #ind = ind.astype('int')
    #G, r, ind = lax.fori_loop(d-1, 0, loop_body, (G, n, r, ind))
    
    '''
    ind = d-1
    carry, y = lax.scan(scanf1, (G, n, r, ind), [jnp.arange(d-1, 0, -1)])
    G = y
    n = carry[1]
    r = carry[2]'''
    
    
    for k in range(d-1, 0, -1):
        G[k] = np.reshape(G[k], (r[k], n[k] * r[k+1]))
        G[k] = np.transpose(G[k])
        G[k], R = la.qr(G[k])
        G[k] = np.transpose(G[k])
        G[k-1] = np.einsum('ijk,lk->ijl', G[k-1], R)
        r[k] = G[k-1].shape[2]
    '''
    carry2, km = lax.scan(scanf2, (G, n, r), [jnp.arange(d-1)])
    G = carry[0]
    n = carry[1]
    r = carry[2]'''
    
    
    for k in range(0, d - 1):
        G[k] = np.reshape(G[k], (r[k] * n[k], r[k+1]))
        G[k+1] = np.reshape(G[k+1], (r[k+1] , n[k+1] * r[k+2]))
        u, s, v = la.svd(G[k])
        rc = s.shape[0]
        p = 0
        for x in s:
            if x < delta:
                x = 0
                if rc > p:
                    rc = p
            p = p + 1
        
        s = np.diag(s)
        u = u[:, :rc]
        s = s[:rc, :rc]
        v = v[:, :rc]
        v = np.dot(v, s)
        G[k] = np.reshape(u, (r[k], n[k], rc))   
        G[k+1] = np.einsum('ij,ik->jk', G[k+1], v)
        r[k+1] = rc

    
    if len(G[d-1].shape) == 2:
        G[d-1] = np.reshape(G[d-1], (G[d-1].shape[1], G[d-1].shape[0], 1))
        '''
        if G[tt.d-1].shape[1] == 1: 
            G[tt.d-1] = np.reshape(G[tt.d-1], (G[tt.d-1].shape[1], G[tt.d-1].shape[0], 1))
        else:
            G[tt.d-1] = np.reshape(G[tt.d-1], (G[tt.d-1].shape[0], G[tt.d-1].shape[1], 1))
        '''
        #G[tt.d-1] = np.reshape(G[tt.d-1], (tt.r[tt.d-1], int(G[tt.d-1].shape[0] / tt.r[tt.d-1]), 1))

    for X in G:
        rn.append(X.shape[0])
    rn.append(G[d-1].shape[2])

    time2 = time.perf_counter()
    print("\n time: ", time2 - time1)
    
    return Model(G, n, r)

        

In [81]:
def tt_round_model(tt, delta):  
    G = tt.core
    d = len(tt.core)
    rn = []
    time1 = time.perf_counter()    
    
    for k in range(d-1, 0, -1):
        r1, n, r2 = G[k].shape
        G[k] = np.reshape(G[k], (r1, n * r2))
        G[k] = np.transpose(G[k])
        G[k], R = la.qr(G[k])
        G[k] = np.transpose(G[k])
        G[k-1] = np.einsum('ijk,lk->ijl', G[k-1], R) 
    
    for k in range(0, d - 1):
        
        if k == 0:
            r1, n1, r2 = G[k].shape
            r3 = G[k+2].shape[0]
            n2 = int(G[k+1].shape[1] / r3)
        elif k == d - 2:
            r1 = G[k].shape[0]
            r2 = G[k+1].shape[0]
            n1 = int(G[k].shape[1] / r2)
            r3 = 1
            n2 = G[k+1].shape[1]
        else:
            r1 = G[k].shape[0]
            r2 = G[k+1].shape[0]
            n1 = int(G[k].shape[1] / r2)
            r3 = G[k+2].shape[0]
            n2 = int(G[k+1].shape[1] / r3)

        G[k] = np.reshape(G[k], (r1 * n1, r2))
        #G[k+1] = np.reshape(G[k+1], (r2 , n2 * r3))
        u, s, v = la.svd(G[k])
 
        rc = s.shape[0]  - (s < delta).sum()
        rn.append(rc)
        
        '''
        for k in range(len(s)):
            val = lax.cond(x < delta, lambda x: np.float32(0), lambda x: np.float32(x), s[k])
            ops.index_update(s, k, val)
        rc = la.matrix_rank(s)'''
        '''
        for x in s:
            if x < delta:
                x = 0
                if rc > p:
                    rc = p
            p = p + 1
        '''
        
        
        '''u = u[:, :rc]
        s = s[:rc, :rc]
        v = v[:, :rc]'''
        
        #s = s * (np.arange(s.shape[0]) < rc)
        s = np.diag(s)  
        
        #for i in range(u.shape[0]):
            #ops.index_update(u, i, u[i] * (np.arange(u.shape[1]) < rc)) 
            
        #for i in range(v.shape[0]):
            #ops.index_update(v, i, v[i] * (np.arange(v.shape[1]) < rc)) 
        
        v = np.dot(v, s)
        #G[k] = np.reshape(u, (r1, n1, rc))   
        G[k] = np.reshape(u, (r1, n1, u.shape[1]))
        G[k+1] = np.einsum('ij,ik->kj', G[k+1], v)
    
    
    if len(G[d-1].shape) == 2:
        G[d-1] = np.reshape(G[d-1], (G[d-1].shape[0], G[d-1].shape[1], 1))

    time2 = time.perf_counter()
    print("\n time: ", time2 - time1)
    
    return Model(G), rn


In [83]:
# input tensor A
'''
A = np.array([[1/2, 1/3, 1/4], [1/3, 1/4, 1/5], [1/4, 1/5, 1/6]])
A = np.arange(24)
B = A.reshape(4, 3, 2)
A = B
'''

'''
A = np.ones(1000000)
A = A.reshape(100, 100, 100)
#4 3 2
'''

A = []
for i in range(100):
  for j in range(100):
    for q in range(100):
      A.append(1 / (i + j + q + 3))
A = np.asarray(A)
A = A.reshape(100, 100, 100)

print("A:\n", A)

eps = 1e-2 # accuracy


'''
B = tt_to_tensor(tt, A.shape)
print("\nB:\n", B)

print("\nttr:\n")
tts = tt_sum(tt, tt)
for X in tts.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print("\nRanks: ", tts.r)

C = tt_to_tensor(tts, A.shape)
print("\nC:\n", C)

D = B + B
print("\nD:\n", D)

print("\nttr:\n")
ttr = tt_round(tt, eps)

for X in ttr.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print(tt.r, ttr.r)

B = tt_to_tensor(ttr, A.shape)
print("\nB:\n", B)
print(la.norm(A-B), eps*la.norm(A))
'''

tt = tt_svd(A, eps)

print("\ntt:")
for X in tt.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print("\nRanks: ", tt.r)

tts = tt_sum(tt, tt)

print("\ntts:")
for X in tts.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

print("\nRanks: ", tts.r)

tts_s = tt_model(tts.core)

fast_tt_round = jit(tt_round_model, static_argnums=(1))
ttr, rn = fast_tt_round(tts_s, eps / math.sqrt(len(tts_s.core) - 1))
#ttr, rn = tt_round_model(tts_s, eps / math.sqrt(len(tts_s.core) - 1))

for i in range(len(rn)):  
    ttr.core[i] = ttr.core[i][:, :, :rn[i]]
    ttr.core[i+1] = ttr.core[i+1][:rn[i], :, :]



#ttr = tt_round(tts, eps, eps / math.sqrt(tts_s.d - 1))

#ttr2 = ttpy.matrix.round(tts)
#ttr2 = ttpy.tensor.round(tts, eps)

print("\nttr:")
for X in ttr.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)

ttro = mod_to_old(ttr)

print("\nRanks: ", ttro.r)

B = tt_to_tensor(tts, A.shape)
print("\nB = A + A:\n", B)
C = tt_to_tensor(ttro, A.shape)
print("\nC = round(B):\n", C)
print("\n", la.norm(B-C), " <= ", eps*la.norm(B), la.norm(B-C) <= eps*la.norm(B))







A:
 [[[0.33333334 0.25       0.2        ... 0.01       0.00990099 0.00980392]
  [0.25       0.2        0.16666667 ... 0.00990099 0.00980392 0.00970874]
  [0.2        0.16666667 0.14285715 ... 0.00980392 0.00970874 0.00961538]
  ...
  [0.01       0.00990099 0.00980392 ... 0.00507614 0.00505051 0.00502513]
  [0.00990099 0.00980392 0.00970874 ... 0.00505051 0.00502513 0.005     ]
  [0.00980392 0.00970874 0.00961538 ... 0.00502513 0.005      0.00497512]]

 [[0.25       0.2        0.16666667 ... 0.00990099 0.00980392 0.00970874]
  [0.2        0.16666667 0.14285715 ... 0.00980392 0.00970874 0.00961538]
  [0.16666667 0.14285715 0.125      ... 0.00970874 0.00961538 0.00952381]
  ...
  [0.00990099 0.00980392 0.00970874 ... 0.00505051 0.00502513 0.005     ]
  [0.00980392 0.00970874 0.00961538 ... 0.00502513 0.005      0.00497512]
  [0.00970874 0.00961538 0.00952381 ... 0.005      0.00497512 0.00495049]]

 [[0.2        0.16666667 0.14285715 ... 0.00980392 0.00970874 0.00961538]
  [0.16666667 0.14


ttr:

Core:
 [[[-0.18944347  0.423647    0.5245455  -0.4898473 ]
  [-0.18175687  0.34709617  0.280628   -0.03517883]
  [-0.17531592  0.2901255   0.13226435  0.1529472 ]
  [-0.16972654  0.24541616  0.0362777   0.22418745]
  [-0.16476701  0.20910303 -0.02800601  0.2398188 ]
  [-0.16029747  0.17888495 -0.0718234   0.22823888]
  [-0.15622269  0.15327851 -0.10181919  0.20366187]
  [-0.15247469  0.13127197 -0.1221652   0.1735315 ]
  [-0.1490024   0.11214315 -0.13560955  0.141865  ]
  [-0.14576681  0.09536071 -0.14402148  0.11082199]
  [-0.14273696  0.08052143 -0.14871411  0.08153443]
  [-0.13988796  0.06731328 -0.15062986  0.05456273]
  [-0.13719937  0.05548969 -0.1504579   0.03012324]
  [-0.13465433  0.04485274 -0.14871056  0.00825027]
  [-0.13223836  0.035242   -0.14577343 -0.01113945]
  [-0.12993948  0.02652448 -0.14194229 -0.02817926]
  [-0.12774715  0.01859068 -0.13744149 -0.04303871]
  [-0.12565233  0.01134744 -0.13244773 -0.05588675]
  [-0.12364708  0.00471675 -0.1270975  -0.0669034 