In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def plot_values(func_values, left, right, n, logscale=False, func_name=None):
    plt.plot(np.linspace(left, right, n), func_values)
    if logscale is True:
        plt.xscale("log")
        plt.yscale("log")
    if func_name:
        plt.title(func_name)

In [3]:
def values(func, left, right, n):
    return func(np.linspace(left, right, n))

Функция для TT-SVD (копирую из домашки):

In [4]:
def tt_svd(tensor, eps, max_rank):
    """
    Input
        tensor: np array
        eps: desired difference in frobenius norm between tensor and TT approximation
        max_rank: upper hard limit on each TT rank (it has priority over eps)

    Output
        carriages: list of cores that give TT decomposition of tensor
    """

    remaining = tensor
    d = len(tensor.shape)
    N = tensor.size
    r = 1

    eps = eps / np.sqrt(d - 1) #потому что ошибка в tt_svd составляет
    #sqrt(sum_{k <= d - 1} квадрат ошибки в svd для A_k) = sqrt(d - 1) * ошибка в каждом svd

    carriages = []

    for k in range(d - 1):
        matrix_to_svd = remaining.reshape((r * tensor.shape[k], N // tensor.shape[k]), order='F')
        u, sigmas, vt = np.linalg.svd(matrix_to_svd, full_matrices=False)

        curr_r = min(sigmas.size, max_rank)
        error_squared = np.sum(np.square(sigmas[curr_r:]))
        while curr_r >= 1 and error_squared + np.square(sigmas[curr_r - 1]) < np.square(eps):
            error_squared = error_squared + np.square(sigmas[curr_r - 1])
            curr_r -= 1

        carriages.append(u[:,:curr_r].reshape((r, tensor.shape[k], curr_r), order='F'))
        remaining = np.diag(sigmas[:curr_r]) @ vt[:curr_r,:]
        N = N // tensor.shape[k]
        r = curr_r

    carriages.append(remaining.reshape((r, tensor.shape[-1], 1), order='F'))

    return carriages

In [5]:
def wtt_filter(input_vector, d, modes, ranks, check_correctness=False):
    
    filters = []
    prod_modes = input_vector.size
    
    if check_correctness:
        assert len(modes) == d
        assert len(ranks) == d - 1
        prod_modes_manual = 1
        for mode in modes:
            prod_modes_manual *= mode
        assert prod_modes == prod_modes_manual
    
    r_prev = 1
    A = input_vector
    for k in range(d):
        A = A.reshape((r_prev * modes[k], prod_modes // modes[k]), order='F')
        if A.shape[0] <= A.shape[1]:
            u, sigmas, vt = np.linalg.svd(A, full_matrices=False)
        else:
            u, sigmas, vt = np.linalg.svd(A, full_matrices=True)
        filters.append(u)

        if check_correctness:
            assert u.shape[0] == u.shape[1] == r_prev * modes[k]
            if k < d - 1:
                assert ranks[k] <= r_prev * modes[k]

        if k < d - 1:
            A = (u.T @ A)[:ranks[k],:]
            prod_modes //= modes[k]
            r_prev = ranks[k]
    
    return filters

Вектор из всех единиц:

In [6]:
d = 4
ones_vector = np.ones(2 ** d)
modes = [2] * d
ranks = [1] * (d - 1)
filters = wtt_filter(ones_vector, d, modes, ranks, True)

In [7]:
filters

[array([[-0.70710678,  0.70710678],
        [-0.70710678, -0.70710678]]),
 array([[ 0.70710678, -0.70710678],
        [ 0.70710678,  0.70710678]]),
 array([[-0.70710678, -0.70710678],
        [-0.70710678,  0.70710678]]),
 array([[ 0.70710678, -0.70710678],
        [ 0.70710678,  0.70710678]])]

In [8]:
[np.sqrt(2) * f for f in filters]

[array([[-1.,  1.],
        [-1., -1.]]),
 array([[ 1., -1.],
        [ 1.,  1.]]),
 array([[-1., -1.],
        [-1.,  1.]]),
 array([[ 1., -1.],
        [ 1.,  1.]])]

Не все матрицы равны
$$
\frac{1}{\sqrt{2}}
\begin{pmatrix}
1 & 1 \\ 1 & -1
\end{pmatrix}
$$
как сказано в статье.

In [9]:
ones_2x8 = np.ones((2,8))
u, s, vt = np.linalg.svd(ones_2x8, full_matrices=False)

In [10]:
u

array([[-0.70710678,  0.70710678],
       [-0.70710678, -0.70710678]])

Видимо, дело в относительной свободе выбора собственных векторов и значений (с точностью до знака можно векторы выбирать, например).

In [11]:
test_1 = np.array([
    [1,2,3],
    [4,5,6]
])
test_2 = np.array([
    [7,8,9],
    [10,11,12]
])
np.vstack([test_1, test_2])

array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]])

In [12]:
def wtt_apply(input_vector, d, filters, modes, ranks, check_correctness=False):
    prod_modes = input_vector.size
    
    if check_correctness:
        assert len(filters) == d
        assert len(modes) == d
        assert len(ranks) == d - 1
        prod_modes_manual = 1
        for mode in modes:
            prod_modes_manual *= mode
        assert prod_modes == prod_modes_manual
        
    tails = []
    A = input_vector
    r_prev = 1
    for k in range(d):
        A = A.reshape((r_prev * modes[k], prod_modes // modes[k]), order='F')
        A = filters[k].T @ A

        if check_correctness:
            assert A.shape[0] == r_prev * modes[k]
            if k < d - 1:
                assert ranks[k] <= r_prev * modes[k]
                
        if k < d - 1:
            tails.append(A[ranks[k]:,:])
            A = A[:ranks[k],:]
            prod_modes //= modes[k]
            r_prev = ranks[k]
        
    result = A
    for k in range(d - 2, -1, -1):        
        result = np.vstack([
            result.reshape((ranks[k], prod_modes), order='F'),
            tails[k]
        ])
        prod_modes *= modes[k]
    
    return result.flatten(order='F')

In [13]:
wtt_ones = wtt_apply(ones_vector, d, filters, modes, ranks, True)
wtt_ones[0]

3.999999999999999

In [14]:
np.linalg.norm(wtt_ones[1:])

4.976250668031886e-16

Один ненулевой элемент, всё остальное --- нули.

In [15]:
d = 20
n = 2 ** d

left = 0.
right = 1.

v = values(lambda x: np.sin(100 * x), left, right, n)

In [16]:
v.size

1048576

In [17]:
filters = wtt_filter(
    v,
    d,
    [2] * d,
    [2] * (d - 1),
    True
)

In [18]:
[u.shape for u in filters]

[(2, 2),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4)]

In [19]:
res = wtt_apply(
    v,
    d,
    filters,
    [2] * d,
    [2] * (d - 1),
    True
)

In [20]:
print(res[0])
print(np.linalg.norm(res[1:]))

725.6562012962519
4.141384974355612e-12


Модельный пример с $\sqrt{x}$.

In [21]:
d = 10
n = 2 ** d
left = 0.
right = 1.

linspace = np.linspace(left, right, n)

sqrt_x_values = values(lambda x: np.sqrt(x), left, right, n)

In [22]:
filters = wtt_filter(
    sqrt_x_values,
    d,
    [2] * d,
    [2] * (d - 1),
    True
)
wtt_res = wtt_apply(
    sqrt_x_values,
    d,
    filters,
    [2] * d,
    [2] * (d - 1),
    True
)

In [23]:
wtt_res

array([ 2.26271567e+01,  3.69570201e-16, -4.29187308e-03, ...,
        1.26439345e-05, -1.82856285e-04, -3.74363753e-06])

In [24]:
filters = wtt_filter(
    sqrt_x_values,
    d,
    [2] * d,
    [2] + [3] * (d - 2),
    True
)
wtt_res = wtt_apply(
    sqrt_x_values,
    d,
    filters,
    [2] * d,
    [2] + [3] * (d - 2),
    True
)

In [25]:
wtt_res

array([ 2.26274167e+01, -6.54458816e-16, -5.17412355e-34, ...,
       -4.43130911e-07, -4.83150669e-09, -3.74363753e-06])

In [26]:
filters = wtt_filter(
    sqrt_x_values,
    d,
    [2] * d,
    [2] + [4] * (d - 2),
    True
)
wtt_res = wtt_apply(
    sqrt_x_values,
    d,
    filters,
    [2] * d,
    [2] + [4] * (d - 2),
    True
)

In [27]:
wtt_res

array([ 2.26274170e+01, -2.86652590e-16,  1.03607012e-33, ...,
       -8.61662304e-09, -1.05267213e-10, -8.09322564e-13])

In [28]:
d = 10
n = 2 ** d

left = -1.
right = 1.

exp_comb_values = values(
    lambda x: np.exp(x) + np.exp(1.5 * x) - 2 * np.exp(-2 * x) + 7 * np.exp(3 * x),
    left,
    right,
    n
)

In [29]:
filters = wtt_filter(
    exp_comb_values,
    d,
    [2] * d,
    [2] + [4] * (d - 2),
    True
)
wtt_res = wtt_apply(
    exp_comb_values,
    d,
    filters,
    [2] * d,
    [2] + [4] * (d - 2),
    True
)

In [30]:
wtt_res

array([ 1.39016434e+03, -4.39670029e-14, -2.32907645e-32, ...,
        2.68361266e-14, -2.36288676e-15,  2.48153651e-14])

In [31]:
np.linalg.norm(wtt_res[1:])

5.467260800634424e-13

In [32]:
wtt_res[0]

1390.1643414965263