# Lab 3: Mode Collapse and WGANs

The aim of this lab is to:


1.   See mode collapse in action on a synthetic dataset
2.   Implement a WGAN on the synthetic dataset to see what changes



### 1. Mode Collapse

First we build a vanilla GAN on the the following synthetic dataset which has 4 modes.
<center>
<img src="https://i.ibb.co/dPbCXHW/multimodal-gaussianmix.png" width="400" />
<figcaption>
</div>
</center>



Here is the architecture we'll be using 
<br><br>
<center>
<img src="https://i.ibb.co/fYzWhGc/lab4-arch.jpg" width="600" />
</center>


## Vanilla GAN on synthetic data

In [None]:
#Importing libraries
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from matplotlib.animation import FuncAnimation
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from IPython.display import HTML
from matplotlib.animation import FuncAnimation 
from keras.layers import LeakyReLU
from sklearn.datasets import make_blobs
from tqdm.notebook import tqdm
import tensorflow.keras.backend as K

In [None]:
# Generate the synthetic dataset, already given!
def get_real_data(n_samples):
  data, _ = make_blobs(n_samples = n_samples, n_features = 2, centers = [(2,2), (-2,2), (-2,-2), (2,-2)], cluster_std=0.3)
  return data

In [None]:
epochs = 1000
batch_size = 512
latent_dim = 2

In [None]:
def build_discriminator(dim):
  model = Sequential()
  for _ in range(2):
    model.add(Dense(64,input_dim=dim,activation=LeakyReLU(alpha=0.1)))
  model.add(Dense(1, activation='sigmoid'))
 
  model.compile('adam',loss='binary_crossentropy',metrics=['accuracy'])
  return model

In [None]:
def build_generator(latent_dim, output_dim):
  model = Sequential()
  for _ in range(4):
    model.add(Dense(16,input_dim=latent_dim,activation=LeakyReLU(alpha=0.1)))
  model.add(Dense(output_dim))
  return model

In [None]:
# Given a generator and a discriminator, build a GAN
def build_GAN(G, D, latent_dim):
  D.trainable = False
  input_layer = tf.keras.layers.Input((latent_dim,))
  X = G(input_layer)
  output_layer = D(X)
  GAN = Model(input_layer, output_layer)
  GAN.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
  return GAN

In [None]:
# Generate random uniform noise to input to the generator 
def generate_input_noise(batch_size, latent_dim):
    return np.random.randn(batch_size,latent_dim) 

In [None]:
# Build the GAN
G = build_generator(latent_dim, 2)
D = build_discriminator(2)
GAN = build_GAN(G, D, latent_dim)

In [None]:
# Training the GAN

### Plotting stuff, do not touch ###
x_min = -4; x_max = 4; y_min = -4; y_max = 4
xx, yy = np.mgrid[(x_min):(x_max):.1, (y_min):(y_max):.1]
grid = np.c_[xx.ravel(), yy.ravel()]
D_loss = []
G_loss = []
G_predict=[]
D_contours = []
### --------------------- ###

