In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
from contextlib import nullcontext
from math import cos, sin

from codec import Bottleneck

channels = 2

samples = torch.concat(
    [
        torch.randn(30, channels) * 0.1 + torch.Tensor([0.1, 0.2]),
        torch.randn(30, channels) * 0.3 + torch.Tensor([-0.8, 0.9]),
        torch.randn(30, channels),
    ]
)
samples.requires_grad_(True)


angle = 0.01
noise_added = 0.01
learning_rate = 1.0
codebook_update_speed = 0.05

history = []
bottleneck = Bottleneck(
    channels=channels, codebook_size=4, codebook_update_speed=codebook_update_speed
)
optimizer = torch.optim.SGD([samples], lr=learning_rate)

for i in range(1000):
    history.append(
        {
            "samples": samples.detach().clone().numpy(),
            "codebook": bottleneck.codebook().detach().clone().numpy(),
        }
    )
    optimizer.zero_grad()
    q, commitment_loss = bottleneck(samples)
    commitment_loss.backward()
    optimizer.step()

    with torch.no_grad():
        # Tricky: We need samples[:] here because otherwise we create a different
        # tensor than the one registered with the optimizer, and it wouldn't know
        # how to update it.
        samples[:] = samples @ torch.Tensor(
            [
                [cos(angle), sin(angle)],
                [-sin(angle), cos(angle)],
            ]
        )
        samples += torch.randn_like(samples) * noise_added

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.animation as animation
from IPython.display import HTML

fig, ax = plt.subplots()

embeddings_scatter = ax.scatter(
    history[0]["samples"][:, 0], history[0]["samples"][:, 1]
)
codes_scatter = ax.scatter(history[0]["codebook"][:, 0], history[0]["codebook"][:, 1])


def animate(history_step):
    # line.set_ydata(np.sin(x + i / 50))  # update the data.
    embeddings_scatter.set_offsets(history_step["samples"])
    codes_scatter.set_offsets(history_step["codebook"])
    return (embeddings_scatter, codes_scatter)


ani = animation.FuncAnimation(fig, animate, frames=history, interval=1 / 30, blit=True)
plt.close(fig)  # Don't show fig, just the animation


HTML(ani.to_jshtml())