In [1]:
import jax
import jax.nn as nn
import jax.numpy as jnp
from jax.config import config; config.update("jax_enable_x64", True)
from jax.random import PRNGKey, split, orthogonal

from collapse import compute_feature_collapse
from network import init_net_orth, compute_outputs
from loss import create_loss
from solver import train
from data import generate_orthogonal_input, generate_labels_and_target

In [3]:
key = PRNGKey(0)

num_classes = 5
num_samples_per_class = 10
input_dim = 20
total_samples = num_classes * num_samples_per_class

labels, target = generate_labels_and_target(num_classes, num_samples_per_class)

key, subkey = split(key)
input_data = generate_orthogonal_input(subkey, input_dim, total_samples)

nonlinear = False
depth = 3
init_scale = 1e-3

loss_fn = create_loss(target)
def e2e_loss_fn(weights):
    return loss_fn(compute_outputs(weights, input_data, nonlinear)[0])

key, subkey = split(key)
init_weights = init_net_orth(
    key=subkey, 
    input_dim=input_dim, 
    output_dim=num_classes, 
    width=input_dim, 
    depth=depth, 
    init_scale=init_scale
)

In [6]:
compute_outputs(init_weights, input_data)[1]

[Array([[ 1.78334207e-04, -1.11758648e-01,  9.79663043e-02,
         -1.34214338e-01,  2.20336752e-01,  1.61459899e-01,
          1.14632158e-02,  1.14089814e-01,  1.60006166e-02,
          6.58178472e-02,  6.00102212e-03, -6.21766520e-02,
          3.32511966e-01,  5.04452211e-03, -4.22880105e-02,
         -9.71073179e-02,  2.33955940e-01, -2.27996355e-01,
         -7.42757784e-02,  1.28378385e-01, -1.70201625e-01,
          2.63329817e-02,  3.88916983e-02,  2.62565166e-01,
          1.03559872e-01,  1.57834720e-01, -9.03861767e-02,
          3.78631890e-01, -1.23957943e-01, -1.52157176e-01,
         -3.03686548e-02,  7.53933326e-02,  3.08322261e-03,
         -6.53298532e-02,  1.33605397e-01,  1.02033642e-01,
         -1.50875234e-02,  3.58390492e-01, -9.29869309e-02,
         -5.71624662e-02, -1.97927781e-01,  3.12129531e-02,
          2.27140469e-02,  1.78575537e-01, -6.21257842e-02,
         -6.09818378e-02,  2.89239780e-02, -8.34505192e-02,
         -1.22427828e-03, -1.40191570e-0