Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero grad in residual vq #25

Open
npuichigo opened this issue Dec 26, 2022 · 8 comments
Open

Zero grad in residual vq #25

npuichigo opened this issue Dec 26, 2022 · 8 comments
Labels
bug Something isn't working

Comments

@npuichigo
Copy link

🐛 Bug Report

Zero grad in second residual vq as mentioned here (lucidrains/vector-quantize-pytorch#33)

residual = residual - quantized

The fix link is lucidrains/vector-quantize-pytorch@ecf2f7c

@npuichigo npuichigo added the bug Something isn't working label Dec 26, 2022
@npuichigo
Copy link
Author

@adefossez

@adefossez
Copy link
Contributor

Thanks for bringing that out!

It seem like this won't impact the Straight-Through-Estimator gradient for the Encoder, but will kill the commitment loss for all residual VQ but the first one right ?

@npuichigo
Copy link
Author

It seems so. But I'm not sure how much it affects the final result.

@adefossez
Copy link
Contributor

I'm a bit reluctant on introducing a change we haven't tested in this codebase, as it could change the best hyper params etc. I can add a warning however if the model is used in training mode pointing to this issue.

adefossez added a commit that referenced this issue Jan 24, 2023
@cantabile-kwok
Copy link

cantabile-kwok commented Apr 4, 2023

@adefossez @npuichigo Could you please point out into more detail why "this won't impact the Straight-Through-Estimator gradient for the Encoder"? I think if the residual is computed in a sense that doesn't pass its real gradients, then the gradient estimator may also be affected. The following code snippet may illustrate this:

import torch
def quantize(x, codebook):
    diff = codebook - x  # (n_code, dim)
    mse = (diff**2).sum(1)
    idx = torch.argmin(mse)
    return codebook[idx]

dim = 5
x = torch.randn(1, dim, requires_grad=True)
codebook1 = torch.randn(10, dim)
codebook2 = torch.randn(10, dim)

q1 = quantize(x, codebook1)  # quantize x with first codebook
q1 = x + (q1 - x).detach()  # transplant q1's gradient to x
residual = x - q1  # detach q1 or not may make a difference. Compute residual for next level quantizing
q2 = quantize(residual, codebook2)  # quantize residual with second codebook
q2 = residual + (q2 - residual).detach()  # transplant q2's gradient to residual

loss = 0*q1.sum() + 1*q2.sum()  # loss is a function of q1 and q2, now it is independent of q1.
loss.backward()
print(x.grad)

The printed gradient is all zero, but if we replace residual = x - q1 with residual = x - q1.detach(), the gradient will be non-zero.

@adefossez
Copy link
Contributor

why did you put 0 * q1.sum() ? that is what is breaking the STE gradient. With the current code d q1 / d x = Id and d q_i d / x = 0 for all i > 1, which is okay as the overall gradient d (sum q_i) / d x = Id which is what we want. The only thing that is impacted in the commitment loss.

@cantabile-kwok
Copy link

cantabile-kwok commented Apr 4, 2023

Oh, I think I over-complicated the problem here. In the model, all the quantization outputs q_i are simply added to feed the decoder, so the relation d (sum q_i) / d x = Id helps making this STE still working. In my code snippet, I assume the loss function can be any arbitrary function of argument q1 and q2. In this case, the gradient from q2 will never impact the previous networks, thus may not be good.

Still, if we replace residual = x - q1 with residual = x - q1.detach(), it seems d (sum q_i) / d x = n*Id then. Thus the scale of the losses may be affected. Thanks for the clarification!

@DingWeiPeng
Copy link

DingWeiPeng commented Jan 9, 2024

@adefossez @cantabile-kwok

If residual = residual - quantized , then the second codebook can update with gradient but it can not afffect the first codebook.
If residual = residual - quantized.detach(), then the second codebook's gradient will affect the fisrt codebook.

In core_vq.py, there is the following code in VectorQuantization Class :
image

Now there is the following code in the ResidualVectorQuantization Class
image

So, this problem equals to the following problem. The following code snippet may illustrate this:

'''
import torch
def quantize(x, codebook):
diff = codebook - x # (n_code, dim)
mse = (diff**2).sum(1)
idx = torch.argmin(mse)
return codebook[idx]

dim = 5
x = torch.randn(1, dim, requires_grad=True)
codebook1 = torch.randn(10, dim)
codebook2 = torch.randn(10, dim)

q1 = quantize(x, codebook1) # quantize x with first codebook
q1 = x + (q1 - x).detach() # transplant q1's gradient to x
residual = x - q1.detach() # detach q1 or not may make a difference. Compute residual for next level quantizing
q2 = quantize(residual, codebook2) # quantize residual with second codebook
q2 = residual + (q2 - residual).detach() # transplant q2's gradient to residual

loss = 1*q2.sum() # loss is a function of q1 and q2, now it is independent of q1.
loss.backward()
print(x.grad)
'''

if residual = x-q1, x.grad = 0,
if residul = x-q1.detach(), x.grad = tensor([[1., 1., 1., 1., 1.]])

Dinglet pushed a commit to Dinglet/encodec that referenced this issue Jan 25, 2024
thatsvenyouknow pushed a commit to thatsvenyouknow/neuro-encodec that referenced this issue Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants