Skip to content

Commit

Permalink
address #26 (comment)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 13, 2024
1 parent 3edf4dd commit b38a4c2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.14.22"
version = "1.14.23"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
19 changes: 17 additions & 2 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def identity(t):
def l2norm(t):
return F.normalize(t, p = 2, dim = -1)

def Sequential(*modules):
modules = [*filter(exists, modules)]
if len(modules) == 0:
return None
elif len(modules) == 1:
return modules[0]

return nn.Sequential(*modules)

def cdist(x, y):
x2 = reduce(x ** 2, 'b n d -> b n', 'sum')
y2 = reduce(y ** 2, 'b n d -> b n', 'sum')
Expand Down Expand Up @@ -702,6 +711,7 @@ def __init__(
kmeans_iters = 10,
sync_kmeans = True,
use_cosine_sim = False,
layernorm_after_project_in = False, # proposed by @SaltyChtao here https://github.com/lucidrains/vector-quantize-pytorch/issues/26#issuecomment-1324711561
threshold_ema_dead_code = 0,
channel_last = True,
accept_image_fmap = False,
Expand All @@ -721,7 +731,7 @@ def __init__(
in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
affine_param = False,
affine_param_batch_decay = 0.99,
affine_param_codebook_decay = 0.9,
affine_param_codebook_decay = 0.9,
sync_update_v = 0. # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
):
super().__init__()
Expand All @@ -733,7 +743,12 @@ def __init__(
codebook_input_dim = codebook_dim * heads

requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()

self.project_in = Sequential(
nn.Linear(dim, codebook_input_dim),
nn.LayerNorm(codebook_input_dim) if layernorm_after_project_in else None
) if requires_projection else nn.Identity()

self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()

self.has_projections = requires_projection
Expand Down

0 comments on commit b38a4c2

Please sign in to comment.