# Wasserstein GAN (WGAN)

Problem: __instability of GAN training__.

## GAN problems

GAN training is __unstable__, the following happens.

### Vanishing gradient

When the discriminator converges,  
__gradients propagated to the generator become flat__.

Main cause:  
  _GAN is based on __KL-divergence__ and __JS-divergence__, those have zero gradient near zero._

### Mode collapse

The generator learns to generate only one mode of the dataset (say, the digit '1').

Main cause:  
  the loss function does not enforce full pdf coverage.

## WGAN-GP improvements

Many improvements proposed on GANs.

### The critique

On major cause of __vanishing gradient__ is the __sigmoid function__.

The __critique does not generate probabilities__ (in $[0,1]$).

This pretends to __substitute the binary cross-entropy__ from the loss.

### The lipschitz condition

__Uncontrolled gradients prevent convergence__.

- __JSD__: Jensen-Shannon divergence.  
  __Symmetrized KL__-divergence.
- __EMD__: Earth-Mover distance.  
  Intuitively, __probability mass to be moved__ to transform a distribution into another.
  
  $$\Large W(p_r, p_g) = \inf\mathbb{E}_{x\sim p_r, y\sim p_g}[\|x-y\|]$$

<img src="images/WGAN_EM.png" width="600pt"/>

Comparison of EMD and JSD:

<img src="images/WGAN_gradients.png" width="600pt"/>

Conditions for EMD to behave correctly: __$g(z)$ Lipschitz__.

<img src="images/WGAN_theorem.png" width="600pt"/>

### The loss

__WGAN__: skipping the derivation (see theorem 3 in the [paper](https://arxiv.org/pdf/1701.07875.pdf)).

$$\Large
\min_{\|D\|_L\leq1}\mathcal{L}(D) = \min_{\|D\|_L\leq1}
  \mathbb{E}_{\tilde{x}\sim p_g}[D(\tilde{x})] 
- \mathbb{E}_{x\sim p_r}[D(x)]
$$

Problem: $D$ must be Lipschitz.

__Enforced in WGAN by weight clipping__.

__WGAN-GP__: Idea, penalize the gradients in the loss.

$$\Large
\mathcal{L}(D) = 
\mathbb{E}_{\tilde{x}\sim p_g}[D(\tilde{x})] 
- \mathbb{E}_{x\sim p_r}[D(x)]
+\lambda \mathbb{E}_{\hat{x}\sim p_{\hat{x}}}[(\|\nabla_\hat{x}D(\hat{x})\|_2-1)^2]
$$

where $p_{\hat{x}}$ spans uniformly straight lines between $p_r$ and $p_g$.

<img src="images/WGAN_penalty.png" width="600pt"/>

### The architecture

Minor architecture improvements:

- striding instead of max-pooling;
- transposed convolutions;
- batch normalizations.

### The optimization

- Many steps on $D$ beforem moving $G$;
- Adam with removed exponential decay rate estimate (first moment), reduced the second.

In [None]:
# Training on MNIST:
# TODO