Skip to content

Commit

Permalink
add ability to fetch all codes across all quantization layers in RQ, …
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 26, 2022
1 parent 146810e commit ec24746
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (1, 1024, 8), (1, 8)
# (batch, seq, dim), (batch, seq, quantizer), (batch, quantizer)

# if you need all the codes across the quantization layers, just pass return_all_codes = True

quantized, indices, commit_loss all_codes = residual_vq(x, return_all_codes = True)

# all_codes - (quantizer, batch, seq, dim)
```

Furthermore, <a href="https://arxiv.org/abs/2203.01941">this paper</a> uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.9.2',
version = '0.10.0',
license='MIT',
description = 'Vector Quantization - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
35 changes: 33 additions & 2 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from functools import partial

import torch
from torch import nn
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize

from einops import rearrange, repeat

class ResidualVQ(nn.Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
*,
num_quantizers,
shared_codebook = False,
heads = 1,
**kwargs
):
super().__init__()
assert heads == 1, 'residual vq is not compatible with multi-headed codes'

self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])

if not shared_codebook:
Expand All @@ -24,7 +30,18 @@ def __init__(
for vq in rest_vq:
vq._codebook = codebook

def forward(self, x):
@property
def codebooks(self):
codebooks = [layer._codebook.embed for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
codebooks = rearrange(codebooks, 'q 1 c d -> q c d')
return codebooks

def forward(
self,
x,
return_all_codes = False
):
quantized_out = 0.
residual = x

Expand All @@ -40,4 +57,18 @@ def forward(self, x):
all_losses.append(loss)

all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))
return quantized_out, all_indices, all_losses

ret = (quantized_out, all_indices, all_losses)

if return_all_codes:
# whether to return all codes from all codebooks across layers

codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = x.shape[0])
gather_indices = repeat(all_indices, 'b n q -> q b n d', d = codebooks.shape[-1])

all_codes = codebooks.gather(2, gather_indices) # gather all codes

# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
ret = (*ret, all_codes)

return ret

0 comments on commit ec24746

Please sign in to comment.