In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
from contextlib import nullcontext
from math import cos, sin

from codec import VectorQuantizer

channels = 2

samples = torch.concat(
    [
        torch.randn(30, channels) * 0.1 + torch.Tensor([0.1, 0.2]),
        torch.randn(30, channels) * 0.3 + torch.Tensor([-0.8, 0.9]),
        torch.randn(30, channels),
    ]
)
samples.requires_grad_(True)


angle = 0.0
noise_added = 0.01
learning_rate = 0.5
codebook_update_speed = 0.05

history = []
bottleneck = VectorQuantizer(
    channels=channels, codebook_size=4, codebook_update_speed=codebook_update_speed
)
optimizer = torch.optim.SGD([samples], lr=learning_rate)

for i in range(1000):
    history.append(
        {
            "samples": samples.detach().clone().numpy(),
            "codebook": bottleneck.codebook().detach().clone().numpy(),
        }
    )
    optimizer.zero_grad()
    _codes, _quantized, commitment_loss = bottleneck(samples)
    commitment_loss.backward()
    optimizer.step()

    with torch.no_grad():
        # Tricky: We need samples[:] here because otherwise we create a different
        # tensor than the one registered with the optimizer, and it wouldn't know
        # how to update it.
        samples[:] = samples @ torch.Tensor(
            [
                [cos(angle), sin(angle)],
                [-sin(angle), cos(angle)],
            ]
        )
        samples += torch.randn_like(samples) * noise_added

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.animation as animation
from IPython.display import HTML

fig, ax = plt.subplots()

embeddings_scatter = ax.scatter(
    history[0]["samples"][:, 0], history[0]["samples"][:, 1]
)
codes_scatter = ax.scatter(history[0]["codebook"][:, 0], history[0]["codebook"][:, 1])


def animate(history_step):
    # line.set_ydata(np.sin(x + i / 50))  # update the data.
    embeddings_scatter.set_offsets(history_step["samples"])
    codes_scatter.set_offsets(history_step["codebook"])
    return (embeddings_scatter, codes_scatter)


ani = animation.FuncAnimation(fig, animate, frames=history, interval=1 / 30, blit=True)
plt.close(fig)  # Don't show fig, just the animation


HTML(ani.to_jshtml())

In [None]:
# TODO save to file ani.to_html5_video()

In [None]:
import tqdm.auto

from codec import VectorQuantizer


def train_model(
    *, train_steps: int, mse_coef=0.3, commitment_coef=1.0, learning_rate: float = 0.8
):
    channels = 2
    latent_dim = 8
    codebook_update_speed = 0.05

    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=2, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=channels),
    )

    history = []
    bottleneck = VectorQuantizer(
        channels=channels, codebook_size=4, codebook_update_speed=codebook_update_speed
    )
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    for i in tqdm.auto.trange(train_steps, desc="Training"):
        x = torch.randn(128, 2)
        y = x / x.norm(dim=1)[:, None]
        z = model(x)
        optimizer.zero_grad()
        _codes, quantized, commitment_loss = bottleneck(z)

        mse_loss = ((quantized - y) ** 2).mean()

        history.append(
            {
                "samples": z.detach().clone().numpy(),
                "codebook": bottleneck.codebook().detach().clone().numpy(),
                "mse_loss": (mse_loss * mse_coef).detach().item(),
                "commitment_loss": (commitment_loss * commitment_coef).detach().item(),
            }
        )
        loss = mse_loss * mse_coef + commitment_loss * commitment_coef
        loss.backward()

        optimizer.step()

    return history

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.animation as animation
from IPython.display import HTML


def save_animation(filename: str, history: list[dict], downsampling: int = 1):
    fig, axes = plt.subplots(1, 2)
    fig.set_figwidth(15)

    scatter_ax = axes[0]
    loss_ax = axes[1]

    embeddings_scatter = scatter_ax.scatter(
        history[0]["samples"][:, 0], history[0]["samples"][:, 1], alpha=0.5
    )
    codes_scatter = scatter_ax.scatter(
        history[0]["codebook"][:, 0], history[0]["codebook"][:, 1]
    )
    (mse_line,) = loss_ax.plot([], [], lw=2)
    (commitment_line,) = loss_ax.plot([], [], lw=2)

    scatter_ax.set_xlim(-1, 1)
    scatter_ax.set_ylim(-1, 1)

    downsampled_history = [x for i, x in enumerate(history) if i % downsampling == 1]
    i = 0

    loss_ax.set_xlim(0, len(downsampled_history))
    loss_ax.set_ylim(0, 1)

    def animate(history_step):
        nonlocal i
        i += 1
        embeddings_scatter.set_offsets(history_step["samples"])
        codes_scatter.set_offsets(history_step["codebook"])
        embeddings_scatter.set_alpha(0.5)

        mse_losses = np.array([hs["mse_loss"] for hs in downsampled_history[:i]])
        mse_line.set_data(np.arange(len(mse_losses)), mse_losses)
        commitment_losses = np.array(
            [hs["commitment_loss"] for hs in downsampled_history[:i]]
        )
        commitment_line.set_data(np.arange(len(commitment_losses)), commitment_losses)

        # exponential moving average of the limits
        previous = [*scatter_ax.get_xlim(), *scatter_ax.get_ylim()]
        update = [
            history_step["samples"][:, 0].min(),
            history_step["samples"][:, 0].max(),
            history_step["samples"][:, 1].min(),
            history_step["samples"][:, 1].max(),
        ]
        alpha = 1.0
        updated = [alpha * p + (1 - alpha) * u for p, u in zip(previous, update)]
        scatter_ax.set_xlim(updated[0], updated[1])
        scatter_ax.set_ylim(updated[2], updated[3])

        return (embeddings_scatter, codes_scatter)

    ani = animation.FuncAnimation(
        fig, animate, frames=downsampled_history, interval=1 / 30, blit=True
    )
    # plt.close(fig)  # Don't show fig

    assert int(30 / downsampling) >= 1
    ani.save(filename, writer="ffmpeg", fps=30)

In [None]:
defaults = {"train_steps": 1000}


history = train_model(mse_coef=0.3, commitment_coef=0.01, learning_rate=0.8, **defaults)
save_animation("for-video/vq_mse_wins.mp4", history, downsampling=10)

history = train_model(mse_coef=0.01, commitment_coef=1.0, learning_rate=0.8, **defaults)
save_animation("for-video/vq_commitment_wins.mp4", history, downsampling=10)

history = train_model(mse_coef=0.3, commitment_coef=1.0, learning_rate=0.8, **defaults)
save_animation("for-video/vq_balanced.mp4", history, downsampling=10)