We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thanks for you jobs. When we checked the code, we found that there was no gradient for residual layer after second layer, please confirm it.
we change the code to : residual = residual - quantized ---> residual = residual - quantized.detach()
Here's the verification we did
if __name__ == "__main__": quantizer = ResidualVQ( num_quantizers=4, dim=256, codebook_size=16, kmeans_init=True, kmeans_iters=10, threshold_ema_dead_code=2, channel_last=False, ) for i in range(4): input = torch.rand((2, 256, 30), requires_grad=True) quantized, indices, losses = quantizer(input) print(quantized.shape, indices.shape, losses.shape) losses[0, i].backward() print(input.grad)
The text was updated successfully, but these errors were encountered:
address #33
ecf2f7c
@aijianiula0601 i do believe you are correct! thank you for spotting this! 🙏
Sorry, something went wrong.
No branches or pull requests
Thanks for you jobs. When we checked the code, we found that there was no gradient for residual layer after second layer, please confirm it.
we change the code to : residual = residual - quantized ---> residual = residual - quantized.detach()
Here's the verification we did
The text was updated successfully, but these errors were encountered: