In [1]:
import numpy as np

Функция для TT-SVD (копирую из домашки):

In [2]:
def tt_svd(tensor, eps, max_rank):
    """
    Input
        tensor: np array
        eps: desired difference in frobenius norm between tensor and TT approximation
        max_rank: upper hard limit on each TT rank (it has priority over eps)

    Output
        carriages: list of cores that give TT decomposition of tensor
    """

    remaining = tensor
    d = len(tensor.shape)
    N = tensor.size
    r = 1

    eps = eps / np.sqrt(d - 1) #потому что ошибка в tt_svd составляет
    #sqrt(sum_{k <= d - 1} квадрат ошибки в svd для A_k) = sqrt(d - 1) * ошибка в каждом svd

    carriages = []

    for k in range(d - 1):
        matrix_to_svd = remaining.reshape((r * tensor.shape[k], N // tensor.shape[k]), order='F')
        u, sigmas, vt = np.linalg.svd(matrix_to_svd, full_matrices=False)

        curr_r = min(sigmas.size, max_rank)
        error_squared = np.sum(np.square(sigmas[curr_r:]))
        while curr_r >= 1 and error_squared + np.square(sigmas[curr_r - 1]) < np.square(eps):
            error_squared = error_squared + np.square(sigmas[curr_r - 1])
            curr_r -= 1

        carriages.append(u[:,:curr_r].reshape((r, tensor.shape[k], curr_r), order='F'))
        remaining = np.diag(sigmas[:curr_r]) @ vt[:curr_r,:]
        N = N // tensor.shape[k]
        r = curr_r

    carriages.append(remaining.reshape((r, tensor.shape[-1], 1), order='F'))

    return carriages

WTT Filter algorithm из статьи:

In [75]:
def wtt_filter(tensor, ranks):
    """
    Input
        tensor: np array
        ranks: list of rank parameters

    Output
        filters: list of linear filters U_1, ..., U_{d - 1} defining the WTT transform
        final ranks: list of final ranks of an approximation
    """

    d = len(tensor.shape)
    N = tensor.size
    r_prev = 1
    A = tensor.reshape((tensor.shape[0], N // tensor.shape[0]), order='F')

    filters = []
    final_ranks = []

    for k in range(d - 1):
        s1, s2 = A.shape
        r = min(ranks[k], s1, s2)
        u, sigmas, vt = np.linalg.svd(A, full_matrices=False)
        filters.append(u)

        A = (u.T @ A)[:r,:]
        N = N * r // (tensor.shape[k] * r_prev)
        A = A.reshape((tensor.shape[k + 1] * r, N // (tensor.shape[k + 1] * r)), order='F')
        final_ranks.append(r)
        r_prev = r

    return filters, final_ranks

WTT Application algorithm из статьи:

In [169]:
def wtt_application(v, filters, shapes, ranks):
    """
    Input
        v: a vector of size n = n_1...n_d
        filters: list of linear filters U_1, ..., U_{d - 1} defining the WTT transform
        shapes: a list of numbers n_1, ..., n_d
        ranks: list of ranks of the transform

    Output
        w: transformed vector
    """
    
    d = len(filters) + 1
    N = v.size
    r_prev = 1
    w = v.reshape((shapes[0], N // shapes[0]), order='F')
    tails = []
    
    for k in range(d - 1):
        
        print("k =", k)
        
        r = ranks[k]
        w = filters[k].T @ w
        
        print("w.shape=", w.shape)
        
        tails.append(w[r:,:])
        
        print("tails[k].shape=", tails[k].shape)
        print("Current N:", N)

        N = N * r // (shapes[k] * r_prev)
        
        print("New N:", N)
        
        if k < d - 2:
            w = w[:r,:].reshape((shapes[k + 1] * r, N // (shapes[k + 1] * r)), order='F')
            r_prev = r
        else:
            w = w[:r,:]
        
    for k in range(d - 2, -1, -1):
        
        print("k =", k)
        print("w.shape=", w.shape)
        print("tails[k].shape=", tails[k].shape)
        
        print("Current N:", N)
        
        r_prev = 1 if k == 0 else ranks[k - 1]
        r = ranks[k]
        N = N * r_prev * shapes[k] // r
        
        print("New N:", N)
        
        w = np.vstack([w, tails[k]]).reshape((r_prev, N // r_prev), order='F')
    
    return w.flatten(order='F')

In [170]:
def func(x):
    return np.exp(x) + np.exp(1.5 * x) - 2 * np.exp(-2 * x) + 7 * np.exp(3 * x)

In [171]:
def values(func, left, right, n):
    return func(np.linspace(left, right, n))

In [179]:
left = -1
right = 1

d = 5
n = 2 ** d

In [180]:
v = values(func, left, right, n)

In [181]:
v_tensor = v.reshape([2] * d, order='F')

In [182]:
filters, ranks = wtt_filter(v_tensor, [2] * d)

In [183]:
ranks

[2, 2, 2, 2]

In [186]:
[U.shape for U in filters]

[(2, 2), (4, 4), (4, 4), (4, 2)]

In [184]:
w = wtt_application(v, filters, [2] * d, ranks)

k = 0
w.shape= (2, 16)
tails[k].shape= (0, 16)
Current N: 32
New N: 32
k = 1
w.shape= (4, 8)
tails[k].shape= (2, 8)
Current N: 32
New N: 16
k = 2
w.shape= (4, 4)
tails[k].shape= (2, 4)
Current N: 16
New N: 8
k = 3
w.shape= (2, 2)
tails[k].shape= (0, 2)
Current N: 8
New N: 4
k = 3
w.shape= (2, 2)
tails[k].shape= (0, 2)
Current N: 4
New N: 8


ValueError: cannot reshape array of size 4 into shape (2,4)

In [135]:
filters

[array([[ 0.63902701, -0.76918429],
        [ 0.76918429,  0.63902701]]),
 array([[ 5.67927720e-01, -7.30419940e-01, -4.81611321e-02,
          3.76331398e-01],
        [-2.86275249e-04,  3.65981793e-01, -6.92043693e-01,
          6.22199945e-01],
        [ 8.23078274e-01,  5.03976456e-01,  3.26213099e-02,
         -2.59780171e-01],
        [ 4.22560260e-04,  2.80270886e-01,  7.19508084e-01,
          6.35418106e-01]]),
 array([[-4.29000441e-01, -7.57212035e-01, -1.19251865e-01,
         -3.41142297e-01],
        [ 4.66611066e-04,  4.71487291e-01, -6.48329317e-01,
         -4.35946607e-01],
        [ 1.22911071e-04,  2.75419469e-04,  1.31578842e-01,
          4.54896800e-01],
        [ 2.12181293e-08, -1.43509592e-07, -2.69774339e-05,
         -4.72445414e-02],
        [-9.03302667e-01,  3.59365463e-01,  5.50133349e-02,
          1.62627193e-01],
        [-1.63599134e-03,  2.74200053e-01,  7.13681864e-01,
         -4.07276978e-01],
        [-6.07434014e-05, -6.22091508e-04,  1.89121196