-
Notifications
You must be signed in to change notification settings - Fork 305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zero grad in residual vq #25
Comments
Thanks for bringing that out! It seem like this won't impact the Straight-Through-Estimator gradient for the Encoder, but will kill the commitment loss for all residual VQ but the first one right ? |
It seems so. But I'm not sure how much it affects the final result. |
I'm a bit reluctant on introducing a change we haven't tested in this codebase, as it could change the best hyper params etc. I can add a warning however if the model is used in training mode pointing to this issue. |
@adefossez @npuichigo Could you please point out into more detail why "this won't impact the Straight-Through-Estimator gradient for the Encoder"? I think if the residual is computed in a sense that doesn't pass its real gradients, then the gradient estimator may also be affected. The following code snippet may illustrate this: import torch
def quantize(x, codebook):
diff = codebook - x # (n_code, dim)
mse = (diff**2).sum(1)
idx = torch.argmin(mse)
return codebook[idx]
dim = 5
x = torch.randn(1, dim, requires_grad=True)
codebook1 = torch.randn(10, dim)
codebook2 = torch.randn(10, dim)
q1 = quantize(x, codebook1) # quantize x with first codebook
q1 = x + (q1 - x).detach() # transplant q1's gradient to x
residual = x - q1 # detach q1 or not may make a difference. Compute residual for next level quantizing
q2 = quantize(residual, codebook2) # quantize residual with second codebook
q2 = residual + (q2 - residual).detach() # transplant q2's gradient to residual
loss = 0*q1.sum() + 1*q2.sum() # loss is a function of q1 and q2, now it is independent of q1.
loss.backward()
print(x.grad) The printed gradient is all zero, but if we replace |
why did you put |
Oh, I think I over-complicated the problem here. In the model, all the quantization outputs Still, if we replace |
🐛 Bug Report
Zero grad in second residual vq as mentioned here (lucidrains/vector-quantize-pytorch#33)
encodec/encodec/quantization/core_vq.py
Line 336 in 1943298
The fix link is lucidrains/vector-quantize-pytorch@ecf2f7c
The text was updated successfully, but these errors were encountered: