In [18]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import cvxpy as cp


import numpy as np
from jax.random import PRNGKey, split
from jax import config, grad
from time import time
# config.update("jax_enable_x64", True)

from data import generate_data, generate_observation_matrix
from loss import create_mc_loss
from network import init_net, create_network, compute_prefactor, compress_network
from solver import train
from utils import svd, compose

In [19]:
key = PRNGKey(0)

r = 100
input_dim = 5000
output_dim = 5000
depth = 3
init_type = "orth"
init_scale = 1e-3

result_dict = {}

key, target_key = split(key)
target = generate_data(key=target_key, shape=(output_dim, input_dim), rank=r)

key, weight_key = split(key)
init_weights = init_net(key=weight_key, input_dim=input_dim, output_dim=output_dim, width=input_dim, depth=depth, init_type="orth", init_scale=init_scale)
network_fn = create_network()

key, observation_key = split(key)
percent_observed = 0.20
mask = generate_observation_matrix(observation_key, percent_observed, (output_dim, input_dim))
mc_loss_fn = create_mc_loss(target, mask)
e2e_loss_fn = compose(mc_loss_fn, network_fn)

## Compressed Network

In [20]:
V = compute_prefactor(init_weights=init_weights, e2e_loss_fn=e2e_loss_fn, grad_rank=r)
comp_init_weights, V1_1, UL_1 = compress_network(init_weights, V, r)
comp_init_weights = [V1_1.T] + comp_init_weights + [UL_1]
comp_network_fn = create_network()

In [22]:
num_iters = 500000
step_size = 1e5
comp_weights, comp_loss_list, comp_time_list = train(
    init_weights=comp_init_weights, 
    e2e_loss_fn=e2e_loss_fn,
    n_epochs=num_iters,
    step_size=step_size,
    n_inner_loops=100,
    factors=True,
    save_weights=False,
    tol=1e-12
)

  0%|          | 0/5000 [00:00<?, ?it/s]

In [23]:
comp_time = comp_time_list[-1]
comp_loss = jnp.sum((network_fn(comp_weights) - target)**2) / jnp.sum(target**2)

In [24]:
result_dict['comp'] = {
    'time': comp_time,
    'loss': comp_loss
}

In [25]:
comp_time

Array(124.916084, dtype=float32)

In [26]:
comp_loss

Array(9.443752e-12, dtype=float32)

## Scaled GD (spectral initialization)

In [27]:
import jax.numpy as jnp
import jax

from jax import grad
from jax.lax import fori_loop
from tqdm.auto import tqdm
from time import time

def scaled_update_weights(weights, gradient, step_size):
    return jax.tree_map(lambda p, g: p - step_size * g @ jnp.linalg.pinv(p.T @ p), weights, gradient)

def scaled_train(init_weights, e2e_loss_fn, n_epochs, step_size, n_inner_loops=100, save_weights=False, tol=0):

    # Define fun body in lax.fori_loop
    def body_fun(_, w):
        g = grad(e2e_loss_fn)(w)
        return scaled_update_weights(w, g, step_size)
    
    # Run once to compile
    fori_loop(0, n_inner_loops, body_fun, init_weights)

    loss = e2e_loss_fn(init_weights)
    loss_list = [loss]
    time_list = [0.]
    weights_list = [init_weights]
    weights = init_weights

    num_iters = n_epochs // n_inner_loops
    pbar = tqdm(range(num_iters))

    start_time = time()
    for _ in pbar:
        pbar.set_description(f"Loss: {loss:0.2e}")
        weights = fori_loop(0, n_inner_loops, body_fun, weights)
        loss = e2e_loss_fn(weights)
        loss_list.append(loss)
        time_list.append(time() - start_time)
        if save_weights:
            weights_list.append(weights)
            
        if loss < tol:
            break

    if save_weights:
        return weights_list, jnp.array(loss_list), jnp.array(time_list)
    else:
        return weights, jnp.array(loss_list), jnp.array(time_list)

In [28]:
def scaled_network_fn(w):
    X, Y = w
    return X @ Y.T

In [29]:
lam = 0
def scaled_e2e_loss_fn(w):
    X, Y = w
    return mc_loss_fn(scaled_network_fn(w)) + lam/2 * jnp.sum((X.T @ X - Y.T @ Y)**2)

In [30]:
U, s, V = svd(mask * target)
Ur = U[:, :r]
Vr = V[:, :r]
Sr = jnp.diag(s[:r])

X_init = Ur @ jnp.sqrt(Sr)
Y_init = Vr @ jnp.sqrt(Sr)

# X_init = init_scale * jnp.asarray(np.random.randn(input_dim, r))
# Y_init = init_scale * jnp.asarray(np.random.randn(output_dim, r))
scaled_init_weights = [X_init, Y_init]

In [33]:
num_iters = 40000
step_size = 1e5
scaled_weights, scaled_loss_list, scaled_time_list = scaled_train(
    init_weights=scaled_init_weights, 
    e2e_loss_fn=scaled_e2e_loss_fn,
    n_epochs=num_iters,
    step_size=step_size,
    n_inner_loops=100,
    save_weights=False,
    tol=1e-12
)

  0%|          | 0/400 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [14]:
scaled_time = scaled_time_list[-1]
scaled_loss = jnp.sum((scaled_network_fn(scaled_weights) - target)**2) / jnp.sum(target**2)

In [15]:
result_dict['scaled'] = {
    'time': scaled_time,
    'loss': scaled_loss
}

In [16]:
result_dict

{'comp': {'time': Array(24.896193, dtype=float32),
  'loss': Array(1.6326912e-09, dtype=float32)},
 'scaled': {'time': Array(53.19235, dtype=float32),
  'loss': Array(2.360959e-09, dtype=float32)}}

In [17]:
l = result_dict['scaled']['loss']
f'{l:0.2e}'

'2.36e-09'

## Nuclear Norm Min

In [None]:
X = cp.Variable(shape=(output_dim, input_dim))
objective = cp.Minimize(cp.norm(X, "nuc"))
constraints = [cp.multiply(mask, (X - target)) == 0]
prob = cp.Problem(objective, constraints)
result = prob.solve(verbose=True)
Xpred = X.value

In [None]:
nuc_time = prob.solver_stats.solve_time
nuc_loss = jnp.sum((Xpred - target)**2) / jnp.sum(target**2)

In [None]:
result_dict['nuc'] = {
    'time': nuc_time,
    'loss': nuc_loss
}