In [None]:
import matplotlib.pyplot as plt
import numpy as np

import wandb

In [None]:
wandb_usr = "your_wandb_username"

### Fetch data from wandb

In [None]:
api = wandb.Api()

project = f"{wandb_usr}/collapse_dynamics"
weight_decays = [3e-1, 1e-1, 3e-2, 1e-2, 3e-3, 1e-3]
seeds = [42, 43, 44, 45, 46]

# Collect results for each run
all_metrics = {
    "train_accuracy": [],
    "test_accuracy": [],
    "nc1_score": [],
    "nc2_score": [],
    "within_class_variance": [],
    "scale_means": [],
    "mi_zx_compression": [],
    "mi_zy_compression": [],
    "nhsic_zx": [],
    "nhsic_zy": [],
    "time_step": [],
}

for weight_decay in weight_decays:
    runs = api.runs(
        path=project,
        filters={
            "state": "finished",
            "config.wandb.job_type": "change_weight_decay",
            "config.optimizer.weight_decay": weight_decay,
            "config.model.name": "toy_mlp",
            "config.calc_nhsic": True,
        },
        order="config.seed",
    )
    run_metrics = {
        "train_accuracy": [],
        "test_accuracy": [],
        "nc1_score": [],
        "nc2_score": [],
        "within_class_variance": [],
        "scale_means": [],
        "mi_zx_compression": [],
        "mi_zy_compression": [],
        "nhsic_zx": [],
        "nhsic_zy": [],
        "time_step": [],
    }
    for seed in seeds:
        found = False
        for run in runs:
            if run.config["seed"] != seed:
                continue
            df = run.history(pandas=True)
            for key in run_metrics:
                if np.any(df[key].to_numpy() == "Infinity"):
                    print(
                        f"Weight decay {weight_decay} + seed {seed} contains infinity",
                    )
                    break
                if key in df:
                    run_metrics[key].append(df[key].to_numpy())
            found = True
            break
        if not found:
            raise ValueError(
                f"Run with seed {seed} not found for weight decay {weight_decay}",
            )

    # Align by time_step (assume all runs have the same time steps)
    min_len = min(len(arr) for arr in run_metrics["time_step"])
    for key in run_metrics:
        run_metrics[key] = [arr[:min_len] for arr in run_metrics[key]]
        all_metrics[key].append(np.stack(run_metrics[key]))

### Plot figures
#### Figure showing grokking behaivor with different weight decay

In [None]:
colors = list(plt.get_cmap("tab10").colors[: len(weight_decays)])


def first_cross(x, y, thr):
    i = np.argmax(y >= thr)
    return x[i] if y[i] >= thr else np.nan


# Compute mean and std for each metric
fig, axs = plt.subplots(1, 3, figsize=(12, 4), sharex=True)
x = all_metrics["time_step"][-1][0][:100000]  # time_step is the same for all seeds
for i, weight_decay, color in zip(
    range(len(weight_decays)),
    weight_decays,
    colors,
    strict=False,
):
    data = all_metrics["train_accuracy"][i]
    mean = data.mean(axis=0)
    std = data.std(axis=0)
    axs[0].plot(x, mean, color=color, linestyle="--", alpha=0.5, zorder=0)

    data = all_metrics["test_accuracy"][i]
    mean = data.mean(axis=0)
    std = data.std(axis=0)
    axs[0].plot(x, mean, label=f"$\\lambda = {weight_decay}$", color=color)

    data = all_metrics["within_class_variance"][i] / (
        all_metrics["scale_means"][i] ** 2 + 1e-10
    )
    mean = data.mean(axis=0)
    std = data.std(axis=0)
    axs[1].plot(x, mean, label=f"$\\lambda = {weight_decay}$", color=color)

    data = all_metrics["nc2_score"][i]
    mean = data.mean(axis=0)
    std = data.std(axis=0)
    axs[2].plot(x, mean, label=f"$\\lambda = {weight_decay}$", color=color)

label_size = 13
for i in range(3):
    axs[i].set_xlabel("Time step", fontsize=label_size)
    axs[i].set_xlim(10, None)
    axs[i].set_xscale("log")
    axs[i].grid()

