Load dependencies:

In [153]:
import numpy as np
import tensorflow as tf

Implementation of Jennrich's algorithm(stable):

In [34]:
def Jennrich(M1: tf.Tensor, M2: tf.Tensor, r: int):
    # Inputs:
    # M1, M2: d-by-d matrices of the same rank = r
    # Outputs:
    # the eigenvectors of M1*M2^(-1)
    S, U, V = tf.linalg.svd(M1)
    W = U[:, 0 : r]
    M1_whitened = tf.linalg.matmul(tf.linalg.matmul(tf.transpose(W), M1), W)
    M2_whitened = tf.linalg.matmul(tf.linalg.matmul(tf.transpose(W), M2), W)
    M = tf.linalg.matmul(M1_whitened, tf.linalg.inv(M2_whitened))
    e, P = tf.linalg.eig(M)
    return tf.linalg.matmul(W, tf.math.real(P))

In [35]:
def matSys(A_r, Tx, x):
    # Solve a matrix valued linear system A_r * Xi * diag(A_r.T * x) * A_r.T = Tx
    # Inputs: 
    # A_r: a d-by-r matrix; 
    # Tx: a d-by-d matrix; 
    # x: a d-by-1 vector
    # Outputs:
    # Xi: a d-by-1 vector
    B = tf.transpose(tf.linalg.pinv(A_r))
    xi = tf.linalg.diag_part(tf.linalg.matmul(tf.linalg.matmul(tf.transpose(B), Tx), B))
    Xi = tf.divide(xi, tf.tensordot(tf.transpose(A_r), x, axes= [[1],[0]]))
    return Xi

In [36]:
def decompose(T: tf.Tensor, x: tf.Tensor, y: tf.Tensor, rank):
    # Given an input tensor, run Jennrich's algorithm and length recovery
    # Inputs:
    # T: tensor to be decomposed;
    # x,y: vectors used in Jennrich's algorithm; 
    # rank: rank of T.
    # Outputs:
    # A: a matrix with columns as directions of tensor components
    # Xi: a vector containing the length of each conponent
    T_x = tf.tensordot(T, x, axes=[[0],[0]])
    T_y = tf.tensordot(T, y, axes=[[0],[0]])
    A = Jennrich(T_x, T_y, r = rank)
    Xi = matSys(A, T_x, x)
    return A, Xi

Helper function on building the symmetric tensor:

In [37]:
def constructSymTensor(components: tf.Tensor, weights = None):
    # Construct a 3rd order symmetric tensor with columns of components and muliplicative constants from weights
    # components: an n-by-m matrix with m components in dimension n
    # weights: a m-by-1 vector
    dim = components.shape[0]
    rank = components.shape[1]
    if weights == None:
        weights = tf.ones(shape=[rank,1])
    elif weights.ndim == 1:
        weights = tf.expand_dims(weights, axis = 1)
    T = tf.zeros(shape = [dim,dim,dim])
    for i in range(rank):
        ai = tf.expand_dims(components[:,i], axis = 1)
        T = tf.add(T, tf.einsum('ij,kl->ijk', tf.tensordot(ai, weights[i,0] * ai, axes= [[1],[1]]) , ai))
    return T

Implementation of accelerated Algorithm 3.3 on Tensorflow:

