<a href="https://colab.research.google.com/github/fasghq/TT/blob/master/TT_Lib.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install flax

Collecting flax
[?25l  Downloading https://files.pythonhosted.org/packages/91/b8/ab292e363cb8758a391541b7942f175f79b3ac06a477dd4495de7c8c91f6/flax-0.2.0-py3-none-any.whl (84kB)
[K     |████████████████████████████████| 92kB 3.9MB/s 
Installing collected packages: flax
Successfully installed flax-0.2.0


In [2]:
from jax import ops
import jax.numpy as jnp
from jax.numpy import maximum
from jax import grad
import jax.numpy as np
from jax.numpy import linalg as la
from jax import jit
from jax import lax
import numpy as nnp
import math
import time
import flax
from flax import struct
import operator
import functools as ft

In [3]:
class TT_Tensor:
    def __init__(self, core, d, n, r):
        self.core = core
        self.d = d
        self.n = n
        self.r = r

In [4]:
@struct.dataclass
class Model():
    core: list

In [5]:
def mod_to_old(ttm):
    d = len(ttm.core)
    n = []
    r = []
    for X in ttm.core:
        n.append(X.shape[1])
        r.append(X.shape[0])
    r.append(ttm.core[d - 1].shape[2])
    return TT_Tensor(ttm.core, d, n, r)

In [6]:
def tt_get_d(tt):
  return len(tt.core)

In [7]:
def tt_get_n(tt):
  d = len(tt.core)
  n = []
  for X in tt.core:
        n.append(X.shape[1])
  return n

In [8]:
def tt_get_r(tt, *args):
    varargin = args
    nargin = 1 + len(varargin)
    d = len(tt.core)
    r = []
    for X in tt.core:
      r.append(X.shape[0])
    if nargin == 2:
      if varargin == "m":
        r.append(tt.core[d - 1].shape[3])
    else:
      r.append(tt.core[d - 1].shape[2])
    return r

In [9]:
def tt_to_tensor(tt, shape):
    G = tt.core
    B = G[0]
    for i in range(1, len(G)):
        X = np.reshape(G[i], (G[i].shape[0], G[i].shape[1] * G[i].shape[2]))
        B = np.reshape(B, (int(np.size(B) / G[i].shape[0]), G[i].shape[0]))
        B = np.dot(B, X)
    B = np.reshape(B, shape)
    return B

In [10]:
def tt_svd(A, delta):

    #time1 = time.perf_counter()

    d = len(A.shape)
    N = np.size(A)
    n = A.shape

    C = A
    G = [] 
    r = [] 
    rk1 = 1

    for k in range(1, d):
        C = np.reshape(C, (rk1 * n[k-1], int(N / (rk1 * n[k-1]))))      
        u, s, v = la.svd(C, full_matrices=False)
        rc = s.shape[0]  - (s < delta).sum()
        r.append(rc)
        G.append(np.reshape(u, (rk1, n[k-1], u.shape[1])))
        s = np.diag(s)
        C = np.dot(s, v)
        rk = C.shape[0]
        N = (N * rk) / (n[k-1] * rk1)
        rk1 = rk
    
    if len(C.shape) == 2:
        C = np.reshape(C, (C.shape[0], C.shape[1], 1))
        r.append(C.shape[2])
        
    G.append(C)

    #time2 = time.perf_counter()
    #print("\n time: ", time2 - time1)
    
    return Model(G), r

In [11]:
def tt_sum(tt1, tt2):
    #time1 = time.perf_counter()
    d1 = len(tt1.core)
    d2 = len(tt2.core)
    if d1 == d2:
        d = d1
    else:
        raise Exception("Different dimensions of tensors.")

    r1 = []
    n1 = []
    for X in tt1.core:
        n1.append(X.shape[1])
        r1.append(X.shape[0])
    r1.append(tt1.core[len(tt1.core) - 1].shape[2])

    r2 = []
    n2 = []
    for X in tt2.core:
        n2.append(X.shape[1])
        r2.append(X.shape[0])
    r2.append(tt2.core[len(tt2.core) - 1].shape[2])
    
    r = [x + y for x, y in zip(r1, r2)]
    
    if r1[0] == r2[0]:
        r[0] = r1[0]
    else:
        raise Exception("Different sizes of first mode.")
        
    if r1[d] == r2[d]:
        r[d] = r1[d]
    else:
        raise Exception("Different sizes of last mode.")
        
    n = n1
    G = []
    
    for i in range(0, d):
        Gi = nnp.zeros((r[i], n[i], r[i+1]))
        Gi = ops.index_update(Gi, ops.index[0 : r1[i], :, 0 : r1[i+1]], tt1.core[i])
        Gi = ops.index_update(Gi, ops.index[r[i] - r2[i] : r[i], :, r[i+1] - r2[i+1]: r[i+1]], tt2.core[i])
        #Gi[0 : r1[i], :, 0 : r1[i+1]] = tt1.core[i]
        #Gi[r[i] - r2[i] : r[i], :, r[i+1] - r2[i+1]: r[i+1]] = tt2.core[i]   
        G.append(Gi)

    #time2 = time.perf_counter()
    #print("\n time: ", time2 - time1)
    
    return Model(G)

In [12]:
def tt_round(tt, delta):  
    G = tt.core
    d = len(tt.core)
    rn = []
    #time1 = time.perf_counter()    
    
    for k in range(d-1, 0, -1):
        r1, n, r2 = G[k].shape
        G[k] = np.reshape(G[k], (r1, n * r2))
        G[k] = np.transpose(G[k])
        G[k], R = la.qr(G[k])
        G[k] = np.transpose(G[k])
        G[k-1] = np.einsum('ijk,lk->ijl', G[k-1], R) 
    
    for k in range(0, d - 1):
        
        if k == 0:
            r1, n1, r2 = G[k].shape
            r3 = G[k+2].shape[0]
            n2 = int(G[k+1].shape[1] / r3)
        elif k == d - 2:
            r1 = G[k].shape[0]
            r2 = G[k+1].shape[0]
            n1 = int(G[k].shape[1] / r2)
            r3 = 1
            n2 = G[k+1].shape[1]
        else:
            r1 = G[k].shape[0]
            r2 = G[k+1].shape[0]
            n1 = int(G[k].shape[1] / r2)
            r3 = G[k+2].shape[0]
            n2 = int(G[k+1].shape[1] / r3)

        G[k] = np.reshape(G[k], (r1 * n1, r2))

        u, s, v = la.svd(G[k], full_matrices=False)
 
        rc = s.shape[0]  - (s < delta).sum()
        rn.append(rc)
        s = np.diag(s)  
         
        v = np.dot(s, v) 
        G[k] = np.reshape(u, (r1, n1, u.shape[1]))
        G[k+1] = np.einsum('ij,ik->kj', G[k+1], v.T)
    
    
    if len(G[d-1].shape) == 2:
        G[d-1] = np.reshape(G[d-1], (G[d-1].shape[0], G[d-1].shape[1], 1))

    #time2 = time.perf_counter()
    #print("\n time: ", time2 - time1)
    
    return Model(G), rn

In [13]:
def tt_matrix(A):
  '''
  dim = 0
  a = np.prod(A.size)
  while a % 2 == 0:
    a = int(a / 2)
    dim = dim + 1

  A = A.reshape([2 for i in range(dim)])
  '''
  d = int(len(A.shape) / 2)
  nper = []
  for i in range(d):
    nper.append(i)
    nper.append(i+d)
  return A.transpose(nper)

In [14]:
def full_matrix(A):
  d = int(len(A.shape))
  d2 = int(len(A.shape) / 2)
  nper = [*range(0, d, 2)] + [*range(1, d, 2)]
  return A.transpose(nper).reshape(ft.reduce(operator.mul, A.shape[:d2], 1), ft.reduce(operator.mul, A.shape[d2:], 1))
  '''
  dim = 0
  a = np.prod(A.size)
  while a % 4 == 0:
    a = int(a / 4)
    dim = dim + 1
  return A.transpose(nper).reshape([4 for i in range(dim)])
  '''

In [35]:
def tt_matvec(ttm, ttv):
  d = len(tt.core)
  M = ttm.core
  x = ttv.core
  y = []
  for k in range(d):
    mshp = M[k].shape
    vshp = x[k].shape
    M[k] = np.reshape(M[k], (ft.reduce(operator.mul, mshp[:3], 1), mshp[3]))
    M[k] = M[k].T
    M[k] = np.reshape(M[k], (mshp[3] * mshp[0] * mshp[1], mshp[2]))
    x[k] = x[k].transpose([1,0,2])
    x[k] = np.reshape(x[k], (vshp[1], vshp[0] * vshp[2]))
    y.append(M[k].dot(x[k]))
    y[k] = np.reshape(y[k], (mshp[3], mshp[0], mshp[1], vshp[0], vshp[2]))
    y[k] = y[k].transpose([1,3,2,0,4])
    shp = y[k].shape
    y[k] = np.reshape(y[k], (shp[0] * shp[1], shp[2], shp[3] * shp[4]))
  return Model(y)



In [38]:
'''
A = []
dim = 2
x = 0
for i in range(2 ** dim):
  A.append(math.sin(x))
  x = x + math.pi / (2 ** dim - 1)
shp = [2 for i in range(dim)]
A = np.asarray(A)
A = A.reshape(shp)
'''

dim = 6
#A = np.arange(2**(dim))
A = np.eye(2**(dim))
A = np.asarray(A)
eps = 1e-2 

print(A, "def")

dim = 0
a = np.prod(A.size)
while a % 2 == 0:
  a = int(a / 2)
  dim = dim + 1
A = A.reshape([2 for i in range(dim)])

#print(A, "\n\n", A.shape, "old\n")
print(A.shape, "old\n")

A = tt_matrix(A)
shp = A.shape


#print(A, "\n\n", A.shape, "new")
print(A.shape, "new\n")

B = full_matrix(A)
print(B.shape, "new2")
#print(B, "\n\n", B.shape, "new2")

dim = 0
a = np.prod(A.size)
while a % 4 == 0:
  a = int(a / 4)
  dim = dim + 1
A = A.reshape([4 for i in range(dim)])

fast_tt_svd = jit(tt_svd, static_argnums=(1))
tt, rn = fast_tt_svd(A, eps / math.sqrt(len(A.shape) - 1))

for i in range(len(rn)-1):  
    tt.core[i] = tt.core[i][:, :, :rn[i]]
    tt.core[i+1] = tt.core[i+1][:rn[i], :, :]

ptr = 0
for i in range(len(tt.core)):
  tt.core[i] = np.reshape(tt.core[i], (tt.core[i].shape[0], shp[ptr], shp[ptr+1], tt.core[i].shape[2]))
  ptr = ptr + 2
  
for X in tt.core:
  print("c:\n", X, X.shape)

print(tt_get_r(tt, "m"))

C = tt_to_tensor(tt, A.shape)
#print(C)

D = full_matrix(C)
#print(D, "\n\n", D.shape, "new3")

dimv = 6
vec = np.arange(2**(dimv))
vec = np.asarray(vec)
vec = vec.reshape([2 for i in range(dimv)])
print(vec.shape, "vec")

ttv, rn = fast_tt_svd(vec, eps / math.sqrt(len(vec.shape) - 1))

for i in range(len(rn)-1):  
    ttv.core[i] = ttv.core[i][:, :, :rn[i]]
    ttv.core[i+1] = ttv.core[i+1][:rn[i], :, :]
  
for X in ttv.core:
  print("c:\n", X, X.shape)

print(tt_get_r(ttv))

ttmv = tt_matvec(tt, ttv)
  
for X in ttmv.core:
  print("c:\n", X, X.shape)

print(tt_get_r(ttmv))


[[1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]] def
(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) old

(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) new

(64, 64) new2
c:
 [[[[-0.70710677]
   [-0.        ]]

  [[-0.        ]
   [-0.70710677]]]] (1, 2, 2, 1)
c:
 [[[[ 7.0710671e-01]
   [-5.5134647e-22]]

  [[ 0.0000000e+00]
   [ 7.0710677e-01]]]] (1, 2, 2, 1)
c:
 [[[[-0.7071069 ]
   [ 0.        ]]

  [[ 0.        ]
   [-0.70710677]]]] (1, 2, 2, 1)
c:
 [[[[-7.0710683e-01]
   [-1.0193026e-29]]

  [[ 1.6974263e-22]
   [-7.0710677e-01]]]] (1, 2, 2, 1)
c:
 [[[[-7.0710677e-01]
   [-1.3272725e-29]]

  [[ 4.5703190e-23]
   [-7.0710677e-01]]]] (1, 2, 2, 1)
c:
 [[[[5.656854]
   [0.      ]]

  [[0.      ]
   [5.656853]]]] (1, 2, 2, 1)
[1, 1, 1, 1, 1, 1]
(2, 2, 2, 2, 2, 2) vec
c:
 [[[-0.33528185  0.94211787]
  [-0.94211787 -0.33528185]]] (1, 2, 2)
c:
 [[[ 0.5506019  -0.3240034 ]
  [ 0.8292731   0.17468734]]

 [

In [None]:
# input tensor A
'''
A = np.array([[1/2, 1/3, 1/4], [1/3, 1/4, 1/5], [1/4, 1/5, 1/6]])
A = np.arange(24)
B = A.reshape(4, 3, 2)
A = B
'''

'''
A = np.ones(1000000)
A = A.reshape(100, 100, 100)
#4 3 2
'''
'''
A = []
for i in range(100):
  for j in range(100):
    for q in range(100):
      A.append(1 / (i + j + q + 3))
A = np.asarray(A)
A = A.reshape(100, 100, 100)
'''

A = []
dim = 15
x = 0
for i in range(2 ** dim):
  A.append(math.sin(x))
  x = x + math.pi / (2 ** dim - 1)
shp = [2 for i in range(dim)]
A = np.asarray(A)
A = A.reshape(shp)

eps = 1e-2 # accuracy


fast_tt_svd = jit(tt_svd, static_argnums=(1))
tt, rn = fast_tt_svd(A, eps / math.sqrt(len(A.shape) - 1))
#ttnew, rn = tt_svd(A, eps / math.sqrt(len(A.shape) - 1))

for i in range(len(rn)-1):  
    tt.core[i] = tt.core[i][:, :, :rn[i]]
    tt.core[i+1] = tt.core[i+1][:rn[i], :, :]

fast_tt_sum = jit(tt_sum)
%timeit tts = fast_tt_sum(tt, tt)
#%timeit tts = tt_sum(tt, tt)
#for X in tts.core:
#  print(X)

ttsold = mod_to_old(tts)
print("\nRanks tts: ", ttsold.r)

fast_tt_round = jit(tt_round, static_argnums=(1))
ttr, rn = fast_tt_round(tts, eps / math.sqrt(len(tts.core) - 1))
#ttr, rn = tt_round(tts, eps / math.sqrt(len(tts.core) - 1))

for i in range(len(rn)):  
    ttr.core[i] = ttr.core[i][:, :, :rn[i]]
    ttr.core[i+1] = ttr.core[i+1][:rn[i], :, :]


#ttr2 = ttpy.matrix.round(tts)
#ttr2 = ttpy.tensor.round(tts, eps)
'''
print("\nttr:")
for X in ttr.core:
    print("\nCore:\n", X, "\nShape: ", X.shape)
'''
ttro = mod_to_old(ttr)
print("\nRanks ttr: ", ttro.r)
'''
B = tt_to_tensor(tts, A.shape)
#print("\nB = A + A:\n", B)
C = tt_to_tensor(ttro, A.shape)
#print("\nC = round(B):\n", C)
print("\n", la.norm(B-C), " <= ", eps*la.norm(B), la.norm(B-C) <= eps*la.norm(B))
'''


The slowest run took 199.50 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 870 µs per loop

Ranks tts:  [1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1]

Ranks ttr:  [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1]


'\nB = tt_to_tensor(tts, A.shape)\n#print("\nB = A + A:\n", B)\nC = tt_to_tensor(ttro, A.shape)\n#print("\nC = round(B):\n", C)\nprint("\n", la.norm(B-C), " <= ", eps*la.norm(B), la.norm(B-C) <= eps*la.norm(B))\n'