title_size = 15
axs[0].set_title("Test Accuracy", fontsize=title_size)
axs[1].set_title("RNC1 Score", fontsize=title_size)
axs[2].set_title("NC2 Score", fontsize=title_size)

axs[0].legend(fontsize=11, loc="upper left")
axs[1].legend(fontsize=11, loc="lower left")
axs[2].legend(fontsize=11, loc="lower left")

axs[2].set_ylim(0, None)
axs[2].set_yticks([1 if x == 0 else x for x in axs[2].get_yticks()])

plt.tight_layout()
plt.show()

#### Figure showing IB dynamics with different weight decay

In [None]:
colors = list(plt.get_cmap("tab10").colors[: len(weight_decays)])


def first_cross(x, y, thr):
    i = np.argmax(y >= thr)
    return x[i] if y[i] >= thr else np.nan


# Compute mean and std for each metric
fig, axs = plt.subplots(1, 3, figsize=(12, 4), sharex=True)
x = all_metrics["time_step"][-1][0][:100000]  # time_step is the same for all seeds
for i, weight_decay, color in zip(
    range(len(weight_decays)),
    weight_decays,
    colors,
    strict=False,
):
    data = all_metrics["mi_zx_compression"][i] - all_metrics["mi_zy_compression"][i]
    mean = data.mean(axis=0)
    axs[0].plot(x, mean, label=f"$\\lambda = {weight_decay}$", color=color)

    data = all_metrics["nhsic_zx"][i] - all_metrics["nhsic_zy"][i]
    mean = data.mean(axis=0)
    axs[1].plot(x, mean, label=f"$\\lambda = {weight_decay}$", color=color)

    data = all_metrics["within_class_variance"][i] / (
        all_metrics["scale_means"][i] ** 2 + 1e-10
    )
    mean = data.mean(axis=0)
    axs[2].plot(x, mean, label=f"$\\lambda = {weight_decay}$", color=color)

label_size = 13
for i in range(3):
    axs[i].set_xlabel("Time step", fontsize=label_size)
    axs[i].set_xlim(10, None)
    axs[i].set_xscale("log")
    axs[i].grid()

title_size = 15
axs[0].set_title("$\\hat{I}(Z;X) - \\hat{I}(Z;Y)$", fontsize=title_size)
axs[1].set_title("$\\text{nHSIC}(Z;X) - \\text{nHSIC}(Z;Y)$", fontsize=title_size)
axs[2].set_title("RNC1 Score", fontsize=title_size)

axs[0].legend(fontsize=11, loc="lower left")
axs[1].legend(fontsize=11, loc="lower left")
axs[2].legend(fontsize=11, loc="lower left")

plt.tight_layout()
plt.show()

### Figure showing margins of training samples

In [None]:
import pandas as pd
import seaborn as sns
import torch
from torch import Tensor, nn

In [None]:
@torch.no_grad()
def get_features(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
) -> tuple[Tensor, Tensor]:
    model.eval()
    features, labels = [], []
    for data, target in loader:
        forward_result = model(data, return_repr=True)
        features.append(forward_result.representation)
        labels.append(target)
    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0)
    return features, labels


def get_classifier_weights_and_biases(model: nn.Module) -> tuple[Tensor, Tensor]:
    classifier = model.last_layer
    if not isinstance(classifier, nn.Linear):
        msg = "The last layer of the model must be a linear layer."
        raise TypeError(msg)
    w = classifier.weight.detach().cpu()
    b = classifier.bias.detach().cpu()
    return w, b


def get_orthogonal_basis_from_weights(
    weight: Tensor,
    class_1: int,
    class_2: int,
) -> Tensor:
    """Get an orthogonal basis from the weights of two classes.

    Returns: Tensor of shape (2, D) where D is the dimension of the weights.
    """
    if not (0 <= class_1 < weight.shape[0] and 0 <= class_2 < weight.shape[0]):
        msg = "class_1 and class_2 must be valid class indices."
        raise ValueError(msg)
    w1 = weight[class_1]
    w2 = weight[class_2]
    e1 = (w1 - w2) / ((w1 - w2).norm() + 1e-12)
    w2_orth = w2 - e1 * (w2 @ e1)
    if w2_orth.norm() < 1e-12:  # noqa: PLR2004
        msg = "The two classes are not linearly separable."
        raise ValueError(msg)
    e2 = w2_orth / (w2_orth.norm() + 1e-12)
    return torch.stack([e1, e2], dim=0)