In [82]:
class OvercompleteTensorDecomposition(tf.Module):
    # Overcomplete tensor decomposition "layer"
    def __init__(self, dimension, name = None):
        super().__init__(name = name)
        init = tf.random_normal_initializer()
        init1 = init(shape = [dimension], dtype = tf.float32)
        init2 = init(shape = [dimension], dtype = tf.float32)
        self.x_magic = tf.Variable(initial_value = tf.divide(init1,tf.norm(init1)), name = "x_magic")
        self.y_magic = tf.Variable(initial_value = tf.divide(init2,tf.norm(init2)), name = "y_magic")
        self.x_2nd = tf.constant(value= init(shape = [dimension], dtype= tf.float32), name = "x_2nd")
        self.y_2nd = tf.constant(value= init(shape = [dimension], dtype= tf.float32), name = "y_2nd")
    def __call__(self, T: tf.Tensor, tensor_rank, overcomplete_param):
        # T: a symmetric d-by-d-by-d tensor of rank tensor_rank
        
        # Decompose the first r compoents
        A_r, Xi_r = decompose(T, self.x_magic, self.y_magic, rank = tensor_rank - overcomplete_param)

        # Deflation
        T_first = constructSymTensor(components= A_r, weights= Xi_r)
        R = T - T_first

        # 2nd decomposition
        # x_2nd = tf.random.normal(shape = self.x_magic.shape)
        # y_2nd = tf.random.normal(shape = self.y_magic.shape)
        # A_k, Xi_k = decompose(R, x_2nd, y_2nd, rank = overcomplete_param)
        A_k, Xi_k = decompose(R, self.x_2nd, self.y_2nd, rank = overcomplete_param)
        
        # Reconstruction
        T_second = constructSymTensor(components= A_k, weights= Xi_k)

        return T_first + T_second

In [39]:
class TensorDecomposition(tf.Module):
    # Non-degenerate tensor decomposition "layer"
    def __init__(self, dimension, name = None):
        init = tf.random_normal_initializer()
        self.x_magic = tf.Variable(initial_value = init(shape = [dimension], dtype = tf.float32), name = "x_magic")
        self.y_magic = tf.Variable(initial_value = init(shape = [dimension], dtype = tf.float32), name = "y_magic")
    def __call__(self, T, tensor_rank):
        # T: a symmetric d-by-d-by-d tensor of rank tensor_rank
        
        # Decompose the tensor
        A_r, Xi_r = decompose(T, self.x_magic, self.y_magic, rank = tensor_rank)

        # Reconstruction
        T_prime = constructSymTensor(components=A_r, weights=Xi_r)
        return T_prime

In [97]:
def loss_func(T1: tf.Tensor, T2: tf.Tensor):
    return tf.nn.l2_loss(T1-T2)/tf.norm(T1)

In [41]:
d = 10
model = TensorDecomposition(dimension = d)
print(model.trainable_variables)

(<tf.Variable 'x_magic:0' shape=(10,) dtype=float32, numpy=
array([-0.07788632,  0.02765081, -0.01590446, -0.03060139, -0.05913058,
       -0.06392466, -0.02336155, -0.01771414,  0.00486755, -0.05758289],
      dtype=float32)>, <tf.Variable 'y_magic:0' shape=(10,) dtype=float32, numpy=
array([ 6.6815086e-02,  8.8878140e-02,  2.6275540e-02, -5.4046657e-02,
       -4.7551796e-02, -1.8458270e-02,  5.9509831e-03,  8.7893255e-02,
        6.2441934e-02, -2.4321165e-05], dtype=float32)>)


Define the training process

In [42]:
def train(model, T, rank_T, learning_rate = 1e-5):
    with tf.GradientTape() as t:
        current_loss = loss_func(T, model(T, rank_T))
    dx, dy = t.gradient(current_loss,[model.x_magic, model.y_magic])
    model.x_magic.assign_sub(learning_rate * dx)
    model.y_magic.assign_sub(learning_rate * dy)

Generate a random tensor(non-denegerate):

In [43]:
# Generate a random tensor
rank = 5
A = tf.random.normal(shape=[d, rank])
T = constructSymTensor(components=A)
print(T.shape)

(10, 10, 10)


Test non-degenrate tensor decomposition on this random tensor:

In [44]:
T_prime = model(T = T, tensor_rank = rank)
print(tf.norm(T_prime - T)/tf.norm(T))

tf.Tensor(4.241754e-05, shape=(), dtype=float32)


Now let us see if training changes anything:

In [58]:
def train_loop(model, T, rank_T, num_epochs = 10):
    epochs = range(num_epochs)
    x = []
    y = []
    for epoch in epochs:
        train(model, T, rank_T)
        x.append(model.x_magic.numpy())
        y.append(model.y_magic.numpy())
        current_loss = loss_func(T, model(T, rank_T))
        print("Epoch {}: x= {}, y= {}, loss= {}".format(epoch, x[-1], y[-1], current_loss))

