In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

import matplotlib.animation as animation
from IPython.display import HTML


def save_animation(
    filename: str,
    history: list[dict],
    n_frames: int = 200,
    show_quantization: bool = True,
):
    fig, axes = plt.subplots(1, 2)
    fig.set_figwidth(11)

    scatter_ax = axes[0]
    loss_ax = axes[1]

    embeddings_scatter = scatter_ax.scatter(
        history[0]["samples"][:, 0], history[0]["samples"][:, 1], alpha=1.0
    )
    codes_scatter = scatter_ax.scatter(
        history[0]["codebook"][:, 0],
        history[0]["codebook"][:, 1],
        alpha=0.8 if show_quantization else 0.0,
        s=75,
        marker="s",
        color="gray",
    )
    (mse_line,) = loss_ax.plot([], [], lw=2)
    (commitment_line,) = loss_ax.plot([], [], lw=2)

    scatter_ax.set_xlim(-15, 15)
    scatter_ax.set_ylim(-15, 15)

    downsampled_history = [
        history[round(i)] for i in np.linspace(0, len(history) - 1, n_frames)
    ]

    loss_ax.set_xlim(0, len(downsampled_history))
    loss_ax.set_ylim(0, np.percentile([hs["mse_loss"] for hs in history], 90) * 1.1)

    i = 0

    def animate(history_step):
        nonlocal i
        i += 1
        embeddings_scatter.set_offsets(history_step["samples"])
        codes_scatter.set_offsets(history_step["codebook"])
        embeddings_scatter.set_alpha(0.8)

        embeddings_scatter.set_array(history_step["labels"])
        # For debugging:
        # embeddings_scatter.set_alpha((history_step["labels"] == 1).astype(float))
        embeddings_scatter.set_cmap("Set3")

        mse_losses = np.array([hs["mse_loss"] for hs in downsampled_history[:i]])
        mse_line.set_data(np.arange(len(mse_losses)), mse_losses)
        commitment_losses = np.array(
            [hs["commitment_loss"] for hs in downsampled_history[:i]]
        )
        commitment_line.set_data(np.arange(len(commitment_losses)), commitment_losses)

        # exponential moving average of the limits
        previous = [*scatter_ax.get_xlim(), *scatter_ax.get_ylim()]
        update = [
            history_step["samples"][:, 0].min(),
            history_step["samples"][:, 0].max(),
            history_step["samples"][:, 1].min(),
            history_step["samples"][:, 1].max(),
        ]
        alpha = 0.9
        updated = [alpha * p + (1 - alpha) * u for p, u in zip(previous, update)]
        scatter_ax.set_xlim(updated[0], updated[1])
        scatter_ax.set_ylim(updated[2], updated[3])

        return (embeddings_scatter, codes_scatter)

    ani = animation.FuncAnimation(
        fig, animate, frames=downsampled_history, interval=1 / 30, blit=True
    )
    # plt.close(fig)  # Don't show fig

    ani.save(filename, writer="ffmpeg", fps=30)

In [None]:
from pathlib import Path
import torch
from torch.utils.data import Dataset
import pandas as pd
import librosa
import numpy as np


class NSynthDataset(Dataset):
    def __init__(self, data_dir: Path | str):
        self.data_dir = Path(data_dir)

        with open(self.data_dir / "examples.json") as f:
            df = pd.read_json(f, orient="index")
        df = df.loc[df["instrument_family_str"].isin(["string", "brass", "guitar"])]
        df = df.loc[df["instrument_source_str"].isin(["acoustic"])]

        self.keys = df.index.tolist()
        self.df = df
        self.cache = {}

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]

        key = self.keys[idx]
        meta = self.df.loc[key]
        audio_path = self.data_dir / "audio" / f"{key}.wav"
        spectrogram = self.load_nsynth_audio(audio_path)

        res = torch.tensor(spectrogram, dtype=torch.float32), meta["instrument_family"]
        self.cache[idx] = res
        return res
        # return {
        #     "spectrogram": torch.tensor(spectrogram, dtype=torch.float32),
        #     "meta": meta.to_dict(),
        # }

    def load_nsynth_audio(self, path: Path):
        y, sr = librosa.load(path, sr=None)
        spectrogram = librosa.feature.melspectrogram(y=y, sr=sr)
        assert spectrogram.shape == (128, 126)
        spectrogram = spectrogram[:, :32]  # First second
        spectrogram = spectrogram[::2, ::2]  # Downsample in both axes
        spectrogram = np.log(spectrogram + 1e-8)
        return spectrogram

In [None]:
dataset = NSynthDataset(data_dir="data/nsynth/nsynth-train/")