for step in tqdm(range(epochs)):

    # Train discriminator
    real_data = get_real_data(batch_size // 2) 
    fake_data = G.predict(generate_input_noise(batch_size // 2, latent_dim), batch_size=batch_size // 2)
    data = np.concatenate((real_data, fake_data), axis=0)

    labels = np.concatenate((np.ones((batch_size // 2, 1)), np.zeros((batch_size // 2, 1))), axis=0)
    _D_loss, _ = D.train_on_batch(data, labels)

    noise = generate_input_noise(batch_size, latent_dim)
    labels = np.ones((batch_size, 1))
    _G_loss, _ = GAN.train_on_batch(noise, labels)
    
    D_loss.append(_D_loss)
    G_loss.append(_G_loss) 

    probs = D.predict(grid).reshape(xx.shape)
    D_contours.append(probs)
    test_noise = generate_input_noise(500, latent_dim)
    fake_samples = G.predict(test_noise, batch_size=len(test_noise))
    G_predict.append(fake_samples)



  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
# See the training process
fig, ax = plt.subplots(1, 2, figsize=(14,6))
plt.close(fig)
ax1 = ax[0]
ax2 = ax[1]

def animate(i):
  i = i*4
  probs = D_contours[i]
  probs[probs < 0.5] = 0
  probs[probs >= 0.5] = 1
  ax1.clear()
  ax2.clear()
  ax1.contourf(xx, yy, probs, 25, alpha = 0.4)
  fake_data = G_predict[i]
  real_data = get_real_data(500)
  # Plot loss and accuracy
  ax1.scatter(fake_data[:, 0], fake_data[:, 1], s = 10, label = 'Fake data')
  ax1.scatter(real_data[:, 0], real_data[:, 1], s = 10, label = 'Real data')
  ax1.set(xlim=(x_min, x_max), ylim=(y_min, y_max))
  ax1.legend(loc="lower right")
  ax1.set_title("Epochs: {}".format(i+1))
  ax2.plot(np.arange(i), G_loss[0:i],label='G loss',c='darkred',zorder=50,alpha=0.8)
  ax2.plot(np.arange(i), D_loss[0:i],label='D loss',c='darkblue',zorder=55,alpha=0.8)
  ax2.set_xlim(-5, epochs+5)
  ax2.legend()
  ax2.set_xlabel('Epoch')

anim = FuncAnimation(fig,animate,frames = epochs//4, interval=100, repeat = True)
HTML(anim.to_html5_video())

## WGAN 
$$
W(p_r,p_g) = \frac{1}{K}\sup_{‖f‖_L≤K} {E_{x∼p_r}[f(x)]−E_{x∼p_g}[f(x)]}
$$
For all $x_1, x_2 \in \mathbb{R}$:
$$
|f(x_1) - f(x_2)| \leq K|x_1 - x_2|
$$
<br>
Main changes:


1.   Output of our critic is not restricted to $[0,1]$. It can take all **real values** and is interpreted as a **score** instead of a probability
2.   Our Loss function is based on the **Wasserstein distance**
3. We have to **clip our weights** in a range ($[-0.01, 0.01]$) to enforce Lipschitz continuity
4. Update the critic more times than the generator

$$
\begin{align}
&\mathcal{L}_{C} = -\bigg(E_{x \sim p_r}[C(x)] - E_{z \sim p_z}[C(D(z))]\bigg) \\
&\mathcal{L}_{G} = -E_{z \sim p_g}[C(D(z))]
\end{align}
$$

To keep this coherent with the same convention that we have used, we may use labels as $y_{real} = +1$ and $y_{fake} = -1$ and then implement loss function as $y_{true} \times y_{pred}$

In [None]:
epochs = 3000

In [None]:
def wasserstein(y_true, y_pred):
    return -K.mean(y_true * y_pred)

In [None]:
# Build a discriminator neural network
def build_critic(dim):
  model = Sequential()
  for _ in range(2):
    model.add(Dense(64,input_dim=dim,activation=LeakyReLU(alpha=0.1)))
  model.add(Dense(1, activation = None))
  model.compile(Adam(learning_rate=0.002, beta_1=0.5),loss=wasserstein)
  return model

In [None]:
def build_WGAN(G, C, latent_dim):
  C.trainable = False
  input_layer = tf.keras.layers.Input((latent_dim,))
  X = G(input_layer)
  output_layer = C(X)
  GAN = Model(input_layer, output_layer)
  GAN.compile(Adam(learning_rate=0.005, beta_1=0.5),loss=wasserstein)
  return GAN

In [None]:
G = build_generator(latent_dim, 2)
C = build_critic(2)
GAN = build_WGAN(G, C, latent_dim)

In [None]:
#####
n_critic = 3
#####

# Training the GAN
C_loss = []
G_loss = []
G_predict=[]

for step in tqdm(range(epochs)):


    # Train discriminator
    for _ in range(n_critic):
      real_data = get_real_data(batch_size // 2) 
      fake_data = G.predict(generate_input_noise(batch_size // 2, latent_dim), batch_size=batch_size // 2)
      data = np.concatenate((real_data, fake_data), axis=0)

      labels = np.concatenate((np.ones((batch_size // 2, 1)), -np.ones((batch_size // 2, 1))), axis=0)
      _C_loss = C.train_on_batch(data, labels)

      clip_threshold = 0.01      
      for l in C.layers:
          weights = l.get_weights()
          weights = [np.clip(w, -clip_threshold, clip_threshold) for w in weights]
          l.set_weights(weights)

    noise = generate_input_noise(batch_size, latent_dim)
    labels = np.ones((batch_size, 1))
    _G_loss = GAN.train_on_batch(noise, labels)
    
    C_loss.append(_C_loss)
    G_loss.append(_G_loss) 

    test_noise = generate_input_noise(500, latent_dim)
    fake_samples = G.predict(test_noise, batch_size=len(test_noise))
    G_predict.append(fake_samples)

  0%|          | 0/3000 [00:00<?, ?it/s]

In [None]:

fig, ax = plt.subplots(1, 2, figsize=(14,6))
plt.close(fig)
ax1 = ax[0]
ax2 = ax[1]

def animate(i):
  i = i*10
  ax1.clear()
  ax2.clear()
  fake_data = G_predict[i]
  real_data = get_real_data(500)
  ax1.scatter(fake_data[:, 0], fake_data[:, 1], s = 10)
  ax1.scatter(real_data[:, 0], real_data[:, 1], s = 10)
  ax1.set(xlim=(x_min, x_max), ylim=(y_min, y_max))
  ax1.set_title("Epochs: {}".format(i+1))
  ax2.plot(np.arange(i), G_loss[0:i],label='G loss',c='darkred',zorder=50,alpha=0.8)
  ax2.plot(np.arange(i), C_loss[0:i],label='D loss',c='darkblue',zorder=55,alpha=0.8)
  ax2.legend()
  ax2.set_xlabel('Epoch')

anim = FuncAnimation(fig,animate,frames = epochs//10, interval=100, repeat = True)
HTML(anim.to_html5_video())