In [46]:
train_loop(model, T, rank, num_epoch = 100)

90475 -0.03060123 -0.05913024 -0.06392422
 -0.02336191 -0.01771467  0.0048679  -0.05758258], y= [ 6.6814624e-02  8.8879801e-02  2.6275033e-02 -5.4046374e-02
 -4.7551226e-02 -1.8457511e-02  5.9503578e-03  8.7892354e-02
  6.2442522e-02 -2.3765293e-05], loss= 1.8368396013102029e-06
Epoch 39: x= [-0.0778866   0.02765182 -0.01590476 -0.03060122 -0.05913023 -0.06392421
 -0.02336193 -0.01771469  0.00486791 -0.05758256], y= [ 6.6814609e-02  8.8879868e-02  2.6275013e-02 -5.4046363e-02
 -4.7551204e-02 -1.8457482e-02  5.9503326e-03  8.7892316e-02
  6.2442545e-02 -2.3743192e-05], loss= 4.380010523163946e-06
Epoch 40: x= [-0.07788662  0.02765191 -0.01590479 -0.03060121 -0.0591302  -0.06392417
 -0.02336196 -0.01771474  0.00486794 -0.05758253], y= [ 6.6814564e-02  8.8880017e-02  2.6274966e-02 -5.4046337e-02
 -4.7551151e-02 -1.8457413e-02  5.9502753e-03  8.7892234e-02
  6.2442601e-02 -2.3692866e-05], loss= 1.6557555682084057e-06
