In [15]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import cvxpy as cp

from jax.random import PRNGKey, split
from jax import config, grad
from time import time
# config.update("jax_enable_x64", True)

from data import generate_data, generate_observation_matrix, generate_sensing_matrices
from loss import create_mc_loss
from network import init_net, create_network, compute_prefactor, compress_network
from solver import train
from utils import svd, sensing_operator

In [21]:
key = PRNGKey(0)

r = 2
input_dim = 100
output_dim = 100
depth = 3
init_type = "orth"
init_scale = 1e-3

key, target_key = split(key)
target = generate_data(key=target_key, shape=(output_dim, input_dim), rank=r)

key, weight_key = split(key)
init_weights = init_net(key=weight_key, input_dim=input_dim, output_dim=output_dim, width=input_dim, depth=depth, init_type="orth", init_scale=init_scale)
network_fn = create_network()

key, observation_key = split(key)
percent_observed = 0.1
mask = generate_observation_matrix(observation_key, percent_observed, (output_dim, input_dim))
mc_loss_fn = create_mc_loss(target, mask)
e2e_loss_fn = lambda w: mc_loss_fn(network_fn(w))

## Compressed Network

In [22]:
V = compute_prefactor(init_weights, mc_loss_fn, network_fn, r)
comp_init_weights, V1_1, UL_1 = compress_network(init_weights, V, r)
comp_init_weights = [V1_1.T] + comp_init_weights + [UL_1]
comp_network_fn = create_network()

In [26]:
num_iters = 100000
step_size = 1e2
comp_weights, comp_loss_list, comp_time_list = train(
    init_weights=comp_init_weights, 
    loss_fn=mc_loss_fn,
    network_fn=comp_network_fn,
    n_epochs=num_iters,
    step_size=step_size,
    n_inner_loops=500,
    factors=True,
    save_weights=False
)

  0%|          | 0/200 [00:00<?, ?it/s]

In [34]:
jnp.mean((network_fn(comp_weights) - target)**2)

Array(1.39968705e-11, dtype=float32)

## Nuclear Norm Min

In [28]:
X = cp.Variable(shape=(output_dim, input_dim))
objective = cp.Minimize(cp.norm(X, "nuc"))
constraints = [cp.multiply(mask, (X - target)) == 0]
prob = cp.Problem(objective, constraints)
result = prob.solve(verbose=True)
Xpred = X.value

                                     CVXPY                                     
                                     v1.3.1                                    
(CVXPY) May 03 06:11:19 PM: Your problem has 10000 variables, 1 constraints, and 0 parameters.
(CVXPY) May 03 06:11:19 PM: It is compliant with the following grammars: DCP, DQCP
(CVXPY) May 03 06:11:19 PM: (If you need to solve this problem multiple times, but with different data, consider using parameters.)
(CVXPY) May 03 06:11:19 PM: CVXPY will first compile your problem; then, it will invoke a numerical solver to obtain a solution.
-------------------------------------------------------------------------------
                                  Compilation                                  
-------------------------------------------------------------------------------
(CVXPY) May 03 06:11:19 PM: Compiling problem (target solver=SCS).
(CVXPY) May 03 06:11:19 PM: Reduction chain: Dcp2Cone -> CvxAttr2Constr -> ConeMatrixStuffing 

In [35]:
prob.solver_stats.solve_time

14.271667806

In [36]:
jnp.mean((Xpred - target)**2)

Array(0.01298427, dtype=float32)

## Dual factors with balancing (and spectral initialization)

In [44]:
U, s, V = svd(mask * target)

In [None]:
W2 = 