In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
pip install git+https://github.com/davidbau/baulab

# Non-Saturating GAN demo

Goodfellow also notices that, in practice, simpling making the job harder for the discriminator as measured by $V$ is not that great.  So in the [original GAN paper](https://arxiv.org/pdf/1406.2661.pdf), he also suggests using a <em>saturating</em> loss.  This use of a different loss was the first in a series of clever alternate loss functions that improved GAN training.

Here we let you play with the saturating loss.
  

In [None]:
from baulab import PlotWidget, Range, Numberbox, show
import torch, math

def gaussian_model(x, mean, variance):
    return torch.exp(-(x - mean) ** 2 / variance) / torch.sqrt(2 * torch.pi * variance)

def make_gaussian_model(mean, variance):
    return lambda x: gaussian_model(x, torch.tensor(mean), torch.tensor(variance))

def discriminator(x, p_true, p_gen):
    return p_true(x) / (p_true(x) + p_gen(x))

def value_fn(x, p_true, p_gen):
    d = discriminator(x, p_true, p_gen)
    return p_gen(x) * torch.log(d)

def nonsaturating_redraw_rule(fig, gen_mean=0.0, gen_variance=1.0, true_mean=3.0, true_variance=0.5):
    if len(fig.axes) == 1:
       ax2 = fig.axes[0].twinx()
    [ax, ax2] = fig.axes
    x_range = torch.arange(-3, 6, 0.1)
    true_model = make_gaussian_model(true_mean, true_variance)
    gen_model = make_gaussian_model(gen_mean, gen_variance)
    ax2.clear()
    ax2.set_ylim(-8.2, 0.2)
    ax2.set_ylabel('non-saturating $V_{\mathrm{nonsaturating}}$')
    ax2.fill_between(x_range, value_fn(x_range, true_model, gen_model),
                     alpha=0.5, color='red', lw=0, label='$V_{\mathrm{nonsaturating}} = E_G[\log(D)]$')
    ax.clear()
    ax.set_title('Non-Saturating GAN game: push the red curve up as much as possible')
    ax.plot(x_range, gen_model(x_range), linewidth=3, label='$P_G$ generator distribution')
    ax.plot(x_range, true_model(x_range), label='$P_T$ target distribution')
    ax.plot(x_range, discriminator(x_range, true_model, gen_model), label='$D$ discriminator classifies T')
    ax.set_ylabel('probability density $P_G$, $P_T$, probability $D$')
    ax.set_ylim(-0.05, 1.05)
    ax.legend(*(a + b for a, b in zip(*[axN.get_legend_handles_labels() for axN in [ax, ax2]])),
               loc='upper left')
    
plot = PlotWidget(nonsaturating_redraw_rule, figsize=(8, 5), dpi=100)
mean_slider = Range(value=plot.prop('gen_mean'), min=-5, max=5, step=0.01)
mean_input = Numberbox(value=plot.prop('gen_mean'))
var_slider = Range(value=plot.prop('gen_variance'), min=0.1, max=1.9, step=0.01)
var_input = Numberbox(value=plot.prop('gen_variance'))
show([[plot],
      [show.style(textAlign='right'), 'Generator Mean', show.style(flex=2), mean_slider, mean_input],
      [show.style(textAlign='right'), 'Generator Variance', show.style(flex=2), var_slider, var_input]])

## What is the non-saturating loss

By choosing $G$ to maximize $V_{\text{nonsaturating}} = E_G(\log D(x))$, we are saying "maximize certainty that a generated $x$ is real".

This differs from the original goal of minimizing $V = k + E_G(\log 1 - D(x))$, which says "minimize certainty that the generated $x$ is fake."

The non-saturating loss is effectively teaching $G$ to actively seek out examples that directly fool $D$, rather than just looking for examples that make $D$ unsure.

## The performance difference

Notice that in the original simple GAN formulation, $V$ goes to near-constant-zero when $G$ is very far from $T$, which makes the gradients very close to zero, so it very hard to learn $G$ at the beginning of training.

But as you can see, instead of saturating at zero when $G$ is far from $T$, the non-sturating loss grows very negative and continues to vary (at least in the Gaussian case).   That means it should be possible to train $G$ to get closer to $T$ even when it starts far away.