def plot_decision_boundary_with_two_classes(
    ax: plt.Axes,
    w: Tensor,
    b: Tensor,
    class_1: int,
    class_2: int,
) -> None:
    w1, w2 = w[class_1], w[class_2]
    b1, b2 = b[class_1], b[class_2]

    x = (b2 - b1) / ((w1 - w2).norm() + 1e-12)
    ylim = ax.get_ylim()
    ys = np.linspace(ylim[0], ylim[1], 1000)
    ax.plot(x * np.ones_like(ys), ys)


def plot_violin_with_two_classes(  # noqa: PLR0913
    model: nn.Module,
    features: Tensor,
    labels: Tensor,
    ax: plt.Axes,
    class_1: int,
    class_2: int,
    label_size: int = 12,
) -> None:
    w, b = get_classifier_weights_and_biases(model)
    w1, w2 = w[class_1], w[class_2]
    b1, b2 = b[class_1], b[class_2]
    scores = (features @ (w1 - w2) + (b2 - b1)) / (torch.norm(w1 - w2) + 1e-12)

    df = pd.DataFrame(
        {
            "score": scores,
            "class": pd.Series(labels).map(
                {class_1: f"class {class_1}", class_2: f"class {class_2}"},
            ),
            "axis": "baseline",
        },
    )
    sns.violinplot(
        x="score",
        y="axis",
        hue="class",
        hue_order=[f"class {class_1}", f"class {class_2}"],
        data=df,
        ax=ax,
        split=True,
        density_norm="area",
        inner="quartile",
        palette=sns.color_palette("tab10", n_colors=2),
        bw_adjust=0.8,
        linewidth=0.8,
        common_norm=True,
    )
    ax.set_xbound(-2.2, 2.2)
    ax.axvline(0, color="black", linestyle="-")
    ax.set_xlabel(
        rf"Signed distance to decision boundary (class {class_1} vs. class {class_2})",
        fontsize=label_size,
    )
    ax.set_ylabel("Density", fontsize=label_size)
    ax.set_yticklabels([])
    ax.legend()
    ax.grid(True)

In [None]:
# Prepare data loader

import torchvision
from torch.utils.data import ConcatDataset, DataLoader, Subset

train_size = 1000

image_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
    ],
)
train_dataset = torchvision.datasets.MNIST(
    root="~/pytorch_datasets",
    train=True,
    transform=image_transform,
    download=True,
)
test_dataset = torchvision.datasets.MNIST(
    root="~/pytorch_datasets",
    train=False,
    transform=image_transform,
    download=True,
)

train_dataset = Subset(train_dataset, range(train_size))
train_unseen_dataset = Subset(train_dataset, range(train_size, len(train_dataset)))
combined_dataset = ConcatDataset([train_unseen_dataset, test_dataset])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False)
unseen_loader = DataLoader(combined_dataset, batch_size=128, shuffle=False)

In [None]:
from src.models.toy_mlp import MLPModel

model = MLPModel()
model.load_state_dict(
    torch.load(
        "~/collapse-dynamics/saved_models/toy_mlp/mnist/model_step_99900.pt",
    ),
)

f_train, l_train = get_features(model, train_loader)
f_unseen, l_unseen = get_features(model, unseen_loader)

class_1, class_2 = 0, 1  # Example classes to visualize
fig, axs = plt.subplots(1, 2, figsize=(12, 2.5), sharex=True)
plot_violin_with_two_classes(model, f_train, l_train, axs[0], class_1, class_2)
plot_violin_with_two_classes(model, f_unseen, l_unseen, axs[1], class_1, class_2)

axs[0].set_title("Train Examples", fontsize=15)
axs[1].set_title("Unseen Examples", fontsize=15)
fig.tight_layout()