Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarification on the KL Divergence term in the Generator loss for SHVN->MNIST model #42

Closed
hsm207 opened this issue Jan 14, 2018 · 1 comment

Comments

@hsm207
Copy link
Contributor

hsm207 commented Jan 14, 2018

I have a question about the _compute_kl function in the class COCOGANDAContextTrainer. The following are the relevant parts of the code:

  def _compute_kl(self, mu, sd):
    mu_2 = torch.pow(mu, 2)
    sd_2 = torch.pow(sd, 2)
    encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
    return encoding_loss

This function was used in gen_update:

    for i, lt in enumerate(lt_codes):
      encoding_loss += 2 * self._compute_kl(*lt)
    total_loss = hyperparameters['gan_w'] * ad_loss + \
                 hyperparameters['kl_normalized_direct_w'] * encoding_loss + \
                 hyperparameters['ll_normalized_direct_w'] * (ll_loss_a + ll_loss_b)

My question is how did you derive the formula to compute the KL divergence term?

I thought it was based on the Auto-Encoding Variational Bayes paper which has the following parts:

image

and in Appendix B:

image

I note the following differences between the code and the paper (Auto-Encoding Variational Bayes):

  1. The KL divergence term is multiplied by 2 instead of 1/2. I guess this does not matter much since it just rescales the loss.

  2. There is no - 1 in the encoding_loss. Did you choose not to include this term because it will not change the optimum point anyway?

@mingyuliutw
Copy link
Owner

Yes, those were my considerations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants