Skip to content

Commit

Permalink
make sure codebook is learnable in the presence of orthogonal regular…
Browse files Browse the repository at this point in the history
…ization loss
  • Loading branch information
lucidrains committed Dec 17, 2021
1 parent 4b3b209 commit ebce893
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.4.9',
version = '0.4.10',
license='MIT',
description = 'Vector Quantization - Pytorch',
author = 'Phil Wang',
Expand Down
33 changes: 26 additions & 7 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Expand Up @@ -86,7 +86,8 @@ def __init__(
decay = 0.8,
eps = 1e-5,
threshold_ema_dead_code = 2,
use_ddp = False
use_ddp = False,
learnable_codebook = False
):
super().__init__()
self.decay = decay
Expand All @@ -99,11 +100,17 @@ def __init__(
self.threshold_ema_dead_code = threshold_ema_dead_code

self.all_reduce_fn = distributed.all_reduce if use_ddp else noop

self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('cluster_size', torch.zeros(codebook_size))
self.register_buffer('embed', embed)
self.register_buffer('embed_avg', embed.clone())

self.learnable_codebook = learnable_codebook
if learnable_codebook:
self.embed = nn.Parameter(embed)
else:
self.register_buffer('embed', embed)

@torch.jit.ignore
def init_embed_(self, data):
if self.initted:
Expand Down Expand Up @@ -137,10 +144,12 @@ def expire_codes_(self, batch_samples):
def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, '... d -> (...) d')
embed = self.embed.t()

self.init_embed_(flatten)

embed = self.embed if not self.learnable_codebook else self.embed.detach()
embed = self.embed.t()

dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
Expand Down Expand Up @@ -179,7 +188,8 @@ def __init__(
decay = 0.8,
eps = 1e-5,
threshold_ema_dead_code = 2,
use_ddp = False
use_ddp = False,
learnable_codebook = False
):
super().__init__()
self.decay = decay
Expand All @@ -197,7 +207,12 @@ def __init__(
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('cluster_size', torch.zeros(codebook_size))
self.register_buffer('embed', embed)

self.learnable_codebook = learnable_codebook
if learnable_codebook:
self.embed = nn.Parameter(embed)
else:
self.register_buffer('embed', embed)

@torch.jit.ignore
def init_embed_(self, data):
Expand Down Expand Up @@ -237,7 +252,9 @@ def forward(self, x):

self.init_embed_(flatten)

embed = l2norm(self.embed)
embed = self.embed if not self.learnable_codebook else self.embed.detach()
embed = l2norm(embed)

dist = flatten @ embed.t()
embed_ind = dist.max(dim = -1).indices
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
Expand Down Expand Up @@ -303,6 +320,7 @@ def __init__(
self.eps = eps
self.commitment_weight = default(commitment_weight, commitment)

has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
self.orthogonal_reg_weight = orthogonal_reg_weight
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
Expand All @@ -318,7 +336,8 @@ def __init__(
decay = decay,
eps = eps,
threshold_ema_dead_code = threshold_ema_dead_code,
use_ddp = sync_codebook
use_ddp = sync_codebook,
learnable_codebook = has_codebook_orthogonal_loss
)

self.codebook_size = codebook_size
Expand Down

0 comments on commit ebce893

Please sign in to comment.