In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
pip install git+https://github.com/davidbau/baulab

# Wasserstein GAN demo

While the original GAN paper chose to optimize $G$ by minimizing the JS divergence $\text{JS}(G, T)$, the Wasserstein GAN paper by Arjovsky (2017) https://arxiv.org/pdf/1701.07875.pdf proposes minimizing a different distance metric, the Wasserstein distance between the distribution $W(G,T)$.

This notebook illustrates the Wasserstein metric and how it behaves differnetly from the JS divergence.

First, we just set up the code.

In [None]:
from baulab import PlotWidget, Range, Numberbox, show
import torch, math

def gaussian_model(x, mean, variance):
    return torch.exp(-(x - mean) ** 2 / variance / 2) / torch.sqrt(2 * torch.pi * variance)

def make_gaussian_model(mean, variance):
    mean, variance = [torch.tensor(d) for d in [mean, variance]]
    fn = lambda x: gaussian_model(x, mean, variance)
    fn.mean = mean
    fn.variance = variance
    return fn

def make_error_fn(mean, variance):
    mean, variance = [torch.tensor(d) for d in [mean, variance]]
    fn = lambda x: torch.erf((x - mean) / torch.sqrt(variance))
    fn.mean = mean
    fn.variance = variance
    return fn

def center_search(length, start):
    for i in range(length):
        if start - 1 - i >= 0:
            yield start - 1 - i
        if start + i < length:
            yield start + i

