Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMA update on CosineCodebook #26

Open
roomo7time opened this issue Sep 27, 2022 · 7 comments
Open

EMA update on CosineCodebook #26

roomo7time opened this issue Sep 27, 2022 · 7 comments

Comments

@roomo7time
Copy link

roomo7time commented Sep 27, 2022

The original VIT-VQGAN paper does not seem to use EMA update for codebook learning since their codebook is unit-normalized vectors.

Particularly, to my understanding, EMA update does not quite make sense when the encoder outputs and codebook vectors are unit-normalized ones.

What's your take on this? Should we NOT use EMA update with CosineCodebook?

@pengzhangzhi
Copy link

pengzhangzhi commented Oct 29, 2022

Would you like to explain why ema does not work for the unit-normalized codebook?

@Saltychtao
Copy link

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

@Saltychtao
Copy link

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

In case anyone else has this problem, I add a layernorm layer after the vq_in projection, and the growing norm problem is largely solved.

@jzhang38
Copy link

jzhang38 commented Mar 9, 2023

@Saltychtao I also encounter a similar issue. Does vq_in refer to VectorQuantize.project_in?

@Saltychtao
Copy link

@Saltychtao I also encounter a similar issue. Does vq_in refer to VectorQuantize.project_in?

Yes.

@santisy
Copy link

santisy commented May 13, 2024

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

In case anyone else has this problem, I add a layernorm layer after the vq_in projection, and the growing norm problem is largely solved.

@Saltychtao Hi, just want to make sure that the current vesion of implementation here seems to put one normalization (l2norm) after the project_in. I also encounter the training loss explosion issue somehow at current version

lucidrains added a commit that referenced this issue May 13, 2024
@lucidrains
Copy link
Owner

lucidrains commented May 13, 2024

@santisy want to try turning this on (following @Saltychtao 's solution)

let me know if it helps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants