In [None]:
!pip install ttax==0.0.2
!pip install ttpy

In [2]:
import jax
import tt
import jax.numpy as jnp
from numpy import random
import numpy as np
import matplotlib.pyplot as plt
import ttax
%matplotlib inline

Хотим считать проекцию градиента для матрицы в задаче решения линейной системы

$$A = Q_1^TD_1Q_1 + Q_2^TD_2Q_2$$


где $~rk(D_i) = 10$, $~~rk(Q_i) = 35$

Можно воспользоваться автодифом из ttax. Тогда самым содержательным будет вопрос подсчета $$(Ax, x)$$ из функционала энергии.



Попробуем.
По определению билинейной формы перенесем транспонирование
$$(Ax, x) = (D_1Q_1x, Q_1x) + (D_2Q_2x, Q_2x)$$

$(D_1Q_1x, Q_1x)$ 

In [3]:
seed = 42
jaxkey = jax.random.PRNGKey(seed)



In [4]:
def generate_example(shape, rkq = 6, rkd = 2, rkx = 8):
  """
  Returns:
      matrices Q, D with ranks of rkq, rkd respectively 
    in a TT representation each
      vector X with rank rkx in a TT representation
  """
  assert len(shape) == 4
  
  Xshape = (shape[0], shape[2])

  # Q, D - tensor operators from R^xshape to R^xshape
  # feelsgood xd 

  Q = ttax.random.tensor(jaxkey, shape=shape, tt_rank=rkq)
  D = ttax.random.tensor(jaxkey, shape=shape, tt_rank=rkd)
  X = ttax.random.tensor(jaxkey, shape=Xshape, tt_rank=rkx)
  
  return Q, D, X

In [5]:
def ttax_make_operator(A):
  """
    converts a TT representation of A with 2n 3d factors into 
    a n-factor 4d representation convenient for TT-TT multiplication
  """
  return [
          jnp.einsum('abi,icd->abcd', A.tt_cores[i], A.tt_cores[i + 1],
                        optimize=True)
          for i in range(0, len(A.tt_cores), 2)
  ]

def ttax_matmul(operator_cores, vector_cores):
  """
    transforms a vector represented by vector_cores under an operator
    formed by operator_cores 

    input: 
          . 4d cores operator_cores
          . 3d cores vector_cores
    
    outout:
          . 3d cores 
  """
  return ttax.base_class.TT([
      jnp.einsum("abic,eig->aebcg", operator_cores[i], vector_cores.tt_cores[i], optimize=True).reshape((
            operator_cores[i].shape[0] * vector_cores.tt_ranks[i], operator_cores[i].shape[1],
            operator_cores[i].shape[3] * vector_cores.tt_ranks[i + 1]), order="F",
      )
      for i in range(len(vector_cores.tt_cores))
  ])

def ttax_transpose_operator(operator_cores):
    cores = [jnp.einsum('aijb->ajib', core) for core in operator_cores]
    return cores

Тензор-тензор умножения и скалярное умножение последовательно

In [6]:
@jax.jit
def objective_1(Q, D, X):
  """
  objective function:
        (DQx, Qx)
  """

  Qx = ttax_matmul(Q, X)

  DQx = ttax_matmul(D, Qx)

  return ttax.flat_inner(ttax.orthogonalize(Qx), ttax.orthogonalize(DQx))

попробуем по другому (что-то не заработало)



In [7]:
@jax.jit
def objective_2(DQ, Q, X):
  """
  objective function:
        (DQx, Qx)
  """

  Qx = ttax_matmul(Q, X)

  # (Dy, y)

  ### hmmm , как бы написать einsum
  return jnp.einsum("ijk,xyz,abc->", DQ, Qx, X, optimize=True)

In [8]:
@jax.jit
def objective_3(QDQ, X):
  """
  objective function:
        (DQx, Qx)
  """

  return jnp.flat_inner(QDQ, ttax.orthogonalize(X))

In [9]:
n = 4
shape = (n, n, n, n)

rkq, rkd, rkx = 6, 2, 2
Q, D, X = generate_example(shape, rkq = rkq, rkd = rkd, rkx = rkx)

In [10]:
# пара тестов на корректность примеров
assert np.linalg.matrix_rank( ttax.full(Q).reshape((n**2, n**2)) ) == rkq
assert np.linalg.matrix_rank( ttax.full(D).reshape((n**2, n**2)) ) == rkd
assert np.linalg.matrix_rank( ttax.full(X).reshape((n, n)) )       == rkx

### проверим на двух матрицах

матрица $~2^{12} \times 2^{12}$

In [11]:
n = 64
shape = (n, n, n, n)

rkq, rkd, rkx = 35, 10, 40
Q1, D1, X1 = generate_example(shape, rkq = rkq, rkd = rkd, rkx = rkx)
# преобразовывать такое в матрицу уже не получится

матрица $~2^{20} \times 2^{20}$

In [12]:
n = 2**10
shape = (n, n, n, n)

Q2, D2, X2 = generate_example(shape, rkq = rkq, rkd = rkd, rkx = rkx)

алгосы

In [13]:
Qo1 = ttax_make_operator(Q1)
Do1 = ttax_make_operator(D1)

Qo2 = ttax_make_operator(Q2)
Do2 = ttax_make_operator(D2)

In [14]:
%timeit c = objective_1(Qo1, Do1, X1)[...]
%timeit c = objective_1(Qo2, Do2, X2)[...]

The slowest run took 6.63 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 59.6 ms per loop
1 loop, best of 5: 3.67 s per loop


60ms на матрицу n = 4000 и 4s на n = 2^20

In [22]:

# end of file
#####################################################


In [16]:
A = ttax.random.tensor(jaxkey, shape=(2, 2, 2, 2), tt_rank=2)
Af = ttax.full(A)
B = ttax.random.tensor(jaxkey, shape=(2, 2, 2, 2), tt_rank=2)
Bf = ttax.full(B)

In [17]:
np.inner(Af, Bf).sum()

122.364174

In [18]:
ttax.flat_inner(ttax.orthogonalize(A), ttax.orthogonalize(B))

DeviceArray(97.19842, dtype=float32)

In [19]:
A = np.array([
              [1, 1],
              [2, 3]
])
B = np.array([
              [2, 4],
              [10, 1]
])

In [20]:
jnp.inner(A, B)

DeviceArray([[ 6, 11],
             [16, 23]], dtype=int32)

In [21]:
# hmmmm
jnp.einsum("ij,ij->", A, B)

DeviceArray(29, dtype=int32)