Skip to content

Commit

Permalink
adopt finite scalar / lookup free quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 11, 2023
1 parent 9bbd3d4 commit 8d301a7
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 30 deletions.
22 changes: 22 additions & 0 deletions README.md
Expand Up @@ -221,3 +221,25 @@ loss.backward()
year = {2021}
}
```

```bibtex
@misc{mentzer2023finite,
title = {Finite Scalar Quantization: VQ-VAE Made Simple},
author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
year = {2023},
eprint = {2309.15505},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{yu2023language,
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
year = {2023},
eprint = {2310.05737},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
2 changes: 1 addition & 1 deletion parti_pytorch/version.py
@@ -1 +1 @@
__version__ = '0.0.18'
__version__ = '0.1.1'
71 changes: 44 additions & 27 deletions parti_pytorch/vit_vqgan.py
Expand Up @@ -3,15 +3,15 @@
from math import sqrt
from functools import partial, wraps

from vector_quantize_pytorch import VectorQuantize as VQ
from vector_quantize_pytorch import VectorQuantize as VQ, LFQ

import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision

from einops import rearrange, reduce, repeat
from einops import rearrange, reduce, repeat, pack, unpack
from einops_exts import rearrange_many
from einops.layers.torch import Rearrange

Expand Down Expand Up @@ -487,11 +487,18 @@ def __init__(
l2_recon_loss = False,
use_hinge_loss = True,
vgg = None,
vq_codebook_dim = 64,
vq_codebook_size = 512,
vq_decay = 0.9,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
lookup_free_quantization = True,
codebook_size = 65536,
vq_kwargs: dict = dict(
codebook_dim = 64,
decay = 0.9,
commitment_weight = 1.,
kmeans_init = True
),
lfq_kwargs: dict = dict(
entropy_loss_weight = 0.1,
diversity_gamma = 2.
),
use_vgg_and_gan = True,
discr_layers = 4,
**kwargs
Expand All @@ -502,7 +509,7 @@ def __init__(

self.image_size = image_size
self.channels = channels
self.codebook_size = vq_codebook_size
self.codebook_size = codebook_size

self.enc_dec = ViTEncDec(
dim = dim,
Expand All @@ -512,17 +519,25 @@ def __init__(
**encdec_kwargs
)

self.vq = VQ(
dim = self.enc_dec.encoded_dim,
codebook_dim = vq_codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
kmeans_init = vq_kmeans_init,
accept_image_fmap = True,
use_cosine_sim = True,
**vq_kwargs
)
# offer look up free quantization
# https://arxiv.org/abs/2310.05737

self.lookup_free_quantization = lookup_free_quantization

if lookup_free_quantization:
self.quantizer = LFQ(
dim = self.enc_dec.encoded_dim,
codebook_size = codebook_size,
**lfq_kwargs
)
else:
self.quantizer = VQ(
dim = self.enc_dec.encoded_dim,
codebook_size = codebook_size,
accept_image_fmap = True,
use_cosine_sim = True,
**vq_kwargs
)

# reconstruction loss

Expand Down Expand Up @@ -582,24 +597,26 @@ def state_dict(self, *args, **kwargs):
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)

@property
def codebook(self):
return self.vq.codebook

def get_fmap_from_codebook(self, indices):
codes = self.codebook[indices]
fmap = self.vq.project_out(codes)
if self.lookup_free_quantization:
indices, ps = pack([indices], 'b *')
fmap = self.quantizer.indices_to_codes(indices)
fmap, = unpack(fmap, ps, 'b * c')
else:
codes = self.quantizer.codebook[indices]
fmap = self.vq.project_out(codes)

return rearrange(fmap, 'b h w c -> b c h w')

def encode(self, fmap, return_indices_and_loss = True):
fmap = self.enc_dec.encode(fmap)

fmap, indices, commit_loss = self.vq(fmap)
fmap, indices, quantizer_aux_loss = self.quantizer(fmap)

if not return_indices_and_loss:
return fmap

return fmap, indices, commit_loss
return fmap, indices, quantizer_aux_loss

def decode(self, fmap):
return self.enc_dec.decode(fmap)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -19,13 +19,13 @@
'text-to-image'
],
install_requires=[
'einops>=0.4',
'einops>=0.7',
'einops-exts',
'ema-pytorch',
'torch>=1.6',
'torchvision',
'transformers',
'vector-quantize-pytorch>=0.9.2'
'vector-quantize-pytorch>=1.9.4'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 8d301a7

Please sign in to comment.