### Подготовительный этап

Обоснование происходящего будет [тут](https://raw.githubusercontent.com/johanDDC/RiemannianOptimizationTT/3af4e53147cd5a89dd53461d22775641cfddcd6e/partial_evp/evp.pdf).

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from numpy import random
import ttax
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
m = random.randn(2, 3, 2)
arr = jnp.array(random.randn(3, 2))
print(arr.shape)
arr = jnp.reshape(arr, (1, arr.shape[0], 1, arr.shape[1]), "F")
print(arr.shape)
arr = jnp.squeeze(arr)
print(arr.shape)



(3, 2)
(1, 3, 1, 2)
(3, 2)


In [3]:
def core_extend_apply_and_shrink(core, func):
    print(core.shape)
    core = jnp.reshape(core, (core.shape[0], core.shape[1], 1, core.shape[2]), "F")
    print(core, "\n")
    print(core.shape)
    core = func(core)
    print(core.shape)
    print(core)
    core = jnp.squeeze(core)
    print(core.shape)
    return core

Core = jnp.array(np.array(m, order="F"))
# print(Core)
Core = core_extend_apply_and_shrink(Core, lambda x: jnp.transpose(x, [0, 2, 1, 3]))
# print(Core)

(2, 3, 2)
[[[[-0.13728446  1.2043208 ]]

  [[ 0.651604   -1.1579461 ]]

  [[-3.2300212   0.44841063]]]


 [[[ 0.89420897 -0.01044754]]

  [[ 0.28185886 -0.11538029]]

  [[-0.06823526 -0.0543802 ]]]] 

(2, 3, 1, 2)
(2, 1, 3, 2)
[[[[-0.13728446  1.2043208 ]
   [ 0.651604   -1.1579461 ]
   [-3.2300212   0.44841063]]]


 [[[ 0.89420897 -0.01044754]
   [ 0.28185886 -0.11538029]
   [-0.06823526 -0.0543802 ]]]]
(2, 3, 2)


In [23]:
tt_rank = random.randint(3, 6)
modes = (2, 3, 2, 2)
TT_op = ttax.random.tensor(jax.random.PRNGKey(42),
                      modes, tt_rank=tt_rank, dtype=jnp.float32)
TT_el = ttax.random.tensor(jax.random.PRNGKey(42),
                      (3, 2), tt_rank=2, dtype=jnp.float32)
print(TT_op.shape)
print(TT_el.shape)

(2, 3, 2, 2)
(3, 2)


In [24]:
def TT_matmul(tt1 : ttax.base_class.TT, tt2 : ttax.base_class.TT):
    operator_cores = [
        jnp.einsum('abi,icd->abcd', tt1.tt_cores[i], tt1.tt_cores[i + 1])
        for i in range(0, len(tt1.tt_cores), 2)
    ]
    cores = [
            jnp.einsum('abic,eig->aebcg', operator_cores[i], tt2.tt_cores[i]).reshape((
                operator_cores[i].shape[0] * tt2.tt_ranks[i], operator_cores[i].shape[1],
                operator_cores[i].shape[3] * tt2.tt_ranks[i + 1]), order="F"
            )
        for i in range(len(tt2.tt_cores))
    ]
    return ttax.base_class.TT(cores)

In [25]:
print(TT_matmul(TT_op, TT_el).shape)
# TT_matmul(TT_op, TT_el)

(2, 2)


### Основной этап

In [126]:
make_rayleigh = lambda A: lambda x: ttax.flat_inner(x, TT_matmul(A, x))
norm = lambda x: jnp.sqrt(ttax.flat_inner(x, x))
residual = lambda A, x, eig: norm(ttax.orthogonalize(TT_matmul(A, x) + (-eig) * x))
retraction = lambda T: ttax.round(ttax.orthogonalize(T))

def armijo_backtracking(init, grad, mul, beta, func, x):
    alpha = init
    while func(x) - func(retraction(x + (-alpha) * grad)) < \
        mul * alpha * norm(grad) ** 2:
        alpha *= beta
    return alpha

In [129]:
def riemanGD(A, init, tol, max_iter = None, debug = False):
    rayleigh = make_rayleigh(A)
    rieman_grad = ttax.autodiff.grad(rayleigh)
    x = ttax.orthogonalize(init)
    residuals = [residual(A, x, rayleigh(x))]
    iters = 0
    if debug == True:
        print("№\tresidual")
    while residual(A, x, rayleigh(x)) > tol:
        # alpha = armijo_backtracking(jnp.inner(x, A @ rieman_grad(x)) / rayleigh(rieman_grad), rieman_grad, 1E-4, 0.8, rayleigh, x)
        rieman_x = rieman_grad(x)
        alpha = armijo_backtracking(2, rieman_x, 1E-4, 0.8, rayleigh, x)
        x = retraction(x + (-alpha)*rieman_x)
        iters += 1
        residuals.append(residual(A, x, rayleigh(x)))
        if debug == True:
            print("{}\t{}".format(iters, residual(A, x, rayleigh(x))))
        if (max_iter is not None and iters >= max_iter) or\
            (len(residuals) > 2 and residuals[-2] == residuals[-1]):
            break

    return x, residuals

In [136]:
core1 = np.zeros((1,4,4))
core1[0,0,0] = core1[0,2,2] = core1[0,1,1] = core1[0, 3, 3] = 1
core2 = np.zeros((4,4,4))
core2[0,0,0] = core2[1,1,1] = core2[2,2,2] = core2[3,3,3] =  1
core3 = core2.copy()
core4 = np.zeros((4,4,1))
core4[0,0,0] = core4[1,1,0] = core4[2,2,0] = core4[3,3,0] =  1
I3 = jnp.array(core2)
eye = ttax.base_class.TT(
    [core1,core2, core3, core4]
)
print(ttax.full(eye))

[[[[1. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]


 [[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 1. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]


 [[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 1. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]


 [[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0

  lax._check_user_dtype_supported(dtype, "ones")


In [137]:
t = ttax.random.tensor(jax.random.PRNGKey(42),(4,4))
# print(rieman_grad(t)(t))
v, rs = riemanGD(eye, t, 4E-4, 20, True)

№	residual
1	13.222766876220703
2	5.255079746246338
3	2.1720805168151855
4	0.9408542513847351
5	0.4378170073032379
6	0.22417989373207092
7	0.12597201764583588
8	0.0755133107304573
9	0.046871837228536606
10	0.02956838347017765
11	0.018783463165163994
12	0.011966923251748085
13	0.007633518893271685
14	0.004871348850429058
15	0.003109496086835861
16	0.0019851604010909796
17	0.0012669748393818736
18	0.0008088692557066679
19	0.000516400090418756
20	0.0003302779223304242


In [138]:
print(make_rayleigh(eye)(v))
print(v)

1.0583811e-07
TT(tt_cores=[DeviceArray([[[ 0.2738081 , -1.8865969 , -0.15656863,  0.02251866],
              [ 3.6740274 ,  0.37656114,  0.06663801,  0.00401304],
              [ 0.38965422, -1.8502033 ,  0.21931557, -0.02004351],
              [-0.70135784,  0.20815684,  0.40980142,  0.0186777 ]]],            dtype=float32), DeviceArray([[[ 0.6014823 ],
              [-0.08720207],
              [-0.7904917 ],
              [ 0.07575009]],

             [[ 0.04094949],
              [ 0.78495234],
              [-0.11366284],
              [-0.60766244]],

             [[ 0.6275646 ],
              [ 0.39845207],
              [ 0.47835764],
              [ 0.4675172 ]],

             [[ 0.49265072],
              [-0.46634927],
              [ 0.36520922],
              [-0.637523  ]]], dtype=float32)])
