# Geometric Tensor Learning
---

In [1]:
import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt

import networkx as nx

from util.t2m import t2m
from util.m2t import m2t
from util.update_L import update_L
from util.update_X import update_X
from util.update_Lambda import update_Lambda

In [2]:
sizes = (10, 15, 12, 16)
n = len(sizes)

In [3]:
X = np.random.standard_normal(size=sizes) # WGN to generate a stationary signal.

In [4]:
G = [nx.erdos_renyi_graph(sizes[i],0.3) for i in range(n)] # List of graphs for each mode
Phi = [nx.laplacian_matrix(G[i]).todense() for i in range(n)]


In [5]:
# Generate data smooth in Cartesian graph G.
ranks = (5, 5, 5, 5)
W_all = 1
V_all = 1

for i in range(n):
    W, V = np.linalg.eigh(nx.laplacian_matrix(G[i]).todense())
    temp_eig = np.flip(np.sort(abs(np.random.randn(ranks[i]))))
    W_all = np.kron(W_all, np.ones(ranks[i])) + np.kron(np.ones(ranks[:i]).flatten(), temp_eig)
    V_all = np.kron(V_all, V[:,:ranks[i]])

X_smooth = np.tensordot(V_all, W_all, axes=([1],[0])).reshape(sizes)
del W_all, V_all

In [6]:
noise_ratio = 0.01
missing_ratio = 0.2
norm_X = np.sqrt(np.sum(X_smooth**2))
mask = np.random.uniform(size=np.prod(sizes)).reshape(sizes)-missing_ratio<0
Y = ma.array(X_smooth + noise_ratio * np.sqrt(norm_X) * np.random.standard_normal(sizes), mask = mask)

In [7]:
# Parameters
alpha = [[1 for i in range(n)], 
    [1 for i in range(n)], 
    [1 for i in range(n)]]
theta = [1 for i in range(n)]
gamma = [1 for i in range(n)]

# Initializations
L = np.zeros(sizes)
G_var = [np.zeros(sizes) for i in range(n)]
X = [np.zeros(sizes) for i in range(n)]
Lx = [np.zeros(sizes) for i in range(n)]
Lambda = [[np.zeros(sizes) for i in range(n)],
    [np.zeros(sizes) for i in range(n)],
    [np.zeros(sizes) for i in range(n)]]

In [8]:
def fn_val_G(G, L, Lambda, alpha, gamma):
    n = len(G)
    val_L = [gamma[i]*np.trace(t2m(G[i],i).transpose() @ Phi[i] @ t2m(G[i],i)) for i in range(i)]
    val_Lag = [alpha[i]*np.linalg.norm(L-G[i]-Lambda[i])**2 for i in range(n)]
    fn_val = sum(val_L) + sum(val_Lag)
    return fn_val, val_L, val_Lag

def fn_val_L(L, Y, Lx, G, Lambda, alpha):
    n = len(L.shape)
    val_Y = np.linalg.norm(Y[~Y.mask]-L[~Y.mask])**2
    val_Lag1 = [alpha[0][i]*np.linalg.norm(L-G[i]-Lambda[0][i])**2 for i in range(n)]
    val_Lag2 = [alpha[1][i]*np.linalg.norm(L-Lx[i]-Lambda[1][i])**2 for i in range(n)]
    fn_val = val_Y + sum(val_Lag1) + sum(val_Lag2)
    return fn_val, val_Y, val_Lag1, val_Lag2

In [9]:
G_inv = [np.linalg.inv(gamma[i]*Phi[i] + alpha[0][i]*np.eye(sizes[i])) for i in range(n)]
iter=0
while True:
    temp = np.zeros(sizes)
    for i in range(n):
        temp += alpha[1][i]*(G_var[i] + Lambda[0][i])
        temp += alpha[2][i]*(Lx[i] + Lambda[1][i])
    
    print("Function value for variable L: {}".format(fn_val_L(L, Y, Lx, G_var, Lambda[:2], alpha[:2])[0]))
    L = temp/(sum(alpha[0]) + sum(alpha[1]) + 1)
    L[~Y.mask] += Y[~Y.mask]/(sum(alpha[0]) + sum(alpha[1]) + 1)
    print("Function value for variable L after update: {}".format(fn_val_L(L, Y, Lx, G_var, Lambda[:2], alpha[:2])[0]))
    
    print("Function value for variable G: {}".format(fn_val_G(G_var, L, Lambda[0], alpha[0], gamma)[0]))
    G_var = [m2t(alpha[0][i]*G_inv[i]*t2m(L-Lambda[0][i], i), sizes, i) for i in range(n)]
    print("Function value for variable G after update: {}".format(fn_val_G(G_var, L, Lambda[0], alpha[0], gamma)[0]))
    fval_G = fn_val_G(G_var, L, Lambda[0], alpha[0], gamma)[0]
    Lx, fvals_L = update_L(Lx, L, X, Lambda[1:], Phi, alpha[1:], theta, track_fval=True)
    print("Function value for variable Lx after update: {}".format(fvals_L[-1]))
    X, fvals_X = update_X(X, Lx, Lambda[2], Phi, alpha[2], theta, track_fval=True)
    print("Function value for variable X after update: {}".format(fvals_X[0][-1]))
    
    Lambda = update_Lambda(Lambda, L, Lx, X, G_var)[0]
    iter=+1
    # if iter%10==0:
    break

Function value for variable L: 8811.654284905877
Function value for variable L after update: 7832.581586583001
Function value for variable G: 435.1434214768334
Function value for variable G after update: 189.92597174163032
Function value for variable Lx after update: 221.39835036979943
Function value for variable X after update: [460.71592294608314, 374.76746935668905, 698.0803071973568, 525.2278746766441]


In [10]:
fvals_X

[(2058.7915741767733,
  [0.0, 0.0, 0.0, 0.0],
  [460.71592294608314,
   374.76746935668905,
   698.0803071973568,
   525.2278746766441]),
 (20397145.509930324,
  [46469.447704658654,
   34017.222687721944,
   163734.84631016944,
   73824.30414488468],
  [1885926.7380964705,
   3764679.540286965,
   10173138.698795268,
   4255354.711904185]),
 (6668893517367.863,
  [393969125.39106274,
   3575340767.021408,
   4964573208.106484,
   2339589579.9964604],
  [68253089742.578125,
   3633247854259.2393,
   1425266131240.7312,
   1530852969444.799]),
 (7.909183238449129e+18,
  [41540173366083.16,
   5014246379178594.0,
   1318868071832162.2,
   1928528672706778.0],
  [1.88033300426139e+16,
   5.517473348638968e+18,
   6.514543021392238e+17,
   1.7131490743312402e+18]),
 (1.1018529755547175e+25,
  [1.3931957040386163e+19,
   7.734512279293749e+21,
   6.568157653342475e+20,
   2.2552275737567798e+21],
  [6.892301655931712e+21,
   8.579046227679336e+24,
   3.5138159754488794e+23,
   2.07054914109