# Check Neural Collapse in Standard Training

In [None]:
from jaxl.constants import *
from jaxl.datasets.mnist import construct_mnist
from jaxl.datasets.wrappers import StandardSupervisedDataset
from jaxl.learning_utils import get_learner
from jaxl.models.common import get_activation
from jaxl.models.modules import CNNModule, MLPModule
from jaxl.plot_utils import set_size
from jaxl.utils import parse_dict, get_device

import _pickle as pickle
import argparse
import jax
import jax.random as jrandom
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import torchvision.datasets as torch_datasets

from collections import OrderedDict
from functools import partial
from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager
from torch.utils.data import DataLoader
from types import SimpleNamespace

In [None]:
device = "cpu"
get_device(device)

In [None]:
doc_width_pt = 750.0

learner_path = "/Users/chanb/research/personal/jaxl/{}".format(
    "jaxl/logs/nc-mnist/cnn-01-19-24_19_37_18-8c04f541-01bc-4346-be64-5581bceb4cc2"
)

In [None]:
def load_model(learner_path: str):
    config_path = os.path.join(learner_path, "config.json")
    with open(config_path, "r") as f:
        config_dict = json.load(f)
        config = parse_dict(config_dict)

    learner = get_learner(
        config.learner_config, config.model_config, config.optimizer_config
    )

    checkpoint_manager = CheckpointManager(
        os.path.join(learner_path, "models"),
        PyTreeCheckpointer(),
    )

    params = checkpoint_manager.restore(checkpoint_manager.latest_step())
    all_params = [(step, checkpoint_manager.restore(step)) for step in checkpoint_manager.all_steps()]
    model = learner._model
    return params, model, config, all_params

In [None]:
params, model, config, all_params = load_model(learner_path)

In [None]:
config

In [None]:
train_dataset = StandardSupervisedDataset(construct_mnist(
    config.learner_config.dataset_config.dataset_kwargs.save_path,
    train=True,
))

test_dataset = StandardSupervisedDataset(construct_mnist(
    config.learner_config.dataset_config.dataset_kwargs.save_path,
    train=False,
))

In [None]:
def get_latent(params, inputs, carries):
    cnn_outs, cnn_states = CNNModule(
        config.model_config.features,
        config.model_config.kernel_sizes,
        get_activation(CONST_RELU),
    ).apply(
        params[CONST_MODEL_DICT][CONST_MODEL][CONST_CNN],
        inputs,
        capture_intermediates=True,
        mutable=["cnn_latents"]
    )

    cnn_outs = cnn_outs.reshape((len(cnn_outs), -1))

    _, mlp_states = MLPModule(
        config.model_config.layers,
        get_activation(CONST_RELU),
        get_activation(CONST_IDENTITY),
    ).apply(
        params[CONST_MODEL_DICT][CONST_MODEL][CONST_MLP],
        cnn_outs,
        capture_intermediates=True,
        mutable=["mlp_latents"]
    )

    latents = OrderedDict()
    for (states, key) in [
        (cnn_states, "cnn_latents"),
        (mlp_states, "mlp_latents"),
    ]:
        for state, state_val in states[key].items():
            latents[state] = state_val
    return latents

In [None]:
def plot(mlp_1_latents, labels, step):
    nrows = ncols = 1
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
        layout="constrained",
    )

    unique_classes = np.unique(labels)

    for class_i in unique_classes:
        class_idxes = np.where(labels == class_i)[0]
        axes.scatter(
            mlp_1_latents[class_idxes, 0],
            mlp_1_latents[class_idxes, 1],
            label="Class {}".format(class_i),
            alpha=0.5
        )
    fig.legend()
    fig.suptitle("Model @ epoch {}".format(step))
    fig.show()

In [None]:
batch_size = 300
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
)

In [None]:
for (step, params) in all_params:
    mlp_1_latents = []
    labels = []
    for inputs, carries, outputs, _ in iter(train_dataloader):
        latents = get_latent(params, inputs, carries)
        mlp_1_latents.append(latents["mlp_1"][0])
        labels.append(outputs)

    mlp_1_latents = np.concatenate(mlp_1_latents, axis=0)
    labels = np.concatenate(labels, axis=0)
    plot(mlp_1_latents, labels, step)