# Approximate tight frame for classification

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import optax

In [None]:
hidden_dim = 64
num_classes = 1000
a = np.random.randn(num_classes, hidden_dim)

In [None]:
(a[..., None] @ a[:, None]).shape

In [None]:
@jax.jit
def loss_fn(a):
    return jnp.sum(
        (jnp.sum(a[..., None] @ a[:, None], axis=0) - jnp.eye(hidden_dim)) ** 2
    )

In [None]:
optimizer = optax.sgd(1e-4)
opt_state = optimizer.init(a)

In [None]:
num_epochs = 20000
losses = []
for ii in range(num_epochs):
    loss, grads = jax.value_and_grad(loss_fn)(a)

    updates, opt_state = optimizer.update(grads, opt_state)
    a = optax.apply_updates(a, updates)
    if (ii + 1) % 1000 == 0:
        losses.append(loss)
        print(ii + 1, loss)

In [None]:
import matplotlib.pyplot as plt

plt.plot(np.arange(len(losses)), losses)
plt.show()

In [None]:
norms = np.linalg.norm(a, axis=-1)
np.max(norms) - np.min(norms)

In [None]:
(a[:, None] - a[None]).shape

In [None]:
np.sum(a[:, None] - a[None], axis=-1)

In [None]:
res = np.where(np.sum(a[:, None] - a[None], axis=-1) == 0)
print(len(res[0]), len(res[1]))
res[0], res[1]

In [None]:
import _pickle as pickle

pickle.dump(a, open("tight_frame.pkl", "wb"))

In [None]:
a.shape

In [None]:
from jaxl.datasets.tight_frame_classification import TightFrameClassification

In [None]:
dataset = TightFrameClassification(
    "tight_frame.pkl",
    1000,
    9,
    10,
    "train",
    1.0,
)

In [None]:
import numpy as np

np.argmax(dataset[0][1], axis=-1)

In [None]:
np.sum(dataset[0][0], axis=-1)