In [76]:
import jax
import ttax
import jax.numpy as jnp

In [78]:
dot_prod = lambda x: ttax.ops.flat_inner(x, x)
full_dot_prod = lambda x: jnp.tensordot(x, x, len(x.shape))
norm_TT = lambda x: jnp.sqrt(dot_prod(x))
norm_full = lambda x: jnp.sqrt(full_dot_prod(x))

actual_grad = lambda x: 2 * x
ttax_grad = ttax.autodiff.grad(dot_prod)
ttax_full_grad = jax.grad(full_dot_prod)

In [88]:
failed_tensor = None
def test(tensor):
    diff = ttax_grad(tensor) + (-1) * actual_grad(tensor)
    print("dot", dot_prod(diff))
    if dot_prod(diff) < 0:
        return test(ttax.orthogonalize(tensor))
    return norm_TT(diff) / norm_TT(tensor)

def test_full(tensor):
    diff = ttax_full_grad(tensor) + (-1) * actual_grad(tensor)
    return norm_full(diff) / norm_full(tensor)

def test_round(tensor):
    diff = ttax.round(ttax.round(ttax_grad(tensor)) +
                      ttax.round((-1) * ttax.round(actual_grad(tensor))))
    if dot_prod(diff) < 0:
        return test_round(ttax.orthogonalize(tensor))
    return norm_TT(diff) / norm_TT(tensor)

In [90]:
from numpy import random

n = 5
for _ in range(n):
    tt_rank = random.randint(3, 7)
    modes = random.randint(1, 30, tt_rank - 1)
    TT_tensor = ttax.random.tensor(jax.random.PRNGKey(42),
                          modes, tt_rank=tt_rank, dtype=jnp.float32)
    print("TT_rank is", tt_rank)
    print("Difference between gradients")
    print("\t\t\tTensor in TT format:", test(TT_tensor))
    # print("\t\t\tTensor in dense format:", test_full(ttax.full(TT_tensor)))
    print("\t\t\tTensor with rounding:", test_round(TT_tensor))
    # print("TT_rank is", tt_rank, " test with rounding result", test_round(ttax.full(TT_tensor)))

TT_rank is 5
Difference between gradients
dot -0.29713368
dot 0.0
			Tensor in TT format: 0.0
			Tensor with rounding: 2.5066938e-06
TT_rank is 6
Difference between gradients
dot 191.76631
			Tensor in TT format: 0.0006665873
			Tensor with rounding: 3.4745083e-06
TT_rank is 3
Difference between gradients
dot 0.00069228926
			Tensor in TT format: 0.0010049982
			Tensor with rounding: 1.1359762e-06
TT_rank is 3
Difference between gradients
dot -4.9602706e-05
dot -0.0007324219
dot -0.00044976198
dot 0.00024414062
			Tensor in TT format: 0.00045103356
			Tensor with rounding: 9.137934e-07
TT_rank is 5
Difference between gradients
dot -3.6882758
dot -1.0884006
dot -2.4634006
dot -1.5
dot -1.625
dot -2.5518012
dot -0.125
dot 1.0
			Tensor in TT format: 0.00051226997
			Tensor with rounding: 2.5806846e-06


In [92]:
n = 10
for _ in range(n):
    tt_rank = random.randint(3, 6)
    modes = random.randint(1, 30, tt_rank - 1)
    TT_tensor1 = ttax.random.tensor(jax.random.PRNGKey(42),
                          modes, tt_rank=tt_rank, dtype=jnp.float32)
    TT_tensor2 = ttax.random.tensor(jax.random.PRNGKey(42),
                          modes, tt_rank=tt_rank, dtype=jnp.float32)
    TT_tensor = TT_tensor2 + (-1) * TT_tensor1
    print("inner product in TTAX:", dot_prod(TT_tensor),
          "\tinner product in JAX:", full_dot_prod(ttax.full(TT_tensor)))
    if failed_tensor is None and dot_prod(TT_tensor) < 0:
        failed_tensor = TT_tensor

inner product in TTAX: 3.3848824e-05 	inner product in JAX: 1.4503537e-12
inner product in TTAX: 4.2504646e-05 	inner product in JAX: 4.2257725e-12
inner product in TTAX: -2.6522805e-06 	inner product in JAX: 7.4804746e-14
inner product in TTAX: -2.6527265e-05 	inner product in JAX: 2.299886e-12
inner product in TTAX: -2.355309e-05 	inner product in JAX: 3.644405e-12
inner product in TTAX: 0.00071999175 	inner product in JAX: 2.6259729e-11
inner product in TTAX: -2.0049207e-05 	inner product in JAX: 2.970988e-12
inner product in TTAX: 0.08153069 	inner product in JAX: 1.0617368e-08
inner product in TTAX: -0.13295794 	inner product in JAX: 8.215069e-08
inner product in TTAX: -3.791185e-06 	inner product in JAX: 6.906975e-13


In [97]:
def naive_tensor_inner_product(tensor : ttax.base_class.TT):
    res = jnp.einsum('rni,rnj->ij', tensor.tt_cores[0], tensor.tt_cores[0])
    for i in range(1, len(tensor.tt_cores)):
        cumul = jnp.einsum('anb,cnd->abcd', tensor.tt_cores[i], tensor.tt_cores[i])
        res = jnp.einsum('ij,iajb->ab', res, cumul)
    return jnp.squeeze(res)

def t3f_inner_product(tt_a: ttax.base_class.TT, tt_b : ttax.base_class.TT):
  axes_str = 'i'
  init_einsum_str = '{1}a{0}b,{2}c{0}d->{3}bd'.format(axes_str, '',
                                                      '',
                                                      '')
  a_core = tt_a.tt_cores[0]
  b_core = tt_b.tt_cores[0]
  # Simplest example of this operation:
  # if both arguments are TT-tensors, then it is
  # res = tf.einsum('aib,cid->bd', a_core, b_core)
  res = jnp.einsum(init_einsum_str, a_core, b_core)

  einsum_str = '{3}ac,{1}a{0}b,{2}c{0}d->{3}bd'.format(axes_str, '',
                                                       '',
                                                       '')
  for core_idx in range(1, tt_a.ndim):
    a_core = tt_a.tt_cores[core_idx]
    b_core = tt_b.tt_cores[core_idx]
    # Simplest example of this operation:
    # if both arguments are TT-tensors, then it is
    # res = tf.einsum('ac,aib,cid->bd', res, a_core, b_core)
    res = jnp.einsum(einsum_str, res, a_core, b_core)
  return jnp.squeeze(res)

new_failed_tensor = failed_tensor

print("TTAX inner product\t\t\t", ttax.ops.flat_inner(new_failed_tensor, new_failed_tensor))
print("JAX inner product\t\t\t", jnp.tensordot(ttax.full(new_failed_tensor), ttax.full(new_failed_tensor), len(new_failed_tensor.tt_cores)))
print("T3F inner product\t\t\t", t3f_inner_product(new_failed_tensor, new_failed_tensor))
print("All tensor elements are squeezed between %f and %f"
      %(jnp.amin(ttax.full(new_failed_tensor)), jnp.amax(ttax.full(new_failed_tensor))))

TTAX inner product			 -2.6522805e-06
JAX inner product			 7.4804746e-14
T3F inner product			 -2.6522805e-06
All tensor elements are squeezed between -0.000000 and 0.000000


<div style="
        background-color: #FFF3C2;
        padding: 6px;
        border-radius: 7px;
    ">
To sum up: there is an awful error in inner product of T3F and TTAX without orthogonalization.
JAX seem to do orthogonalization by itself.
</div>
