In [None]:
!pip install ttax==0.0.2
!pip install ttpy

In [2]:
import jax
import tt
import jax.numpy as jnp
from numpy import random
import numpy as np
import matplotlib.pyplot as plt
import ttax
%matplotlib inline

Хотим считать проекцию градиента для матрицы в задаче решения линейной системы

$$A = Q_1^TD_1Q_1 + Q_2^TD_2Q_2$$


где $~rk(D_i) = 10$, $~~rk(Q_i) = 35$

Можно воспользоваться автодифом из ttax. Тогда самым содержательным будет вопрос подсчета $$(Ax, x)$$ из функционала энергии.



Попробуем.

$$(Ax, x) = (D_1Q_1x, Q_1x) + (D_2Q_2x, Q_2x)$$

Сначала умножим $~y = Qx~$, затем $~z = Dy~$, и в конце $~inner(z, y)$ 

После того, как ко мне вернется память о том, как умножать матрицы в einsum, скомпоную эти шаги в один.

In [102]:
seed = 42
jaxkey = jax.random.PRNGKey(seed)

In [103]:
def generate_example(shape, rkq = 6, rkd = 2, rkx = 8):
  """
  Returns:
      matrices Q, D with ranks of rkq, rkd respectively 
    in a TT representation each
      vector X with rank rkx in a TT representation
  """
  assert len(shape) == 4
  
  Xshape = (shape[0], shape[2])

  # Q, D - tensor operators from R^xshape to R^xshape
  # feelsgood xd 

  Q = ttax.random.tensor(jaxkey, shape=shape, tt_rank=rkq)
  D = ttax.random.tensor(jaxkey, shape=shape, tt_rank=rkd)
  X = ttax.random.tensor(jaxkey, shape=Xshape, tt_rank=rkx)
  
  return Q, D, X

In [359]:
def ttax_make_operator(A):
  """
    converts a TT representation of A with 2n 3d factors into 
    a n-factor 4d representation convenient for TT-TT multiplication
  """
  return [
          jnp.einsum('abi,icd->abcd', A.tt_cores[i], A.tt_cores[i + 1],
                        optimize=True, precision=jax.lax.Precision.HIGHEST)
          for i in range(0, len(A.tt_cores), 2)
  ]

def ttax_matmul(operator_cores, vector_cores, ttax_format=True):
  """
    transforms a vector represented by vector_cores under an operator
    formed by operator_cores 

    input: 
          . 4d cores operator_cores
          . 3d cores vector_cores
    
    outout:
          . 3d cores 
  """
  tmp = [
      jnp.einsum("abic,eig->aebcg", operator_cores[i], vector_cores.tt_cores[i],
                 optimize=True, precision=jax.lax.Precision.HIGHEST)
      for i in range(len(vector_cores.tt_cores))
  ]

  if ttax_format:
    tmp = ttax.base_class.TT([
       core.reshape((
            operator_cores[i].shape[0] * vector_cores.tt_ranks[i], operator_cores[i].shape[1],
            operator_cores[i].shape[3] * vector_cores.tt_ranks[i + 1]), order="F",
      )   
      for i, core in enumerate(tmp)
    ])

  return tmp

def ttax_transpose_operator(operator_cores):
    cores = [jnp.einsum('aijb->ajib', core) for core in operator_cores]
    return cores

Тензор-тензор умножения и скалярное умножение последовательно

In [256]:
@jax.jit
def objective_1(Q, D, X):
  """
  objective function:
        (DQx, Qx)
  """

  Qx = ttax.orthogonalize(ttax_matmul(Q, X))

  DQx = ttax_matmul(D, Qx)

  return ttax.flat_inner(Qx, ttax.orthogonalize(DQx))

попробуем по другому



In [360]:
@jax.jit
def objective_2(Q, D, X):
  """
  objective function:
        (DQx, Qx)
  """

  Qx = ttax.orthogonalize(ttax_matmul(Q, X))

  return jnp.array([
    jnp.einsum( 
        "abic,eig,xiy->",
         D[i], core, core,
         optimize='greedy',
         precision=jax.lax.Precision.HIGHEST
    )
    for i, core in enumerate(Qx.tt_cores)
  ]).prod()

In [340]:
@jax.jit
def objective_3(Q, D, X):
  """
  objective function:
        (DQx, Qx)
  """

  Qx = ttax.orthogonalize(ttax_matmul(Q, X))

  return jnp.array([
    jnp.einsum( 
        "abic,eig,eix->",
         D[i], core, core,
         optimize='greedy'
    )
    for i, core in enumerate(Qx.tt_cores)
  ]).prod()

In [341]:
n = 4
shape = (n, n, n, n)

rkq, rkd, rkx = 6, 2, 3
Q, D, X = generate_example(shape, rkq = rkq, rkd = rkd, rkx = rkx)

