## Approximating KL divergence 

Reference:

http://joschu.net/blog/kl-approx.html

https://towardsdatascience.com/approximating-kl-divergence-4151c8c85ddd

### Brief introduction 

Why do we want an approximation? 

* The KL divergence may not have an analytical solution. Like the Gaussian mixture distributions. 
* The integration is computationally expensive. 
* It is eaiser to cache the old log-probability but not the whole distribution. 

#### Approximate the forward KL divergence 

The simplest approximation of the KL divergence is to take the mean over all samples.  

$$\sum_x p(x) \log \frac{p(x)}{q(x)} \approx \frac{1}{N}\sum_{i=1}^N \frac{p(x_i)}{q(x_i)}$$

When $N=1$, $\log\frac{p(x_i)}{q(x_i)} = \log r$ is the simplest approximation. 

Next we can apply the important sampling idea,

$$
\begin{aligned}
\sum_x p(x) \log \frac{p(x)}{q(x)} = &\sum_x q(x)\frac{p(x)}{q(x)} \log \frac{p(x)}{q(x)}\\
=& \sum_x q(x) r\log r \\
\approx& q(x_i) r_i \log r_i  
\end{aligned}
$$

However this approximation is of high variance. Also, half of the samples are negative (when $p(x_i) < q(x_i)$), but KL divergence is always non-negative. One way to prevent the negative approximation is to construct a negative correlated construct. 

$$
\begin{aligned}
\sum_x p(x) \log \frac{p(x)}{q(x)} =& \sum_x q(x) r\log r \\
\approx& \sum_x q(x) r\log r + \lambda(r-1)
\end{aligned}
$$

This is a little bit like adding a constraint $r - 1 \geq 0$ into a Lagrangian. As we know $\log x \leq 1-x$, we can simply choose $\lambda = 1$ to ensure the approximation is $\geq 0$, Thus we can write, 

$$
\begin{aligned}
\sum_x p(x) \log \frac{p(x)}{q(x)} =& \sum_x q(x) r\log r \\
\approx& \frac{1}{N}\sum_{i=1}^N q(x) r_i\log r_i + \lambda(r_i-1)
\end{aligned} 
$$

Sometimes, we would sample from $q(x)$, then the approxiamtion would be $r_i\log r_i + \lambda(r_i-1)$.



### A simple example 

In [2]:
import numpy as np
from scipy.stats import norm 

import matplotlib.pyplot as plt 
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [12]:
eps_ =1e-12
mu1, mu2, sig = 0, .1, 1
p = norm(loc=0,  scale=1)
q = norm(loc=.1, scale=1)
x = q.rvs(size=(10_000_000,))
KL_true = np.log(sig/sig+eps_) + (sig**2 + (mu1-mu2)**2)/(2*sig**2) - 1/2
print(KL_true)

logr = p.logpdf(x) - q.logpdf(x)
r    = np.exp(logr)
k    = r*logr + (r-1)
k2   = (r-1) - logr

print(f'Bias: {(k.mean() - KL_true)/KL_true}, Var: {k.std()/KL_true}')
print(f'Bias: {(k2.mean() - KL_true)/KL_true}, Var: {k2.std()/KL_true}')

0.005000000001000093
Bias: -0.00100530397138113, Var: 40.33597307511491
Bias: 0.00044284607723210385, Var: 1.4167130974890887


In [14]:
import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(p, q)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = (logr.exp()*logr + (logr.exp()-1)) * q.log_prob(x).exp()
k3 = (logr.exp() - 1) - logr 
for k in (k1, k2, k3):
    print((k.mean() - truekl) / truekl, k.std() / truekl)


true tensor(0.0050)
tensor(0.0127) tensor(20.0023)
tensor(-1.1449) tensor(6.9880)
tensor(0.0003) tensor(1.4177)
