Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FSQ implementation #74

Merged
merged 14 commits into from
Sep 29, 2023
2 changes: 1 addition & 1 deletion examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, **vq_kwargs):
def forward(self, x):
for layer in self.layers:
if isinstance(layer, VectorQuantize):
x_flat, indices, commit_loss = layer(x)
x, indices, commit_loss = layer(x)
else:
x = layer(x)

Expand Down
94 changes: 94 additions & 0 deletions examples/autoencoder_fsq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# FashionMnist VQ experiment with various settings, using FSQ.
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py

from tqdm.auto import trange

import math
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from vector_quantize_pytorch import FSQ


lr = 3e-4
train_iter = 1000
levels = [8, 6, 5] # target size 2^8, actual size 240
num_codes = math.prod(levels)
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"


class SimpleFSQAutoEncoder(nn.Module):
def __init__(self, levels: list[int]):
super().__init__()
self.layers = nn.ModuleList(
[
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
nn.Conv2d(8, 8, kernel_size=6, stride=3, padding=0),
FSQ(levels),
nn.ConvTranspose2d(8, 8, kernel_size=6, stride=3, padding=0),
nn.Conv2d(8, 16, kernel_size=4, stride=1, padding=2),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=2),
]
)
return

def forward(self, x):
for layer in self.layers:
if isinstance(layer, FSQ):
x, indices = layer(x)
else:
x = layer(x)

return x.clamp(-1, 1), indices


def train(model, train_loader, train_iterations=1000):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)

for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices = model(x)
rec_loss = (out - x).abs().mean()
rec_loss.backward()

opt.step()
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)
return


transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)

print("baseline")
torch.random.manual_seed(seed)
model = SimpleFSQAutoEncoder(levels).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
3 changes: 2 additions & 1 deletion vector_quantize_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
65 changes: 65 additions & 0 deletions vector_quantize_pytorch/finite_scalar_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1
"""

import torch
import torch.nn as nn


def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()


class FSQ(nn.Module):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for porting this! Do you mind if we link this repo in the next version and our own public code release?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK if you are also planning to update the README and I can send some figs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do! and i believe @lucidrains and @sekstini will appreciate it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, go ahead 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fab-jul

LMK if you are also planning to update the README and I can send some figs.

That would be great 🙏

def __init__(self, levels: list[int]):
super().__init__()
_levels = torch.tensor(levels, dtype=torch.int32)
self.register_buffer("_levels", _levels)

_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
self.register_buffer("_basis", _basis)

codebook_size = self._levels.prod()
implicit_codebook = self.indices_to_codes(torch.arange(codebook_size))
self.register_buffer("implicit_codebook", implicit_codebook)

def forward(self, z: torch.Tensor) -> torch.Tensor:
zhat = self.quantize(z)
indices = self.codes_to_indices(zhat)
return zhat, indices

def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 - eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
sekstini marked this conversation as resolved.
Show resolved Hide resolved
shift = (offset / half_l).tan()
return (z + shift).tanh() * half_l - offset

def quantize(self, z: torch.Tensor) -> torch.Tensor:
"""Quanitzes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width

def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width

def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
half_width = self._levels // 2
return (zhat - half_width) / half_width

def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == len(self._levels)
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(torch.int32)

def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
"""Inverse of `codes_to_indices`."""
indices = indices.unsqueeze(-1)
codes_non_centered = (indices // self._basis) % self._levels
return self._scale_and_shift_inverse(codes_non_centered)