In [342]:
# пара тестов на корректность примеров
assert np.linalg.matrix_rank( ttax.full(Q).reshape((n**2, n**2)) ) == rkq
assert np.linalg.matrix_rank( ttax.full(D).reshape((n**2, n**2)) ) == rkd
assert np.linalg.matrix_rank( ttax.full(X).reshape((n, n)) )       == rkx

In [343]:
Qo = ttax_make_operator(Q)
Do = ttax_make_operator(D)

In [344]:
assert [core.shape for core in Qo] == [(1, n, n, rkq), (rkq, n, n, 1)]
assert [core.shape for core in Do] == [(1, n, n, rkd), (rkd, n, n, 1)]
assert [core.shape for core in X.tt_cores] == [(1, n, rkx), (rkx, n, 1)]

In [361]:
Qx_5d = ttax_matmul(Qo, X, ttax_format=False)
Qx = ttax_matmul(Qo, X)
DQx_5d = ttax_matmul(Do, Qx, ttax_format=False)
DQx = ttax_matmul(Do, Qx)

In [346]:
[core.shape for core in Do]

[(1, 4, 4, 2), (2, 4, 4, 1)]

In [347]:
[core.shape for core in X.tt_cores]

[(1, 4, 3), (3, 4, 1)]

In [348]:
 [core.shape for core in Qx_5d], [core.shape for core in Qx.tt_cores]

([(1, 1, 4, 6, 3), (6, 3, 4, 1, 1)], [(1, 4, 18), (18, 4, 1)])

In [349]:
 [core.shape for core in DQx_5d], [core.shape for core in DQx.tt_cores]

([(1, 1, 4, 2, 18), (2, 18, 4, 1, 1)], [(1, 4, 36), (36, 4, 1)])

In [350]:
ttax.full(Qx).shape

(4, 4)

In [351]:
ttax.full(DQx).shape

(4, 4)

### проверим на двух матрицах

матрица $~2^{12} \times 2^{12}$

In [352]:
n = 64
shape = (n, n, n, n)

rkq, rkd, rkx = 35, 10, 40
Q1, D1, X1 = generate_example(shape, rkq = rkq, rkd = rkd, rkx = rkx)
# преобразовывать такое в матрицу уже не получится

матрица $~2^{20} \times 2^{20}$

In [353]:
n = 2**10
shape = (n, n, n, n)

Q2, D2, X2 = generate_example(shape, rkq = rkq, rkd = rkd, rkx = rkx)

алгосы

In [362]:
Qo1 = ttax_make_operator(Q1)
Do1 = ttax_make_operator(D1)

Qo2 = ttax_make_operator(Q2)
Do2 = ttax_make_operator(D2)

In [363]:
objective_2(Qo, Do, X)

DeviceArray(85903.02, dtype=float32)

In [364]:
objective_1(Qo, Do, X)

DeviceArray(80900.11, dtype=float32)

In [365]:
objective_3(Qo, Do, X)

DeviceArray(172864.81, dtype=float32)

In [358]:
xs = objective_2(Qo2, Do2, X2)

In [80]:
objective_1(Qo2, Do2, X2)

DeviceArray(1.1699865e+19, dtype=float32)

In [258]:
%timeit c = objective_1(Qo1, Do1, X1)
%timeit c = objective_1(Qo2, Do2, X2)

The slowest run took 37.99 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 7.96 ms per loop
1 loop, best of 5: 2.5 s per loop


In [259]:
%timeit c = objective_2(Qo1, Do1, X1)
%timeit c = objective_2(Qo2, Do2, X2)

The slowest run took 61.37 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 5.14 ms per loop
1 loop, best of 5: 824 ms per loop


60ms на матрицу n = 4000 и 4s на n = 2^20

In [None]:

# end of file
#####################################################


In [None]:
A = ttax.random.tensor(jaxkey, shape=(2, 2, 2, 2), tt_rank=2)
Af = ttax.full(A)
B = ttax.random.tensor(jaxkey, shape=(2, 2, 2, 2), tt_rank=2)
Bf = ttax.full(B)

In [None]:
np.inner(Af, Bf).sum()

122.364174

In [None]:
ttax.flat_inner(ttax.orthogonalize(A), ttax.orthogonalize(B))

DeviceArray(97.19842, dtype=float32)

In [122]:
A = np.array([
              [1, 2],
              [3, 4]
])
x = np.array([-1, 2])
B = np.array([
              [2, 4],
              [10, 1]
])

In [129]:
np.einsum("ij,j,i->", A, x, x)

7

In [35]:
A = random.rand(2, 2, 2)
B = random.rand(2, 2, 2)

In [260]:
jnp.inner(A @ x, x)

DeviceArray(7, dtype=int32)

In [261]:
# hmmmm
jnp.einsum("ij,ij->", A, B)

DeviceArray(44, dtype=int32)