From 8d301a7f837fcd58623aed21a1d4dcefffe9fba1 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 11 Oct 2023 09:44:34 -0700 Subject: [PATCH] adopt finite scalar / lookup free quantization --- README.md | 22 ++++++++++++ parti_pytorch/version.py | 2 +- parti_pytorch/vit_vqgan.py | 71 +++++++++++++++++++++++--------------- setup.py | 4 +-- 4 files changed, 69 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index c34a69b..0f92db3 100644 --- a/README.md +++ b/README.md @@ -221,3 +221,25 @@ loss.backward() year = {2021} } ``` + +```bibtex +@misc{mentzer2023finite, + title = {Finite Scalar Quantization: VQ-VAE Made Simple}, + author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen}, + year = {2023}, + eprint = {2309.15505}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + +```bibtex +@misc{yu2023language, + title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation}, + author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang}, + year = {2023}, + eprint = {2310.05737}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` diff --git a/parti_pytorch/version.py b/parti_pytorch/version.py index d9fc5d6..df9144c 100644 --- a/parti_pytorch/version.py +++ b/parti_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.0.18' +__version__ = '0.1.1' diff --git a/parti_pytorch/vit_vqgan.py b/parti_pytorch/vit_vqgan.py index b7f1921..ba56ce7 100644 --- a/parti_pytorch/vit_vqgan.py +++ b/parti_pytorch/vit_vqgan.py @@ -3,7 +3,7 @@ from math import sqrt from functools import partial, wraps -from vector_quantize_pytorch import VectorQuantize as VQ +from vector_quantize_pytorch import VectorQuantize as VQ, LFQ import torch from torch import nn, einsum @@ -11,7 +11,7 @@ from torch.autograd import grad as torch_grad import torchvision -from einops import rearrange, reduce, repeat +from einops import rearrange, reduce, repeat, pack, unpack from einops_exts import rearrange_many from einops.layers.torch import Rearrange @@ -487,11 +487,18 @@ def __init__( l2_recon_loss = False, use_hinge_loss = True, vgg = None, - vq_codebook_dim = 64, - vq_codebook_size = 512, - vq_decay = 0.9, - vq_commitment_weight = 1., - vq_kmeans_init = True, + lookup_free_quantization = True, + codebook_size = 65536, + vq_kwargs: dict = dict( + codebook_dim = 64, + decay = 0.9, + commitment_weight = 1., + kmeans_init = True + ), + lfq_kwargs: dict = dict( + entropy_loss_weight = 0.1, + diversity_gamma = 2. + ), use_vgg_and_gan = True, discr_layers = 4, **kwargs @@ -502,7 +509,7 @@ def __init__( self.image_size = image_size self.channels = channels - self.codebook_size = vq_codebook_size + self.codebook_size = codebook_size self.enc_dec = ViTEncDec( dim = dim, @@ -512,17 +519,25 @@ def __init__( **encdec_kwargs ) - self.vq = VQ( - dim = self.enc_dec.encoded_dim, - codebook_dim = vq_codebook_dim, - codebook_size = vq_codebook_size, - decay = vq_decay, - commitment_weight = vq_commitment_weight, - kmeans_init = vq_kmeans_init, - accept_image_fmap = True, - use_cosine_sim = True, - **vq_kwargs - ) + # offer look up free quantization + # https://arxiv.org/abs/2310.05737 + + self.lookup_free_quantization = lookup_free_quantization + + if lookup_free_quantization: + self.quantizer = LFQ( + dim = self.enc_dec.encoded_dim, + codebook_size = codebook_size, + **lfq_kwargs + ) + else: + self.quantizer = VQ( + dim = self.enc_dec.encoded_dim, + codebook_size = codebook_size, + accept_image_fmap = True, + use_cosine_sim = True, + **vq_kwargs + ) # reconstruction loss @@ -582,24 +597,26 @@ def state_dict(self, *args, **kwargs): def load_state_dict(self, *args, **kwargs): return super().load_state_dict(*args, **kwargs) - @property - def codebook(self): - return self.vq.codebook - def get_fmap_from_codebook(self, indices): - codes = self.codebook[indices] - fmap = self.vq.project_out(codes) + if self.lookup_free_quantization: + indices, ps = pack([indices], 'b *') + fmap = self.quantizer.indices_to_codes(indices) + fmap, = unpack(fmap, ps, 'b * c') + else: + codes = self.quantizer.codebook[indices] + fmap = self.vq.project_out(codes) + return rearrange(fmap, 'b h w c -> b c h w') def encode(self, fmap, return_indices_and_loss = True): fmap = self.enc_dec.encode(fmap) - fmap, indices, commit_loss = self.vq(fmap) + fmap, indices, quantizer_aux_loss = self.quantizer(fmap) if not return_indices_and_loss: return fmap - return fmap, indices, commit_loss + return fmap, indices, quantizer_aux_loss def decode(self, fmap): return self.enc_dec.decode(fmap) diff --git a/setup.py b/setup.py index 345504d..49b34a5 100644 --- a/setup.py +++ b/setup.py @@ -19,13 +19,13 @@ 'text-to-image' ], install_requires=[ - 'einops>=0.4', + 'einops>=0.7', 'einops-exts', 'ema-pytorch', 'torch>=1.6', 'torchvision', 'transformers', - 'vector-quantize-pytorch>=0.9.2' + 'vector-quantize-pytorch>=1.9.4' ], classifiers=[ 'Development Status :: 4 - Beta',