Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about KL computation #50

Closed
SerezD opened this issue Jan 12, 2024 · 1 comment
Closed

Question about KL computation #50

SerezD opened this issue Jan 12, 2024 · 1 comment

Comments

@SerezD
Copy link

SerezD commented Jan 12, 2024

In distributions.py, the KL is computed as indicated in section 3.2 of the paper (residual normal distributions, Equation 2):

def kl(self, normal_dist):
    term1 = (self.mu - normal_dist.mu) / normal_dist.sigma
    term2 = self.sigma / normal_dist.sigma

    return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2) 

What I don't understand is, why you compute term2 = self.sigma / normal_dist.sigma. Shouldn't it be:
term2 = self.sigma - normal_dist.sigma?

@SerezD
Copy link
Author

SerezD commented Jan 12, 2024

Nevermind, I got it:

checking section 3.2 of the paper, paragraph: Residual Normal Distributions.

In the code self.mu, self.sigma are the parameters of the posterior distribution
prior.mu, prior.sigma are the parameters of the prior distribution

p(z_i|z_{l<i}) is the prior, defined as N(μ_p, σ_p) where both params are conditioned on all z_{l<i}

q(z_i|z_{l<i}, x) is the distribution from the encoder (self), defined as:

q = N(μ_p + Δμ_q, σ_p * Δσ_q), where Δμ_q, Δσ_q are the relative shift and scale given by the hierarchical nature of the distribution.

So basically, self.mu and self.sigma are the parameters of the posterior:

self.mu = μ_p + Δμ_q
self.sigma = σ_p * Δσ_q

The KL Loss between two normal distributions a = N(μ_1, σ_1), b = N(μ_2, σ_2) is given by:

 0.5 [ (μ_2 - μ_1)**2 / σ_2**2 ] + 0.5 (σ_1**2 / σ_2**2) - 0.5 [ln(σ_1**2 / σ_2**2)] - 0.5

proof: https://statproofbook.github.io/P/norm-kl.html
In our case: μ_1 = self.mu; μ_2 = prior.mu; σ_1 = self.sigma; σ_2 = prior.sigma

So the three terms in the formula above become:
1. 0.5 [ (μ_p - μ_p + Δμ_q)**2 / σ_p**2] = 0.5 [ Δμ_q**2 / σ_p**2]
2. 0.5 ((σ_p * Δσ_q)**2 / σ_p**2) = 0.5 [Δσ_q**2]
3. 0.5 [ln((σ_p * Δσ_q)**2 / σ_p**2)] = 0.5 ln(Δσ_q**2)

The final formula is thus the one written in Equation 2 and (in the code):
Δμ_q = self.mu - prior.mu
Δσ_q = self.sigma / prior.sigma

@SerezD SerezD closed this as completed Jan 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant