In [1]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import cvxpy as cp

from jax.random import PRNGKey, split
from jax import config, grad
from time import time
# config.update("jax_enable_x64", True)

from data import generate_data, generate_observation_matrix
from loss import create_mc_loss
from network import init_net, create_network, compute_prefactor, compress_network
from solver import train
from utils import svd, compose

In [3]:
key = PRNGKey(0)

r = 5
input_dim = 200
output_dim = 200
depth = 3
init_type = "orth"
init_scale = 1e-3

key, target_key = split(key)
target = generate_data(key=target_key, shape=(output_dim, input_dim), rank=r)

key, weight_key = split(key)
init_weights = init_net(key=weight_key, input_dim=input_dim, output_dim=output_dim, width=input_dim, depth=depth, init_type="orth", init_scale=init_scale)
network_fn = create_network()

key, observation_key = split(key)
percent_observed = 0.1
mask = generate_observation_matrix(observation_key, percent_observed, (output_dim, input_dim))
mc_loss_fn = create_mc_loss(target, mask)
e2e_loss_fn = compose(mc_loss_fn, network_fn)

## Compressed Network

In [5]:
V = compute_prefactor(init_weights=init_weights, e2e_loss_fn=e2e_loss_fn, grad_rank=r)
comp_init_weights, V1_1, UL_1 = compress_network(init_weights, V, r)
comp_init_weights = [V1_1.T] + comp_init_weights + [UL_1]
comp_network_fn = create_network()

In [6]:
num_iters = 80000
step_size = 5e2
comp_weights, comp_loss_list, comp_time_list = train(
    init_weights=comp_init_weights, 
    e2e_loss_fn=e2e_loss_fn,
    n_epochs=num_iters,
    step_size=step_size,
    n_inner_loops=500,
    factors=True,
    save_weights=False
)

  0%|          | 0/160 [00:00<?, ?it/s]

In [47]:
comp_time = comp_time_list[jnp.argmax(comp_loss_list < 1e-8)]
comp_loss = jnp.sum((network_fn(comp_weights) - target)**2) / jnp.sum(target**2)

In [48]:
result_dict['comp'] = {
    'time': comp_time,
    'loss': comp_loss
}

## Nuclear Norm Min

In [18]:
X = cp.Variable(shape=(output_dim, input_dim))
objective = cp.Minimize(cp.norm(X, "nuc"))
constraints = [cp.multiply(mask, (X - target)) == 0]
prob = cp.Problem(objective, constraints)
result = prob.solve(verbose=True)
Xpred = X.value

                                     CVXPY                                     
                                     v1.3.1                                    
(CVXPY) May 04 03:04:45 PM: Your problem has 40000 variables, 1 constraints, and 0 parameters.
(CVXPY) May 04 03:04:45 PM: It is compliant with the following grammars: DCP, DQCP
(CVXPY) May 04 03:04:45 PM: (If you need to solve this problem multiple times, but with different data, consider using parameters.)
(CVXPY) May 04 03:04:45 PM: CVXPY will first compile your problem; then, it will invoke a numerical solver to obtain a solution.
-------------------------------------------------------------------------------
                                  Compilation                                  
-------------------------------------------------------------------------------
(CVXPY) May 04 03:04:45 PM: Compiling problem (target solver=SCS).
(CVXPY) May 04 03:04:45 PM: Reduction chain: Dcp2Cone -> CvxAttr2Constr -> ConeMatrixStuffing 

In [37]:
nuc_time = prob.solver_stats.solve_time
nuc_loss = jnp.sum((Xpred - target)**2) / jnp.sum(target**2)

In [49]:
result_dict['nuc'] = {
    'time': nuc_time,
    'loss': nuc_loss
}

## Dual factors with balancing (and spectral initialization)

In [50]:
def dual_network_fn(w):
    X, Y = w
    return X @ Y.T

In [51]:
lam = 1e-3
def dual_e2e_loss_fn(w):
    X, Y = w
    return mc_loss_fn(dual_network_fn(w)) + lam/2 * jnp.sum((X.T @ X - Y.T @ Y)**2)

In [52]:
U, s, V = svd(mask * target)
Ur = U[:, :r]
Vr = V[:, :r]
Sr = jnp.diag(s[:r])

X_init = Ur @ jnp.sqrt(Sr)
Y_init = Vr @ jnp.sqrt(Sr)
dual_init_weights = [X_init, Y_init]

In [58]:
num_iters = 500000
step_size = 5
dual_weights, dual_loss_list, dual_time_list = train(
    init_weights=dual_init_weights, 
    e2e_loss_fn=dual_e2e_loss_fn,
    n_epochs=num_iters,
    step_size=step_size,
    n_inner_loops=100,
    factors=False,
    save_weights=False
)

  0%|          | 0/5000 [00:00<?, ?it/s]

In [59]:
dual_time = dual_time_list[jnp.argmax(dual_loss_list < 1e-8)]
dual_loss = jnp.sum((dual_network_fn(dual_weights) - target)**2) / jnp.sum(target**2)

In [61]:
result_dict['dual'] = {
    'time': dual_time,
    'loss': dual_loss
}

In [74]:
l = result_dict['dual']['loss']
f'{l:0.2e}'

'1.33e-04'