In [1]:
#import jax.numpy as np
#from jax import pmap
import numpy as np
from maxnorm.maxnorm_completion import *
from maxnorm.tenalg import *
from maxnorm.graphs import *
import sparse
from itertools import product
import networkx as nx

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

%load_ext autoreload

In [2]:
# create random, low-rank tensor
t = 5
n = 30
r = 3
max_iterations = 100
epsilons = np.logspace(-4, 0, 5)
epsilons = np.concatenate((np.array(0), epsilons), axis = None)
#delta = 0.1
const = 24
sigma = 0.005
ndata =  const * r * t * n * np.log10(n)
U = kr_random(n, t, r, rvs='unif')
U = kr_rescale(U, np.sqrt(n**t), 'hs')
norm_true = np.sqrt(kr_dot(U, U) / n**t)
#U = [np.random.randn(n, r) for i in range(t)]
print("n data: %.2e" % ndata)
print("n data ** t/2: %.2e" % int(const * r * n**(t/2) * np.log10(n)))
print("true norm: %.2e" % norm_true)

n data: 1.60e+04
n data ** t/2: 5.24e+05
true norm: 1.00e+00


In [3]:
qnorm_true = max_qnorm_ub(U)
print(qnorm_true)
print(r**(t/2))

58.590081445867376
15.588457268119896


In [4]:
expander = nx.random_regular_graph(6, n)
#expander = nx.chordal_cycle_graph(n)
observation_mask = obs_mask_expander(expander, t)
#observation_mask = obs_mask_iid(tuple([n for i in range(t)]), ndata * n**(-t))

In [5]:
from run_sweep_iid import generate_data


data = generate_data(observation_mask, U, sigma)
clean_data_rmse = np.sqrt(loss(U, data) / data.nnz)
print(data.nnz)
print(ndata)
print(n**t)
print("%0.1e%%" % (float(data.nnz) / n**t * 100))

38880
15952.909550972354
24300000
1.6e-01%


In [6]:
def print_factor_norms(U):
    print("fro:   " + str([np.linalg.norm(Ui,'fro')**2 for Ui in U]))
    print("2-inf: " + str([np.max(np.linalg.norm(Ui, axis=1)) for Ui in U]))

print_factor_norms(kr_balance_factors(U))

fro:   [72.20668981939896, 72.20668981939899, 72.20668981939899, 72.20668981939899, 72.20668981939899]
2-inf: [2.291713506684371, 2.3091916772236365, 2.0989896240905535, 2.335517972704103, 2.2533043604887495]


In [7]:
delta = 1.5 * sigma
#np.sqrt(np.sum((data.data - clean_data)**2)) * 2 / np.sqrt(data.nnz) * 1.5

print("rms of data:        %f" % clean_data_rmse)
print("delta parameter:    %f" % delta)
print("rmse of U true:     %f" % np.sqrt(loss(U, data) / data.nnz))

rms of data:        0.004984
delta parameter:    0.007500
rmse of U true:     0.004984


In [8]:
#Uinit = kr_rescale(Unew1, np.sqrt(np.product(data.shape) * data.sum() ** 2 / data.nnz), 'hs')
# Uinit = kr_balance_factors(Unew1)

In [10]:
%autoreload
Unew2_list = []
cost_arr_list = []
deltas = np.logspace(-6, -3, 4)
eps = 1e-4
iterations = 300
print(deltas)
print(sigma)
# sigma = deltas[-1] * np.sqrt(data.nnz)
for delta in deltas:
    print("delta = ", delta)
    Unew2, cost_arr2, ge_arr2, ge_mse_arr2 = tensor_completion_maxnorm(U, data, 4 * r**t, delta * np.sqrt(data.nnz), epsilon=eps,
                                                init='svdrand',
                                                kappa=100, beta=1, 
                                                tol=1e-10, inner_tol=1e-12, max_iter=iterations, inner_max_iter=10, 
                                                verbosity=1, inner_line_iter=40,
                                                rebalance=True)
    Unew2_list.append(Unew2)
    cost_arr_list.append(cost_arr2)
    ge_arr_list.append(ge_arr2)
    ge_arr_mse_list.append(ge_mse_arr2)

[1.e-06 1.e-05 1.e-04 1.e-03]
0.005
delta =  1e-06
Initial cost: 1.963e+06
Initial qnorm_ub: 9.295e+03
|| r || = 1.003e+00, delta = 1.000e-06
Initial MSE : 1.000e+00

Iteration 0 complete
|| resid || = 4.255e-01
Cost : 5.007e+05
MSE : 1.026e+00


Iteration 10 complete
|| resid || = 5.299e-02
Cost : 5.086e+04
MSE : 8.802e-01

Caught KeyboardInterrupt, exiting early

finished in 11 iterations


scaled || r || = 5.299e-02, delta = 1.000e-06
Max-qnorm upper bound: 4.540e+04
Cost function: 5.086e+04
MSE : 8.802e-01


NameError: name 'ge_arr_list' is not defined

In [None]:
import matplotlib.pyplot as plt

plt.semilogy(cost_arr_list[0])
plt.semilogy(cost_arr_list[1])
plt.semilogy(cost_arr_list[2])
plt.semilogy(cost_arr_list[3])
# plt.semilogy(cost_arr_list[4])
# plt.semilogy(cost_arr_list[5])

plt.legend(["delta = 1e-6", "delta = 1e-5", "delta = 1e-4", "delta = 1e-3"])
plt.xlabel('iterate')
plt.ylabel('cost')
plt.title("GE-Curve for Various Delta Values")
plt.savefig('images/GE-Curve-Delta_300.png')

## Generalization error

In [None]:
def gen_err(Upred, Utrue):
    norm_true = kr_dot(Utrue, Utrue)
    mse_gen = kr_dot(Upred, Upred) + norm_true - 2 * kr_dot(Upred, Utrue)
    return np.sqrt(mse_gen / norm_true)

def mse_gen_err(Upred, Utrue):
    norm_true = kr_dot(Utrue, Utrue)
    mse_gen = kr_dot(Upred, Upred) + norm_true - 2 * kr_dot(Upred, Utrue)
    return np.sqrt(mse_gen / Upred[0].shape[0] ** len(Upred))


for i in range(len(Unew2_list)):
    print("delta = {}".format(deltas[i]))
    print("relative RMSE max:      %1.4e" % gen_err(Unew2_list[i], U))
    print("MSE max:                %1.4e\n" % mse_gen_err(Unew2_list[i], U))

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, axs = plt.subplots(1,t, figsize=(6, 20))
for i in range(t):
    im = axs[i].imshow(Unew2[i])
    divider = make_axes_locatable(axs[i])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im, cax=cax, orientation='vertical')
    if i > 0:
        axs[i].set_yticks([])

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, axs = plt.subplots(1,t, figsize=(6, 20))
for i in range(t):
    im = axs[i].imshow(U[i])
    divider = make_axes_locatable(axs[i])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im, cax=cax, orientation='vertical')
    if i > 0:
        axs[i].set_yticks([])