Epoch 41: x= [-0.07788663  0.02765194 -0.0159048  -0.0306012  -0.05913018

Non-denegerate case needs not gradient descent. Now we turn to overcomplete case:

In [134]:
dim = 10
rank = 11
A = tf.random.normal(shape=[dim, rank])
T = constructSymTensor(A)

In [162]:
def train_loop_overcomplete(model_overcomplete, T, rank, overcomplete_param, num_epochs = 100, learning_rate = 1e-6, diminishing_rate = False):
    epochs = range(num_epochs)
    x = []
    y = []
    dxs = []
    dys = []
    for epoch in epochs:
        if diminishing_rate:
            with tf.GradientTape() as t:
                current_loss = loss_func(T, model_overcomplete(T= T, tensor_rank = rank, overcomplete_param = overcomplete_param))
            dx, dy = t.gradient(current_loss,[model_overcomplete.x_magic, model_overcomplete.y_magic])
            model_overcomplete.x_magic.assign_sub(learning_rate/float(epoch+1) * dx)
            model_overcomplete.y_magic.assign_sub(learning_rate/float(epoch+1) * dy)
        else:
            with tf.GradientTape() as t:
                current_loss = loss_func(T, model_overcomplete(T= T, tensor_rank = rank, overcomplete_param = overcomplete_param))
            dx, dy = t.gradient(current_loss,[model_overcomplete.x_magic, model_overcomplete.y_magic])
            model_overcomplete.x_magic.assign_sub(learning_rate * dx)
            model_overcomplete.y_magic.assign_sub(learning_rate * dy)

        x.append(model_overcomplete.x_magic.numpy())
        y.append(model_overcomplete.y_magic.numpy())
        dxs.append(dx.numpy())
        dys.append(dy.numpy())
        current_loss = loss_func(T, model_overcomplete(T, rank, overcomplete_param))

        if epoch % 100 == 0:
            print("Epoch {}: x= {}, y= {}, loss= {}, dx = {}, dy = {}, dx_norm = {}, dy_norm = {}".format(epoch + 1, x[-1], y[-1], current_loss, dxs[-1], dys[-1], np.linalg.norm(dxs[-1]), np.linalg.norm(dys[-1])))

In [163]:
model_overcomplete = OvercompleteTensorDecomposition(dimension = dim)
overcomplete_param = rank - dim
train_loop_overcomplete(model_overcomplete, T, rank, overcomplete_param, num_epochs= 500,learning_rate= 1e-1)

Epoch 1: x= [ 0.38853985  0.19724813  0.736214    0.09377964 -0.3380325   0.4594486
 -0.17199281 -0.34954566  0.49484527  0.43105042], y= [ 0.37388813 -2.426167    1.8121659  -0.8972436  -1.2933408   4.0304503
 -1.5983369   1.0135648   0.8357092  -0.23720054], loss= 13.515523910522461, dx = [-4.0202928  -0.8287794  -1.3978753   1.217087    0.3960721  -5.3751144
  3.5441403   0.6933507  -1.4035525   0.70913374], dy = [ -2.4681168   23.190847   -15.002649     4.41963      7.478448
 -40.130196    21.420576   -12.0588455   -8.804589     0.49234867], dx_norm = 8.052839279174805, dy_norm = 56.00802993774414
Epoch 101: x= [ -2.7986634 -19.758486   16.741667  -40.7492     73.0789     14.039474
  12.9976015 -42.426697  -14.59312    20.816542 ], y= [ 23.477768   10.299561  -13.8942175 -21.947342   13.912668   19.646713
 -11.553485    7.126324   35.455296   -9.720668 ], loss= 1.0434242486953735, dx = [-0.00475306 -0.01826433  0.00440744  0.06275612  0.04034863 -0.00034143
 -0.05416786 -0.011414  

In [164]:
train_loop_overcomplete(model_overcomplete, T, rank, overcomplete_param, num_epochs= 5001,learning_rate= 1e-1)

 = [ 2.3676702e-03  1.4003960e-04  1.5319552e-03  7.6136948e-03
  3.3943523e-03 -8.0754515e-04 -3.5742875e-03 -4.3092412e-03
  7.9824310e-04 -5.9047015e-05], dy = [-0.00079209  0.0007177  -0.004044   -0.00020295 -0.00413778  0.00587654
  0.00254792 -0.00570007 -0.00020912 -0.00055352], dx_norm = 0.010492865927517414, dy_norm = 0.010417554527521133
Epoch 1901: x= [ -3.181981  -18.97912    15.72421   -45.503983   69.97427    14.6321945
  16.52827   -41.38327   -15.290918   21.763924 ], y= [ 25.999348   10.515712  -12.5727005 -17.914877   16.76667    17.49985
 -13.925492    6.6874394  35.66315    -9.165723 ], loss= 0.3391835391521454, dx = [ 2.3927507e-03  2.1048635e-04  1.4549308e-03  7.3388689e-03
  3.2151230e-03 -7.8128558e-04 -3.3738082e-03 -4.2565856e-03
  7.4689090e-04  6.5353233e-06], dy = [-0.00063205  0.00074363 -0.00389457  0.00025949 -0.00399104  0.00597394
  0.00227088 -0.00603029 -0.00019315 -0.00059834], dx_norm = 0.010136655531823635, dy_norm = 0.01047474890947342
Epoch 200

In [166]:
train_loop_overcomplete(model_overcomplete, T, rank, overcomplete_param, num_epochs= 5001,learning_rate= 1e-1)

1292], loss= 0.011770645156502724, dx = [ 0.0012609  -0.00010637  0.00048674  0.00364389  0.00155568  0.00014425
 -0.00108154 -0.00263286  0.0005623   0.00037001], dy = [-0.02921378 -0.00520536 -0.00845924 -0.04908286 -0.00791062 -0.01474363
  0.01404963  0.05612287 -0.00545368 -0.00784992], dx_norm = 0.005109965335577726, dy_norm = 0.08414101600646973
Epoch 2001: x= [ -5.4953084 -19.053171   14.868563  -51.24507    67.985916   14.233114
  18.631903  -36.96894   -16.031303   21.35634  ], y= [ 25.598614   9.759948 -12.005696 -17.743155  18.715433  14.836512
 -14.260324   9.258775  36.060852  -8.841014], loss= 0.01061057299375534, dx = [ 2.4162277e-03  4.0028244e-05  8.4495475e-04  5.7358975e-03
  1.9750837e-03  7.0792949e-04 -1.6593318e-03 -4.8739803e-03
  8.3058386e-04  7.0697116e-04], dy = [-0.00775873 -0.00127608 -0.00230808 -0.01341804 -0.00253477 -0.00356439
  0.00386403  0.01483169 -0.00153965 -0.00213174], dx_norm = 0.008459024131298065, dy_norm = 0.02254221960902214
Epoch 2101: 

Verify that we have the correct magic vectors:

In [168]:
print(model_overcomplete.x_magic, model_overcomplete.y_magic)
A_r, Xi_r = decompose(T, model_overcomplete.x_magic, model_overcomplete.y_magic, rank = 10)
print(A_r)

<tf.Variable 'x_magic:0' shape=(10,) dtype=float32, numpy=
array([ -5.9696712, -19.067682 ,  14.700633 , -52.320652 ,  67.626854 ,
        14.081567 ,  18.927778 , -36.020607 , -16.18811  ,  21.206348 ],
      dtype=float32)> <tf.Variable 'y_magic:0' shape=(10,) dtype=float32, numpy=
array([ 25.6511  ,   9.782886, -11.992485, -17.698637,  18.70946 ,
        14.84865 , -14.284846,   9.167983,  36.058006,  -8.831934],
      dtype=float32)>
tf.Tensor(
[[ 0.08843441 -0.19704808  0.2235612  -0.2583426   0.6635888  -0.2526739
  -0.17949477 -0.38038832  0.3997382  -0.21994777]
 [ 0.07108357 -0.16983761  0.07088809  0.17949645  0.03512893  0.33813894
  -0.28792596 -0.47485015  0.21703234  0.09849288]
 [-0.09376454  0.20533454 -0.12981196 -0.27900356 -0.20335223  0.2669789
  -0.0847471  -0.6081371  -0.07520815  0.38948455]
 [-0.23393172  0.29528874  0.03627873  0.04959017 -0.0993012   0.11166164
  -0.06488676  0.30398586  0.07744135 -0.5494448 ]
 [-0.57449895 -0.34911084 -0.15225846 -0.13923524

In [174]:
A_normalized, A_norm = tf.linalg.normalize(A, axis=0)
print(A_normalized)

tf.Tensor(
[[ 0.39915055 -0.24594928  0.08908025  0.3232107  -0.17958044 -0.29740816
  -0.25273916 -0.3804448  -0.21985348 -0.22352764 -0.6635542 ]
 [ 0.21796511 -0.21263486  0.07250353 -0.22504432 -0.28853622 -0.0082253
   0.33809313 -0.47478753  0.09858713 -0.0708352  -0.03513132]
 [-0.07511914  0.25701267 -0.0949136   0.3498101  -0.08395986 -0.11116836
   0.26710045 -0.6088567   0.38946506  0.12978077  0.20266382]
 [ 0.07365488  0.36890128 -0.23513165 -0.06152458 -0.06323229 -0.66188323
   0.11174922  0.30367872 -0.549547   -0.03633233  0.09882505]
 [-0.55097574 -0.43682876 -0.57343113  0.17404373 -0.35252157 -0.22519593
  -0.6312445  -0.11562872  0.36131364  0.15220949 -0.00293561]
 [ 0.44476533 -0.68451256 -0.17537679  0.6968552   0.30106437 -0.09560683
  -0.46489236 -0.08915525 -0.17068948  0.310508    0.27272016]
 [-0.09950895  0.13268483  0.62851197 -0.07177387 -0.53424835  0.1706571
   0.16400743 -0.06767393  0.2105676   0.13344264 -0.40540206]
 [-0.50774926 -0.07943285 -0.268

In [172]:
print(tf.matmul(tf.transpose(A_normalized), A_r))

tf.Tensor(
[[ 0.34491718 -0.14739618 -0.10463247 -0.23556773  0.1462432  -0.02358371
   0.02030449 -0.1621542   0.99965674 -0.17984775]
 [ 0.3431466   0.80007327  0.21734327  0.38269773  0.00733831  0.7246169
  -0.11256169  0.28018075 -0.18278483 -0.03248984]
 [ 0.9999964   0.27221626 -0.02395384  0.1172271   0.5845563   0.45028615
  -0.44510022 -0.06642469  0.34277248  0.2176605 ]
 [-0.14672795 -0.38323796 -0.05012612 -0.79896814  0.07997853 -0.4601581
   0.00741411 -0.24864005  0.29530078  0.12519506]
 [-0.44660982 -0.08958185  0.19236274 -0.00649383 -0.64734805  0.0069145
   0.9999949   0.33082396  0.01919465 -0.5693862 ]
 [ 0.20083196 -0.05703301  0.0060225   0.27810305 -0.11163344  0.25102112
   0.37461963 -0.05172858 -0.38691086  0.03449243]
 [ 0.45130554  0.5797775   0.4131079   0.368054   -0.05495837  0.9999961
   0.0077643  -0.00446488 -0.02467888 -0.05683722]
 [-0.06471471  0.22432928  0.33524716  0.19911566 -0.2456652  -0.00415079
   0.33057138  0.9995945  -0.16005059 -0.384

In [175]:
print(A_normalized[:,0], A_r[:,-2])

tf.Tensor(
[ 0.39915055  0.21796511 -0.07511914  0.07365488 -0.55097574  0.44476533
 -0.09950895 -0.50774926 -0.11057036 -0.02782716], shape=(10,), dtype=float32) tf.Tensor(
[ 0.3997382   0.21703234 -0.07520815  0.07744135 -0.5494324   0.44459853
 -0.10072564 -0.508204   -0.10979505 -0.02776888], shape=(10,), dtype=float32)


Save the notebook session:

In [177]:
import dill
dill.dump_session('first_success_save.db')

Remaining cells are for function testings:

In [48]:
# Check tensor-vector product
a = tf.ones(shape = [2,2,2])
b = tf.ones(shape = [2])
c = tf.tensordot(a,b,axes = [[0],[0]])
print(c)

tf.Tensor(
[[2. 2.]
 [2. 2.]], shape=(2, 2), dtype=float32)


In [49]:
# Check tensor-vector product
a = tf.ones(shape = (2,2,2))
b = tf.ones(shape = (2,))
c = tf.tensordot(a,b, axes=1)
print(c)

tf.Tensor(
[[2. 2.]
 [2. 2.]], shape=(2, 2), dtype=float32)


In [50]:
b = tf.random.normal(shape = (3,2))
B = constructSymTensor(components=b)
print(B)
print(b[0,0]**3 + b[0,1]**3)

tf.Tensor(
[[[-0.11650289 -0.01634141  0.29392678]
  [-0.01634141  0.04184418 -0.20585784]
  [ 0.29392678 -0.20585784  0.64169407]]

 [[-0.01634141  0.04184418 -0.20585784]
  [ 0.04184419 -0.04064367  0.15482171]
  [-0.20585784  0.15482171 -0.5090162 ]]

 [[ 0.29392678 -0.20585784  0.64169407]
  [-0.20585784  0.15482171 -0.5090162 ]
  [ 0.64169407 -0.5090162   1.7345427 ]]], shape=(3, 3, 3), dtype=float32)
tf.Tensor(-0.11650288, shape=(), dtype=float32)


In [51]:
a = tf.constant(value= [[1,1],[0,-1]],dtype=float)
b = tf.random.normal(shape = [2])
print(a,b)

tf.Tensor(
[[ 1.  1.]
 [ 0. -1.]], shape=(2, 2), dtype=float32) tf.Tensor([ 0.03729795 -0.8495061 ], shape=(2,), dtype=float32)


In [52]:
print(tf.tensordot(a,b,axes=[[1],[0]]))
print(tf.linalg.diag_part(a))

tf.Tensor([-0.8122081  0.8495061], shape=(2,), dtype=float32)
tf.Tensor([ 1. -1.], shape=(2,), dtype=float32)