In [None]:
dataset[0][0].shape

In [None]:
import tqdm.auto

for _sample in tqdm.auto.tqdm(dataset):
    pass

In [None]:
# import pickle
# with open("nsynth_dataset_cache.pkl", "wb") as f:
#     pickle.dump(dataset.cache, f)

In [None]:
import itertools
import pickle
import tqdm.auto

from codec import ResidualVectorQuantizer, VectorQuantizer
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import Subset


def make_architecture(in_features: int, latent_dim: int, out_features: int):
    return torch.nn.Sequential(
        torch.nn.Linear(in_features=in_features, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=latent_dim),
        torch.nn.GELU(),
        torch.nn.Linear(in_features=latent_dim, out_features=out_features),
    )


def get_filtered_fashion_mnist():
    dataset = torchvision.datasets.FashionMNIST(
        "./data/fashion_mnist/", download=True, transform=ToTensor()
    )
    # Filter to only first 3 classes
    idx = [i for i, (_, y) in enumerate(dataset) if y in [0, 1, 2]]
    dataset = Subset(dataset, idx)
    return dataset


def train_model(
    *,
    train_steps: int,
    mse_coef=0.3,
    commitment_coef=1.0,
    learning_rate: float = 0.8,
    n_codebooks: int = 1,
):
    channels = 2
    latent_dim = 16
    codebook_update_speed = 0.05
    device = "cuda"
    history = []

    dataset = get_filtered_fashion_mnist()

    # dataset = NSynthDataset(data_dir="data/nsynth/nsynth-train/")
    # with open("nsynth_dataset_cache.pkl", "rb") as f:
    #     dataset.cache = pickle.load(f)

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=256, shuffle=True, num_workers=4
    )

    data_size = dataset[0][0].numel()

    torch.manual_seed(123)

    with torch.device("cuda"):
        encoder = make_architecture(data_size, latent_dim, channels)
        decoder = make_architecture(channels, latent_dim, data_size)

        bottleneck = ResidualVectorQuantizer(
            channels=channels,
            codebook_size=36,
            n_codebooks=n_codebooks,
            codebook_update_speed=codebook_update_speed,
        )

    # AdamW works better than SGD even for this simple problem
    optimizer = torch.optim.AdamW(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=learning_rate,
    )
    encoder.to(device)
    decoder.to(device)

    for i, (batch, batch_y) in (
        pbar := tqdm.auto.tqdm(
            zip(range(train_steps), itertools.cycle(dataloader)),
            desc="Training",
            total=train_steps,
        )
    ):
        optimizer.zero_grad()

        x = batch.view(batch.size(0), -1).to(device)
        z = encoder(x)

        # bottleneck.restart_unused_codes(z)
        _codes, z_quantized, commitment_loss = bottleneck(z)

        if commitment_coef == 0:
            z_quantized = z

        x_reconstructed = decoder(z_quantized)

        mse_loss = ((x_reconstructed - x) ** 2).mean()

        residuals = z - bottleneck.bottlenecks[0](z)[1]

        def for_history(x):
            return x.detach().clone().to("cpu").numpy()

        history.append(
            {
                "samples": for_history(z),
                "residuals": for_history(residuals),
                "labels": for_history(batch_y),
                "codebooks": [
                    for_history(b.codebook()) for b in bottleneck.bottlenecks
                ],
                "mse_loss": for_history(mse_loss * mse_coef),
                "commitment_loss": for_history(commitment_loss * commitment_coef),
            }
        )
        loss = mse_loss * mse_coef + commitment_loss * commitment_coef
        loss.backward()
        pbar.set_postfix({"loss": loss.item()})

        optimizer.step()

    print(f"Final step: {history[-1]['mse_loss']=} {history[-1]['commitment_loss']=}")

    return history, (encoder, bottleneck, decoder)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

import matplotlib.animation as animation
from IPython.display import HTML


def save_animation_rvq(
    filename: str,
    history: list[dict],
    n_frames: int = 200,
    show_quantization: bool = True,
):
    assert len(history[0]["codebooks"]) == 2

    fig, axes = plt.subplots(1, 2)
    fig.set_figwidth(11)
    axes[0].set_title("Embeddings")
    axes[1].set_title("Residuals")

    # scatter_ax = axes[0]

    embeddings_scatter = []
    codes_scatter = []

    for j in [0, 1]:
        embeddings_scatter.append(
            axes[j].scatter(
                history[0]["samples" if j == 0 else "residuals"][:, 0],
                history[0]["samples" if j == 0 else "residuals"][:, 1],
                alpha=1.0,
            )
        )
        codes_scatter.append(
            axes[j].scatter(
                history[0]["codebooks"][j][:, 0],
                history[0]["codebooks"][j][:, 1],
                alpha=0.8 if show_quantization else 0.0,
                s=75,
                marker="s",
                color="gray",
            )
        )

    downsampled_history = [
        history[round(i)] for i in np.linspace(0, len(history) - 1, n_frames)
    ]

    def animate(history_step):
        for j in [0, 1]:
            samples_key = "samples" if j == 0 else "residuals"

            embeddings_scatter[j].set_offsets(history_step[samples_key])
            codes_scatter[j].set_offsets(history_step["codebooks"][j])
            embeddings_scatter[j].set_alpha(0.8)

            embeddings_scatter[j].set_array(history_step["labels"])
            # For debugging:
            # embeddings_scatter[level].set_alpha((history_step["labels"] == 1).astype(float))
            embeddings_scatter[j].set_cmap("Set3")

            # exponential moving average of the limits
            previous = [*axes[j].get_xlim(), *axes[j].get_ylim()]
            update = [
                history_step[samples_key][:, 0].min(),
                history_step[samples_key][:, 0].max(),
                history_step[samples_key][:, 1].min(),
                history_step[samples_key][:, 1].max(),
            ]
            alpha = 0.9
            updated = [alpha * p + (1 - alpha) * u for p, u in zip(previous, update)]
            axes[j].set_xlim(updated[0], updated[1])
            axes[j].set_ylim(updated[2], updated[3])

        return (*embeddings_scatter, *codes_scatter)

    ani = animation.FuncAnimation(
        fig, animate, frames=downsampled_history, interval=1 / 30, blit=True
    )
    # plt.close(fig)  # Don't show fig

    ani.save(filename, writer="ffmpeg", fps=30)

In [None]:
# history, (encoder, bottleneck, decoder) = train_model(
#     train_steps=3000,
#     mse_coef=1.0,
#     commitment_coef=1.0,
#     learning_rate=0.001,
#     n_codebooks=2,
# )
save_animation_rvq("for-video/rvq.mp4", history, n_frames=250)

In [None]:
history[0]

In [None]:
history, (encoder, bottleneck, decoder) = train_model(
    train_steps=3000,
    mse_coef=1.0,
    commitment_coef=0.0,
    learning_rate=0.001,
)
# save_animation(
#     "for-video/vq_unquantized.mp4", history, show_quantization=False, n_frames=250
# )
save_animation(
    "for-video/vq_unquantized_with_clustering.mp4",
    history,
    show_quantization=True,
    n_frames=250,
)

In [None]:
save_animation(
    "for-video/vq_unquantized.mp4", history, show_quantization=False, n_frames=250
)

In [None]:
history = train_model(
    train_steps=30000, mse_coef=1.0, commitment_coef=1.0, learning_rate=0.001
)
save_animation("for-video/vq_balanced.mp4", history, n_frames=250)

In [None]:
history, _ = train_model(
    train_steps=30000, mse_coef=1.0, commitment_coef=3.0, learning_rate=0.001
)
save_animation("for-video/vq_commitment_wins.mp4", history)

In [None]:
history = train_model(
    train_steps=30000, mse_coef=1.0, commitment_coef=0.1, learning_rate=0.001
)
save_animation("for-video/vq_mse_wins.mp4", history)

In [None]:
next(iter(dataloader))

In [None]:
dataset = get_filtered_fashion_mnist()

In [None]:
import torchvision.transforms.functional as F

from torchvision.utils import make_grid

plt.rcParams["savefig.bbox"] = "tight"


def show_images(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img), cmap="gray")
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    # return fig

In [None]:
example_batch = torch.concat([dataset[i][0] for i in [5, 7, 4, 6]], dim=0).to("cuda")
# example_batch = torch.concat([dataset[i][0] for i in range(16)], dim=0).to("cuda")

In [None]:
import einops

In [None]:
z = encoder(einops.rearrange(example_batch, "b w h -> b (w h)").to("cuda"))
_codes, z_quantized, commitment_loss = bottleneck.bottlenecks[0](z)
decoded_flat = decoder(z_quantized)
decoded = einops.rearrange(decoded_flat, "b (w h) -> b w h", w=28, h=28).to("cpu")
decoded = torch.clamp(decoded, 0.0, 1.0)

In [None]:
from torchvision.utils import make_grid

grid = make_grid(
    einops.rearrange(torch.concat([example_batch.cpu(), decoded]), "b w h -> b 1 w h"),
    nrow=len(example_batch),
    pad_value=0.0,
)
show_images(grid)

In [None]:
_codes