# f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization

In the [last note](Week%201.%20Generative%20Adversarial%20Networks.ipynb), we saw how to learn a distribution $p_g$ that approximated some unknown distribution $p_d$. We did this by introducing two neural networks. Namely, the generator and the discriminator, which were both neural networks. The discriminator and the generator played a game where the discriminator were to distinguish between real data sampled from $p_d$ and fake data sampled from the generator. The generator, on the other hand, were to fool the discriminator by transforming random noise $z_i \sim \mathcal{N}(0, 1)$ into something indistinguishable from the real data.

In essence, fooling the discriminator corresponds to minimizing the "distance" between $p_d$ and $p_g$ for some notion of distance. We presented the [Kullbach-Liebler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) as one such measure and argued that we were not able to calculate this measure for several reasons. The divergence is defined as

$$
D_{KL}(P || Q) = \int_\mathcal{X} p(x) \log \frac{p(x)}{q(x)} dx \quad \quad (1)
$$

over the domain $\mathcal{X}$ for continuous probability density functions $p$ and $q$. As we could not calculate this directly, we introduced the discriminator in order to be able to opgimize the generator. More formally, the discriminator and the generator played the following game:

$$
\text{min }_g\text{ max}_d\; \mathbb{E}_{x\sim p_d}[\text{log d(x)}] + \mathbb{E}_{z \sim p_z}[1-\text{log(d(g(z))}]. \quad \quad (2)
$$

When playing this game, [[3]](#References) showed that under certain constraints, the optimal
discriminator $d^\ast_g(x)$ is the following

$$
d^\ast_g(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}. \quad \quad (3)
$$

Further more, it turned out that playing this game would approximately minimize the [Jensen-Shannon divergence](https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence)

$$
D_{JS}(P || Q) = D_{KL}(P ||P + Q) + D_{KL}(Q || P + Q). \quad \quad (4)
$$

The goal of this note is to generalize Eqn. (4) to a broader class of divergences. Conceptually, this will allow us to "change the cost function" (which is essentially Eqn. (2)). It will also generalize Eqn. (3) such that the optimal discriminator will get different forms according to which divergence the GAN setup is minimizing.

A simple motivation for why we may want to change the objective of the training is the gif in Figure 1, which is very similar to the one we saw in the previous note. The only difference is that the real data is now generated from a biomodal distribution. The figure shows how the learned distribution $p_g$ (which is now underspecified since it can only represent "simple" uni-modal distribution) may learn to simply span the two modes of the data. It may be that we would rather want the generater to just focus on one of the modes of the true distribution. In that case we could hope that a change in traning objective would favour such a property.

<div style="text-align: center;">
    <img src="gifs/dual-mode/gif.gif">
    <strong>Figure 1</strong> Animation of how $p_g$ converges towards $p_d$.
</div>


## Generalizing GANs to f-divergences

As stated above, the goal of this note is to generalize GANs to minimize a broader class of divergences. Specifically, we will show that we can generalize GANs to minimize any [f-divergence](https://en.wikipedia.org/wiki/F-divergence). This results was shown in [[1]](#References) and will be reproduced here. In order to do so, we need three ingredients. 

1. Not surprisingly, the first ingredient is the f-divergence. 
2. The second ingredient is the [convex conjugate](https://en.wikipedia.org/wiki/Convex_conjugate) of convex lower semi-continuous functions.
3. [Jensens Inequality](https://en.wikipedia.org/wiki/Jensen%27s_inequality) is the third and final ingredient.

The structure of the following sections is as follows. First, we describe what the three ingredients are and how we will use them. Second, we will use the ingredients to show a lower bound on the f-divergence, which is very close to the GAN objective in (2). Finally, we will discuss how to use this lower bound to minimize different f-divergences.  

### f-divergences
Until now, we have seen two divergences. First, we saw the KL-divergence (Eqn. (1)) and second we saw the JS-divergence (Eqn. (4)). It turns out that these divergences are both part of a broughter class of divergences called the f-divergences. f-divergences are characterized by a convex and continuous function $f$ where $f(1) = 0$. The f-divergence is defined as follows

$$
D_f(P||Q) = \int q(x) f\left( \frac{p(x)}{q(x)} \right) \quad \quad (5)
$$

Let's verify that the KL-divergence is actually a part of this family. Let $f(u) = u \log u$ and pug it into Eqn. (5):

$$
\begin{align}
D_f(P || Q) &= \int q(x) \left( \frac{p(x)}{q(x)} \log \frac{p(x)}{q(x)} \right)\\
&= \int p(x) \log \frac{p(x)}{q(x)} \\
&= D_{KL}(P || Q)
\end{align}
$$

Similarly, if we let $f(x) = u \log u + (u + 1) \log (u + 1)$, we recover the JS-divergence (up to a constant). For more divergences, please see Table 1 further down the note.

### The Convex Conjugate
The convex conjugate is defined over convex lower semi-continuous functions, but in the f-GAN setting, you can just think of it as convex and continuous functions. For any function $f$ in the set of all these convex and continuous functions, the _convex conjugate_ $f^\ast$ is defined as

$$
f^\ast(t) = sup_{u \in dom_f} \{ut - f(u)\}.
$$

In the definition of $f^\ast$, we use the [supremum](https://en.wikipedia.org/wiki/Infimum_and_supremum). Intuitively, this can be thought of as the maximum. $f^\ast$ has two nice properties. 

1. $f^\ast$ it self is convex and continuous, so it has a convex conjugate $f^{\ast\ast}$  
2. $f$ and $f^\ast$ are each others dual, which means that $f = f^{\ast\ast}$

Note how 2. allows us to reformulate $f$ in terms of it convex conjugate: 

$$
f(x) = sup_{t \in dom_{f^\ast}} \left( xt - f^\ast(t)\right). \quad \quad (6)
$$

We will use these propertis in a moment.

### Jensen's Inequality
This inequality is very simple but also very powerful. It says that for any convex function $f$, the following inequality holds:

$$
\mathbb{E} [ f(x) ] \geq f( \mathbb{E} x )
$$

We will use this inequality to lower bound f-divergences by using the fact the $f^\ast$ is convex.

### Lower bounding f-divergences
Lets concider the f-divergence and use the ingredients above to make a lower bound the divergence.

$$
\begin{align}
D_f (P||Q) &= \int_\mathcal{X} q(x) f\left( \frac{p(x)}{q(x)} \right) \\
&= \int_\mathcal{X} q(x) sup_{t \in dom_{f^\ast}} \left\{ \frac{p(x)}{q(x)} t - f^\ast(t)\right\}\\
&= \mathbb{E}_{x \sim q} \left[ sup_{t \in dom_{f\ast}} \left\{ \frac{p(x)}{q(X)} t - f^\ast(t)\right\} \right]\\
&\geq sup_{t \in dom_{f^\ast}} \left\{ \mathbb{E}_{x \sim q} \left[ \frac{p(x)}{q(x)} t - f^\ast(t)\right] \right\} \\
&= sup_{t \in dom_{f^\ast}} \left\{ \int_\mathcal{X} q(x) \frac{p(x)}{q(x)} t dx -\mathbb{E}_{x \sim q} \left[ f^\ast(t)\right] \right\}\\
&= sup_{t \in dom_{f^\ast}} \left\{ \mathbb{E}_{x \sim p} \left[t\right] -\mathbb{E}_{x \sim q} \left[ f^\ast(t)\right] \right\}\\
& \vdots  \\
&\geq sup_{D \in \mathcal{D}} \left\{ \mathbb{E}_{x \sim p} \left[D(x)\right] -\mathbb{E}_{x \sim q} \left[ f^\ast(D(x))\right] \right\} \quad \quad (7)
\end{align}
$$

Where $\mathcal{D}$ is any class of functions. The first step utilizes the dual formulation of $f$ (Eqn. (6)). Second step uses the definition of expectations. Then, the inequality is introduced by Jensens Inequality. Afterwards, we realize that the $q(x)$s cancels and finally, we use the definition of expectations again.
<span style="background: yellow;">Please note that it is not clear from the article how to do the last step.(I don't know if I want to include this)</span>

This is a lot of math but the take away is that we now have a lower bound on the f-divergences that looks a lot like the original GAN setup in Eqn. (2) if we let $D(x)$ be our discriminator and think of $q$ as our generator.

As a side note, it turns out that under certain mild conditions [[2]](#References), the bound above is tight, when 

$$
D_{opt}(x) = f'\left( \frac{p(x)}{q(x)} \right). 
$$

This means that if $D(x) = f'\left(\frac{p(x)}{q(x)}\right)$, then the f-divergence is exactly equal to the last line in (7).

Also, we should note that this result was originaly presented in [[2]](#References) along with a way to estimate the actual value of any f-divergence (called Variational Divergence Estimation). [[1]](#References) on the other hand extends this to actually minimize the f-divergences by the use of GANs. We shall see how in the following section.  

## The actual GAN
To minimize the f-divergence, we use the same trick as we did in the last note. We let $D_\omega(x)$ in Eqn. (7) be a neural network parameterized by $\omega$ and let the distribution $q$ be defined by another neural network $G_\theta$ applied to some random noise from some simple distribution $Z$. In that case, we get the following objective:

$$
F(\omega, \theta) = \mathbb{E}_{x\sim p_d}[{D_\omega}(x)] - \mathbb{E}_{x\sim p_g} [f^\ast({D_\omega}(x))] \quad = \quad \mathbb{E}_{x\sim p_d}[{D_\omega}(x)] - \mathbb{E}_{z \sim p_z} [f^\ast({D_\omega}({G_\theta}(z)))]
$$

What we need to notice here is that we now have an objective to train a GAN with for each and every f-divergence. We simply do gradient descent on $F(\omega, \theta)$ alternating between updating $\omega$ and $\theta$. In [[1]](#References) they present an algorithm for training these networks, which is very similar to the original algorithm and therefor not very interesting. Furthermore, they manage to show that under certain smoothness assumptions about the norm of the gradient of $F(\omega, \theta)$, then their algorithm converges to a saddelpoint.

Below, we show a couple of f-divergences along with their convex conjugates below.


<table style="width: 85%; margin-left:auto; margin-right:auto; text-align: left;">
    <col width="15%">
    <col width="45%">
    <col width="25%">
    <col width="15%">
    <tr>
        <th>Name</th>
        <th>$D_f(P || Q)$</th>
        <th>Generator $f(u)$</th>
        <th>$D_{opt}(x)$</th>
    </tr>
    <tr>
        <td>Kullback-Leibler</td>
        <td>$\int p(x) \log \frac{p(x)}{q(x)} dx$</td>
        <td>$u \log u$</td>
        <td>$1 + \log \frac{p(x)}{q(x)}$</td>
    </tr>
    <tr>
        <td>Reverse KL</td>
        <td>$\int q(x) \log \frac{q(x)}{p(x)} dx$</td>
        <td align="left">$- \log u$</td>
        <td>$- \frac{q(x)}{p(x)}$</td>
    </tr>
    <tr>
        <td>Jensen-Shannon</td>
        <td>$\frac{1}{2} \int p(x) \log \frac{2p(x)}{p(x) + q(x)} + q(x) \log \frac{2q(x)}{p(x) + q(x)} dx$</td>
        <td>$u \log u - (u + 1) \log \frac{u + 1}{2}$</td>
        <td>$\log \frac{2p(x)}{p(x) + q(x)}$</td>
    </tr>
    <tr>
        <td>GAN</td>
        <td>$\int p(x) \log \frac{2p(x)}{p(x) + q(x)} + q(x) \log \frac{2q(x)}{p(x) + q(x)} dx - \log (4)$</td>
        <td>$u \log u - (u + 1) \log (u + 1) $</td>
        <td>$\log \frac{p(x)}{p(x) + q(x)}$</td>
    </tr>
</table>
<p style="text-align: center"><b>Table 1</b> Overview of different f-divergences<p>

# Code

## Learning a multimodal distribution from an underspecified GAN

In this section we use a gan to translate $N(0, 1)$ to a biomodal distribution composed of two normal distributions; $N(4, 1)$ and $N(8,1)$. The network that we use is identical to the one we used last week: <br><br>


<div style="text-align: center">
<img src="figs/week1_architecture.png" style="width: 80%;">
<b>Figure 2: </b>A schematic of the simple model for this example.  
</div>

The following animation show how during training the generated data from the fake data distribution $p_g$ approaches the real data distribution $p_d$. However, since the generator is underspecified, in the sence that it cannot model a biomodal distribution with only the two parameters $w$ and $b$, it cannot learn the real distribution. Notice also how the decision boundary of the discriminator ends up spiking right between the two modes, as there are more fake data in that region than real data. A final thing to notice in the animation is how much the generator and discriminator depends on each other, e.g., when the generator moves its data into a new region, the discriminator changes to reflect that, which makes the generator move again, etc. This is exactly what we would expect from the formulation of the game that they are playing.

<br/>
<br/>

<div style="text-align: center">
<img src="gifs/dual-mode/gif.gif">
<b>Figure 3: </b>Animation of how $p_g$ converges towards $p_d$.
</div>


<b>Abstract Idea of Code: </b><br>
The code is implemented using Keras. The code uses Keras to define a generator and discriminator neural network. For the generator we also define an optimizer so we can train it on real/fake images. The generator needs to use the discriminator as a cost function. We do this by freezing the discriminators weights $\text{discriminator.trainable = False}$ so we can define an optimizer for the generator without having it change the discriminator. 

All plots are made using matplotlib. 

In [None]:
%matplotlib notebook 
import numpy as np
import matplotlib.pyplot as plt

from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.optimizers import Adam

n       = 1000
fig, ax = plt.subplots(figsize=(5, 3))

def plot(): 
    ax.cla()
    ax.hist(generator.predict_on_batch(noise) , alpha=0.3, label="Fake p_g")
    ax.hist(real_data , alpha=0.3, label="Real p_d")
    xs = np.arange(-8, 16, 0.1)
    pred = discriminator.predict_on_batch(xs)
    ax.plot(xs, pred*250, label="D(x)")
    ax.legend(loc=3)
    ax.set_ylim([0, 280])
    ax.set_xlim([-9, 13])
    
    fig.suptitle("Iteration: [%i / %i]"%(i, iterations))
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig.canvas.draw()
    plt.savefig("gifs/dual-mode/%i.png"%i)
    plt.pause(.01)

# number of samples. 
n = 1000
iterations = 200
repeat = 10

# define generator
generator = Sequential()
generator.add(Dense(1, input_dim=1)) # one neuron except bias, don't have the relu activation!
#generator.add(Dense(1))

# define discrimiantor
discriminator = Sequential()
discriminator.add(Dense(10, input_dim=1, activation="relu")) # non linearity has some use here. 
discriminator.add(Dense(1,  activation="sigmoid"))
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.005))

# Combine models
gan = Sequential()
gan.add(generator)
discriminator.trainable = False # from the gan model we freeze discriminator to use it as loss function
gan.add(discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.005))

# Do several updates, plot some interval of iterations fake / real data. 
for i in range(iterations):
    noise     = np.random.normal(0, 1, size=(n, 1))
    fake_data = generator.predict_on_batch(noise)
    
    x1        = np.random.normal(4, 1, size=(n//2, 1))
    x2        = np.random.normal(8, 1, size=(n//2, 1))
    real_data = np.c_[x1, x2].reshape(-1, 1)
    
    disc_X = np.concatenate((real_data, fake_data), axis=0)
    disc_y = np.concatenate((np.zeros(n), np.ones(n)), axis=0) # flip labels since we min instead of max. 
    
    plot()
    for j in range(repeat): discriminator.train_on_batch(x=disc_X, y=disc_y)
    for j in range(repeat): gan.train_on_batch(x=noise, y=np.zeros(n))


# References
[[1]](http://papers.nips.cc/paper/6066-f-gan-training-generative-neural-samplers-using-variational-divergence-minimization.pdf) Nowozin, S., Cseke, B. and Tomioka, R., 2016. f-gan: Training generative neural samplers using variational divergence minimization. In Advances in Neural Information Processing Systems (pp. 271-279).

[[2]](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=5605355) Nguyen, X., Wainwright, M.J. and Jordan, M.I., 2010. Estimating divergence functionals and the likelihood ratio by convex risk minimization. IEEE Transactions on Information Theory, 56(11), pp.5847-5861.

[[3]](https://papers.nips.cc/paper/5423-generative-adversarial-nets) Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A. and Bengio, Y., 2014. Generative adversarial nets. In Advances in neural information processing systems (pp. 2672-2680).