In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
from contextlib import nullcontext
from math import cos, sin

from codec import VectorQuantizer

channels = 2

samples = torch.concat(
    [
        torch.randn(30, channels) * 0.1 + torch.Tensor([0.1, 0.2]),
        torch.randn(30, channels) * 0.3 + torch.Tensor([-0.8, 0.9]),
        torch.randn(30, channels),
    ]
)
samples.requires_grad_(True)


angle = 0.0
noise_added = 0.01
learning_rate = 0.5
codebook_update_speed = 0.05

history = []
bottleneck = VectorQuantizer(
    channels=channels, codebook_size=4, codebook_update_speed=codebook_update_speed
)
optimizer = torch.optim.SGD([samples], lr=learning_rate)

for i in range(1000):
    history.append(
        {
            "samples": samples.detach().clone().numpy(),
            "codebook": bottleneck.codebook().detach().clone().numpy(),
        }
    )
    optimizer.zero_grad()
    _codes, _quantized, commitment_loss = bottleneck(samples)
    commitment_loss.backward()
    optimizer.step()

    with torch.no_grad():
        # Tricky: We need samples[:] here because otherwise we create a different
        # tensor than the one registered with the optimizer, and it wouldn't know
        # how to update it.
        samples[:] = samples @ torch.Tensor(
            [
                [cos(angle), sin(angle)],
                [-sin(angle), cos(angle)],
            ]
        )
        samples += torch.randn_like(samples) * noise_added

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.animation as animation
from IPython.display import HTML

fig, ax = plt.subplots()

embeddings_scatter = ax.scatter(
    history[0]["samples"][:, 0], history[0]["samples"][:, 1]
)
codes_scatter = ax.scatter(history[0]["codebook"][:, 0], history[0]["codebook"][:, 1])


def animate(history_step):
    # line.set_ydata(np.sin(x + i / 50))  # update the data.
    embeddings_scatter.set_offsets(history_step["samples"])
    codes_scatter.set_offsets(history_step["codebook"])
    return (embeddings_scatter, codes_scatter)


ani = animation.FuncAnimation(fig, animate, frames=history, interval=1 / 30, blit=True)
plt.close(fig)  # Don't show fig, just the animation


HTML(ani.to_jshtml())

Fancy


In [None]:
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.animation as animation
from IPython.display import HTML


def save_animation(
    filename: str,
    history: list[dict],
    n_frames: int = 200,
    show_quantization: bool = True,
):
    fig, axes = plt.subplots(1, 2)
    fig.set_figwidth(11)

    scatter_ax = axes[0]
    loss_ax = axes[1]

    embeddings_scatter = scatter_ax.scatter(
        history[0]["samples"][:, 0], history[0]["samples"][:, 1], alpha=1.0
    )
    codes_scatter = scatter_ax.scatter(
        history[0]["codebook"][:, 0],
        history[0]["codebook"][:, 1],
        alpha=1.0 if show_quantization else 0.0,
        s=75,
        marker="s",
        color="gray",
    )
    (mse_line,) = loss_ax.plot([], [], lw=2)
    (commitment_line,) = loss_ax.plot([], [], lw=2)

    scatter_ax.set_xlim(-1.3, 1.3)
    scatter_ax.set_ylim(-1.3, 1.3)

    downsampled_history = [
        history[round(i)] for i in np.linspace(0, len(history) - 1, n_frames)
    ]

    loss_ax.set_xlim(0, len(downsampled_history))
    loss_ax.set_ylim(0, 0.15)

    i = 0

    def animate(history_step):
        nonlocal i
        i += 1
        embeddings_scatter.set_offsets(history_step["samples"])
        codes_scatter.set_offsets(history_step["codebook"])
        embeddings_scatter.set_alpha(0.5)

        embeddings_scatter.set_array(history_step["labels"])
        embeddings_scatter.set_cmap("tab10")

        mse_losses = np.array([hs["mse_loss"] for hs in downsampled_history[:i]])
        mse_line.set_data(np.arange(len(mse_losses)), mse_losses)
        commitment_losses = np.array(
            [hs["commitment_loss"] for hs in downsampled_history[:i]]
        )
        commitment_line.set_data(np.arange(len(commitment_losses)), commitment_losses)

        # exponential moving average of the limits
        previous = [*scatter_ax.get_xlim(), *scatter_ax.get_ylim()]
        update = [
            history_step["samples"][:, 0].min(),
            history_step["samples"][:, 0].max(),
            history_step["samples"][:, 1].min(),
            history_step["samples"][:, 1].max(),
        ]
        alpha = 0.9
        updated = [alpha * p + (1 - alpha) * u for p, u in zip(previous, update)]
        scatter_ax.set_xlim(updated[0], updated[1])
        scatter_ax.set_ylim(updated[2], updated[3])

        return (embeddings_scatter, codes_scatter)

    ani = animation.FuncAnimation(
        fig, animate, frames=downsampled_history, interval=1 / 30, blit=True
    )
    # plt.close(fig)  # Don't show fig

    ani.save(filename, writer="ffmpeg", fps=30)

In [None]:
import itertools
import tqdm.auto

from codec import VectorQuantizer
import torchvision
from torchvision.transforms import ToTensor


def make_architecture(in_features: int, latent_dim: int, out_features: int):
    return torch.nn.Sequential(
        torch.nn.Linear(in_features=in_features, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=out_features),
    )


def train_model(
    *, train_steps: int, mse_coef=0.3, commitment_coef=1.0, learning_rate: float = 0.8
):
    channels = 2
    latent_dim = 16
    codebook_update_speed = 0.05
    device = "cuda"
    history = []

    dataset = torchvision.datasets.FashionMNIST(
        "./data/fashion_mnist/", download=True, transform=ToTensor()
    )
    # Filter to only first 3 classes
    idx = [i for i, (_, y) in enumerate(dataset) if y in [0, 1, 2]]
    from torch.utils.data import Subset

    filtered_dataset = Subset(dataset, idx)
    dataloader = torch.utils.data.DataLoader(
        filtered_dataset, batch_size=256, shuffle=True
    )

    torch.manual_seed(123)

    with torch.device("cuda"):
        encoder = make_architecture(28 * 28, latent_dim, channels)
        decoder = make_architecture(channels, latent_dim, 28 * 28)

        bottleneck = VectorQuantizer(
            channels=channels,
            codebook_size=36,
            codebook_update_speed=codebook_update_speed,
        )

    # AdamW works better than SGD even for this simple problem
    optimizer = torch.optim.AdamW(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=learning_rate,
    )
    encoder.to(device)
    decoder.to(device)

    for i, (batch, batch_y) in (
        pbar := tqdm.auto.tqdm(
            zip(range(train_steps), itertools.cycle(dataloader)),
            desc="Training",
            total=train_steps,
        )
    ):
        optimizer.zero_grad()

        x = batch.view(batch.size(0), -1).to(device)
        z = encoder(x)

        bottleneck.restart_unused_codes(z)
        _codes, z_quantized, commitment_loss = bottleneck(z)

        if commitment_coef == 0:
            z_quantized = z

        x_reconstructed = decoder(z_quantized)

        mse_loss = ((x_reconstructed - x) ** 2).mean()

        history.append(
            {
                "samples": z.detach().clone().to("cpu").numpy(),
                "labels": batch_y.detach().clone().to("cpu").numpy(),
                "codebook": bottleneck.codebook().detach().clone().to("cpu").numpy(),
                "mse_loss": (mse_loss * mse_coef).detach().item(),
                "commitment_loss": (commitment_loss * commitment_coef).detach().item(),
            }
        )
        loss = mse_loss * mse_coef + commitment_loss * commitment_coef
        loss.backward()
        pbar.set_postfix({"loss": loss.item()})

        optimizer.step()

    print(f"Final step: {history[-1]['mse_loss']=} {history[-1]['commitment_loss']=}")

    return history

In [None]:
history = train_model(
    train_steps=2500,
    mse_coef=1.0,
    commitment_coef=0.0,
    learning_rate=0.001,
)
save_animation("for-video/vq_unquantized.mp4", history, show_quantization=False)

# history = train_model(
#     train_steps=100000, mse_coef=1.0, commitment_coef=3.0, learning_rate=0.001
# )
# save_animation("for-video/vq_balanced_short.mp4", history, n_frames=250)


# history = train_model(
#     train_steps=250, mse_coef=1.0, commitment_coef=1.0, learning_rate=0.001
# )
# save_animation("for-video/vq_balanced.mp4", history)

# history = train_model(
#     train_steps=2500,
#     mse_coef=1.0,
#     commitment_coef=0.0,
#     learning_rate=0.001,
# )
# save_animation(
#     "for-video/vq_unquantized.mp4", history, n_frames=300, show_quantization=False
# )

# history = train_model(mse_coef=0.01, commitment_coef=1.0, learning_rate=0.8, **defaults)
# save_animation("for-video/vq_commitment_wins.mp4", history, downsampling=10)


In [None]:
dataset = torchvision.datasets.FashionMNIST(
    "./data/fashion_mnist/", download=True, transform=ToTensor()
)
# Filter to only first 3 classes
idx = [i for i, (_, y) in enumerate(dataset) if y in [0, 1, 2]]
from torch.utils.data import Subset

filtered_dataset = Subset(dataset, idx)
dataloader = torch.utils.data.DataLoader(filtered_dataset, batch_size=256, shuffle=True)

In [None]:
next(iter(dataloader))

In [None]:
import datasets

In [None]:
ds = datasets.load_dataset("jg583/NSynth", trust_remote_code=True)