# Quick search for a visually-pleasing root, near the halfway point between max and min
def grid_root(x, vals):
    imax, imin = int(vals.argmax(0)), int(vals.argmin(0))
    sign = vals.sign()
    for i in center_search(len(sign), (imax + imin) // 2):
        if i > 0 and sign[i - 1] != sign[i]:
            root = (x[i] * vals[i-1] - x[i-1] * vals[i]) / (vals[i-1] - vals[i])
            sign = (sign[i - 1] - sign[i]).sign()
            return root, sign
    return None, sign[len(sign) // 2]

# Solve for a near-optimal Lipshitz D for two Gaussians
def make_lipshitz_fn(x, errs, dens):
    corner, sign = grid_root(x, errs)
    zero, sign2 = grid_root(x, dens)
    if corner is None:
        return lambda x: - sign * x - (zero or 0.0)
    else:
        return lambda x: sign * ((x - corner).abs() - (zero - corner).abs())

def discriminator(x, p_true, p_gen):
    return p_true(x) / (p_true(x) + p_gen(x))

loghalf = math.log(0.5)

def value_fn(x, p_true, p_gen, rule=None):
    d = discriminator(x, p_true, p_gen)
    constant = loghalf if rule == 'js' else 0
    return p_true(x) * (torch.log(d) - constant) + p_gen(x) * (torch.log((1 - d).clamp(1e-50)) - constant)

def js_redraw_rule(fig, gen_mean=0.0, gen_variance=1.0, true_mean=3.0, true_variance=0.5, rule='v'):
    if len(fig.axes) == 1:
       ax2 = fig.axes[0].twinx()
    [ax, ax2] = fig.axes
    x_range = torch.arange(-3, 6, 0.1, dtype=torch.double)
    true_model = make_gaussian_model(true_mean, true_variance)
    gen_model = make_gaussian_model(gen_mean, gen_variance)
    ax2.clear()
    ax2.set_ylim(-3.05, 3.05)
    ax2.set_ylabel('$JS(G, T)$ divergence' if rule == 'js' else 'cross-entropy $V$')
    constant = '-\log{0.5}' if rule == 'js' else ''
    ax2.fill_between(x_range, value_fn(x_range, true_model, gen_model, rule=rule),
                     alpha=0.5, color='red', lw=0,
                     label=f'${rule.upper()} = E_T[\log(D){constant}] + E_G[\log(1-D){constant}]$')
    ax.clear()
    ax.set_title(('JS divergence: minimize JSD' if rule == 'js' else 'The GAN game: minimize $V$') + ' by fooling $D$')
    ax.plot(x_range, gen_model(x_range), linewidth=3, label='$P_G$ generator distribution')
    ax.plot(x_range, true_model(x_range), label='$P_T$ target distribution')
    ax.plot(x_range, discriminator(x_range, true_model, gen_model), label='$D$ discriminator classifies T')
    ax.set_ylabel('probability density $P_G$, $P_T$, probability $D$')
    ax.set_ylim(-0.05, 1.05)
    ax.legend(*(a + b for a, b in zip(*[axN.get_legend_handles_labels() for axN in [ax, ax2]])),
               loc='upper left')
    
def wasserstein_redraw_rule(fig, gen_mean=0.0, gen_variance=1.0, true_mean=3.0, true_variance=0.5):
    if len(fig.axes) == 1:
       ax2 = fig.axes[0].twinx()
    [ax, ax2] = fig.axes
    x_range = torch.arange(-3, 6, 0.1, dtype=torch.double)
    true_model = make_gaussian_model(true_mean, true_variance)
    gen_model = make_gaussian_model(gen_mean, gen_variance)
    true_erf = make_error_fn(true_mean, true_variance)
    gen_erf = make_error_fn(gen_mean, gen_variance)
    lipshitz = make_lipshitz_fn(
        x_range, true_erf(x_range) - gen_erf(x_range),
        true_model(x_range) - gen_model(x_range))
    ax2.clear()
    ax2.set_ylim(-3.05, 3.05)
    ax2.plot(x_range, lipshitz(x_range), color='green', label='$D$ Lipshitz discriminator "price"')
    ax2.fill_between(x_range, lipshitz(x_range) * (true_model(x_range) - gen_model(x_range)),
                     alpha=0.5, color='purple', lw=0, label='$V = E_T[D] - E_G[D]$ earthmover dist "cost"')
    ax2.set_ylabel('Lipschitz score $D$, earthmover distance $V$')

    ax.clear()
    ax.set_title('Wasserstein GAN game: minimize the V Earthmover distance')
    ax.plot(x_range, gen_model(x_range), linewidth=3, label='$P_G$ generator distribution')
    ax.plot(x_range, true_model(x_range), label='$P_T$ target distribution')
    ax.set_ylim(-0.05, 1.05)
    ax.set_ylabel('probability density $P_G$, $P_T$')
    
    ax.legend(*(a + b for a, b in zip(*[axN.get_legend_handles_labels() for axN in [ax, ax2]])),
               loc='upper left')

plot = PlotWidget(wasserstein_redraw_rule, figsize=(8, 5), dpi=100)
mean_slider = Range(value=plot.prop('gen_mean'), min=-6, max=6, step=0.01)
mean_input = Numberbox(value=plot.prop('gen_mean'))
var_slider = Range(value=plot.prop('gen_variance'), min=0.1, max=1.9, step=0.01)
var_input = Numberbox(value=plot.prop('gen_variance'))
show([[plot],
      [show.style(textAlign='right'), 'Generator Mean', show.style(flex=2), mean_slider, mean_input],
      [show.style(textAlign='right'), 'Generator Variance', show.style(flex=2), var_slider, var_input]])


## Earthmover Distance

Wasserstein distance between distributions is also called the "Earthmover" distance - if you think of the distributions $G$ and $T$ as big piles of dirt, then $W(G, T)$ is the minimium average distance that a piece of dirt needs to move to transport one pile to the other.  Deciding exactly which piece of dirt to move from one spot to another is called a transport plan, and $W(G, T)$ is a minimum tranport cost over all possible tranport plans.

So you can see how $W(G, T)$ defines a distance metric from $G$ to $T$.  Arjovsky argues that, for learning $G$, Wasserstein is a better distance metric than JS divergence, and the demo below gives you some intuition for why: JS divergence saturates when the distributions do not overlap well with other (they become a fixed "100% different"), while Wasserstein continues to measure differences smoothly even when the distributions do not overlap (more distant earth piles can become "more and more costly" to move).

## The Transport Price Function D

One problem is that $W$ seems totally uncomputable!  How do we come up with the optimal transport plan?  It turns out that there are some tricks with optimal transport.  There is an amazing theorem by Kantorovich-Rubinstein (see this [proof summary by John Thickstun](https://courses.cs.washington.edu/courses/cse599i/20au/resources/L12_duality.pdf)) that says that an optimal transport can be specified by assigning a number $D: X \rightarrow \mathbb{R}$ to every data point where $D(x) - D(y) \leq |x - y|$; then the optimal transport is found by finding the $D$ that maximizes $V = E_T[D(x)] - E_G[D(x)]$; and when you do that, $V=W(G,T)$ is the cost, i.e., the Wasserstein distance.


You can think of $D$ like a pricing plan, where the highest $D(x)$ are assigned to the most valuable pieces of earth, where the target must be piled high, and the lowest $D(x)$ are assigned to the last valuable pieces of earth, where the source must be moved away, and $D(x) - D(y)$ is the amount we are willing to pay to move earth from $y$ to $x$.  Some pricing schemes would not offer enough money to get every piece of dirt moved - you cannot move a piece of earth from $y$ to $x$ unless the prices make it possible to collect $|x - y|$ to move it.  For example, if you set all prices the same, you cannot move any earth.

Amazingly, the Kantorovich-Rubinstein theorem says that there is always some pricing plan that will pay for moving every piece of dirt while still never overpaying for any individual piece, i.e., constraining $D(x) - D(y) \leq |x - y|$.  Under ay pricing plan, the total amount paid will be $V = E_T[D(x)] - E_G[D(x)]$, and it turns out that the way to move every piece of earth is to find the pricing plan that pays the most, while never overpaying.

## The Lipshitz constraint on D

The constraint $D(x) - D(y) \leq |x - y|$ is the <em>Lipshitz</em> condition, which is to say $D$ is flat enough so that the derivative of $D$ never exceeds 1.

In the optimal case, where $D$ is chosen to maximize the transport capacity, its derivative will be maximized and go to $\pm 1$, so in the one-dimensional case, it looks like a zig-zag curve like an absolute-value function.  The slope heads upward in the direction that earth needs to move towards the target.

In the interactive charts, we show the optimal $D$ as a green line.

When training a neural network to approximate $D$, it is possible but costly to enfoce the Lipshitz constraint strictly (it can be done by clamping model weights).  In practice, we use a soft constraint, training $D$ to tend to be flat by adding regularizer that puts a penalty on having a high derivative.


## Comparing Wasserstein to JS divergence

See the demo below to compare Wasserstein to JS divergence; you can see some similarities and differences.

  * When $G$ and $T$ are close to each other, both metrics behave similarly.
  * But when $G$ and $T$ are far apart, Wasserstein distance continues to grow while JS divergence saturates.
  
Why does this mean that Wasserstein should behave better than the original JS divergence loss when training a GAN?

In [None]:
plot = PlotWidget(wasserstein_redraw_rule, figsize=(6, 5))
plot_js = PlotWidget(js_redraw_rule, rule='js', figsize=(6, 5),
                     gen_mean=plot.prop('gen_mean'), gen_variance=plot.prop('gen_variance'))
mean_slider = Range(value=plot.prop('gen_mean'), min=-5, max=5, step=0.01)
mean_input = Numberbox(value=plot.prop('gen_mean'))
var_slider = Range(value=plot.prop('gen_variance'), min=0.1, max=1.9, step=0.01)
var_input = Numberbox(value=plot.prop('gen_variance'))
show([[plot, show.style(margin=0), plot_js],
      [show.style(textAlign='right'), 'Generator Mean', show.style(flex=2), mean_slider, mean_input],
      [show.style(textAlign='right'), 'Generator Variance', show.style(flex=2), var_slider, var_input]])