In [None]:
!pip install ttax==0.0.2
!pip install ttpy

In [2]:
import jax
import tt
import jax.numpy as jnp
from numpy import random
import numpy as np
import matplotlib.pyplot as plt
import ttax
%matplotlib inline

In [118]:
ort = ttax.orthogonalize
rkey = lambda x: jax.random.PRNGKey(x)

In [119]:
jaxkey = rkey(42)

Сначала напишем кучу вспомогательных функций, затем будем генерировать примеры нужного вида, и запускать от них GD

In [187]:
def ttax_make_operator(A):
  """
    converts a TT representation of A with 2n 3d factors into 
    a n-factor 4d representation convenient for TT-TT multiplication
  """
  return [
          jnp.einsum('abi,icd->abcd', A.tt_cores[i], A.tt_cores[i + 1],
                        precision=jax.lax.Precision.HIGHEST)
          for i in range(0, len(A.tt_cores), 2)
  ]

def ttax_matmul(operator_cores, vector_cores):
  """
  applies operator represented by its 4d tensor cores operator_cores 
  to vector given its 3d TT cores vector_cores
  """

  return ttax.base_class.TT([
      jnp.einsum('abic,eig->aebcg', operator_cores[i], vector_cores.tt_cores[i], precision=jax.lax.Precision.HIGHEST).reshape((
            operator_cores[i].shape[0] * vector_cores.tt_ranks[i], operator_cores[i].shape[1],
            operator_cores[i].shape[3] * vector_cores.tt_ranks[i + 1]), order="F",
      )
      for i in range(len(vector_cores.tt_cores))
  ])

In [188]:
@jax.jit
def simultaneous_convolution(Q, D, X):
  """
  "efficiently" evaluates the objective function:
        (DQx, Qx)
  """

  # ttax.ort would be of good here, but unimplemented in ttax
  Qx = ttax_matmul(Q, X)

  return jnp.array([
    jnp.einsum( 
        "abic,eig,xiy->",
         D[i], core, core,
         optimize='greedy',
         precision=jax.lax.Precision.HIGHEST
    )
    for i, core in enumerate(Qx.tt_cores)
  ]).prod()

def simultaneous_apply(example, X):
  """
  evaluates the objective function:
        (D1 Q1 X, Q1 X) + (D2 Q2 X, Q2 X) 
  with matrices provided by the given example
  """

  Q1, Q2, D1, D2 = example.Q1, example.Q2, example.D1, example.D2
  return simultaneous_convolution(Q1, D1, X) + simultaneous_convolution(Q2, D2, X)

In [189]:
def ttax_make_objective(example):
  """
  returns a callable functional of x,
  where A and b are provided by the example
    0.5 * x*A*x - x*b 
  """
  B = example.B
  return lambda x: 0.5 * simultaneous_apply(example, x) - ttax.flat_inner(x, B)

def ttax_norm(x):
  """
   returns a norm of a tensor x 
    <x, x>^0.5 
  """
  return jnp.sqrt(ttax.flat_inner(x, x))

def ttax_make_residual(example):
  """
   returns a callable functional of x
   where A is a generic callable operator on x' space
   and is provided by the given example
    |A*x - b| 
  """
  B = example.B
  return lambda x: ttax_norm(ort(ttax.round(example.evaluate(x) + (-1)*B, example.rkx)))

In [190]:
def ttax_retract(TT, v, r):
  """
  returns retraction R(T, v) = T + v
  while rounding result to the closest tensor of rank r
  """
  return ort(ttax.round(TT + v, max_tt_rank=r))

def ttax_armijo_backtracking(init, grad, mul, beta, func, x, rk):
    """
    Standard Armijo line search algorithm
    """
    alpha = init
    while func(x) < func(ttax_retract(x, -alpha*grad, rk)) + mul*alpha*ttax_norm(grad)**2:
        alpha *= beta
    return alpha

def transpose_operator(operator_cores):
    """
    Transposes the given operator
    """
    cores = [jnp.einsum('aijb->ajib', core) for core in operator_cores]
    return cores

Генерация примера из условия

In [191]:
class Example:
  def __init__(self, Q1, D1, Q2, D2, X, x0, B, evaluate, rkx):
    self.Q1 = Q1
    self.D1 = D1
    self.Q2 = Q2
    self.D2 = D2
    self.X = X
    self.x0 = x0
    self.B = B
    self.evaluate = evaluate
    self.rkx = rkx

def generate_example(shape, rkq = 6, rkd = 2, rkx = 8):
  """
  Returns:
      matrices Q, D with ranks of rkq, rkd respectively 
    in a TT representation each
      vector X with rank rkx in a TT representation
  """
  assert len(shape) == 4
  
  Xshape = (shape[0], shape[2])

  # Q, D - tensor operators from R^xshape to R^xshape
  # feelsgood xd 

  Q1 = ttax.random.tensor(jaxkey, shape=shape, tt_rank=rkq)
  D1 = ttax.random.tensor(jaxkey, shape=shape, tt_rank=rkd)

  Q2 = ttax.random.tensor(jaxkey, shape=shape, tt_rank=rkq)
  D2 = ttax.random.tensor(jaxkey, shape=shape, tt_rank=rkd)

  x0 = ttax.random.tensor(jax.random.PRNGKey(43), shape=Xshape, tt_rank=rkx)
  X = ttax.random.tensor(jaxkey, shape=Xshape, tt_rank=rkx)

  Q1o = ttax_make_operator(Q1)
  Q2o = ttax_make_operator(Q2)
  D1o = ttax_make_operator(D1)
  D2o = ttax_make_operator(D2)

  Q1oT = transpose_operator(Q1o)
  Q2oT = transpose_operator(Q2o)

  # looks good
  evaluate = lambda x: ort(ttax.round(ttax_matmul(Q1oT, ttax_matmul(D1o, ttax_matmul(Q1o, x))) \
                        + ttax_matmul(Q2oT, ttax_matmul(D2o, ttax_matmul(Q2o, x))), max_tt_rank=rkx))
  
  B = evaluate(X)
  
  return Example(Q1o, D1o, Q2o, D2o, X, x0, B, evaluate, rkx)

In [193]:
def ttax_solve(objective, residual, x0, rk, iters=500, tol=1e-2, debug=False):

  riemann_grad_at = ttax.autodiff.grad(objective)

  x = x0

  residuals = []

  i = 0

  while residual(x) > tol and i < iters:

    riemann_grad = ort(riemann_grad_at(x))
    alpha = ttax_armijo_backtracking(1, riemann_grad, 1e-4, 0.8, objective, x, rk)
    
    x = ttax_retract(x, (-alpha)*riemann_grad, rk)

    residuals.append( residual(x) )   
    if debug:
      print(f"{i} th iteration: {residuals[-1]}")

    i += 1

  return x, residuals

In [199]:
n = 4
shape = (n, n, n, n)

rkq, rkd, rkx = 6, 2, 3
example = generate_example(shape, rkq = rkq, rkd = rkd, rkx = rkx)

In [200]:
objective = ttax_make_objective(example)
residual = ttax_make_residual(example)

sanity checks

In [201]:
assert residual(example.X) < 1e-1
assert ttax_norm(example.evaluate(example.X) + (-1)*example.B) < 1e-1

In [202]:
y, res = ttax_solve(objective, residual, x0=example.x0, rk=rkx, debug=True, tol=1e-1)

0 th iteration: 610272083968.0
1 th iteration: nan


In [203]:
res

[DeviceArray(6.102721e+11, dtype=float32), DeviceArray(nan, dtype=float32)]