Solutions for mitigating GAN training issues: This notebook shows sample code for implementing feature matching and minibatch discrimination

In [None]:
from tensorflow.keras import Model, layers
import tensorflow as tf

### Feature Matching

Extract activations from an intermediate layer of the discriminator

In [None]:
def generator(noise_dim):
    z = layers.Input(shape=noise_dim)
    x = layers.Dense(units=4*4*1024)(z)
    x = layers.Reshape((4,4,1024))(x)
    for filter_size in [512,256,128,3]:
        x = layers.Conv2DTranspose(filters=filter_size,
               kernel_size=5, strides=2, padding='same')(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.BatchNormalization()(x)
    return Model(inputs = z, outputs=x)

In [None]:
def discriminator():
    img = layers.Input(shape=[64,64,3])
    x = layers.Conv2D(filters=128, kernel_size=5, strides=2,
                      padding='same')(img)
    for filter_size in [256, 512,1024]:
        x = layers.Conv2D(filters=filter_size, kernel_size=5,
                          strides=2, padding='same')(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.BatchNormalization()(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1)(x)
    return Model(inputs = img, outputs = x)

In [None]:
G = generator(100)
D =discriminator()
D.summary()

In [None]:
intermediate_D = Model(D.inputs, D.get_layer('conv2d_2').output)

In [None]:
def generator_loss(real_output, fake_output):
    features_fake = tf.reduce_mean(intermediate_D(fake_output))
    features_real = tf.reduce_mean(intermediate_D(real_output))
    
    return cross_entropy(tf.ones_like(fake_output), fake_output) \
               + tf.square(tf.norm(features_fake-features_real))

### Minibatch Discrimination

It’s based on the idea that a random sample from the real dataset will have diverse set of images, and hence, a minibatch of real images will be very diverse. So, if we somehow measure intra-minibatch similarity of images for a true random sample, we should get a very low similarity score. The samples from generator should also have these characteristics if the generator is a good one.

In [None]:
L = 5 #feature dimension of an intermediate layer f(x)
K = 2 #Number of lob-dimnsional projections
d = 3 #Low dimnsion for projecting feature 
n = 10 #batch size of sample size

Projection Matices

In [None]:
M1 = tf.random.uniform([L,d])
M2 = tf.random.uniform([L,d])

K projection matrices togethers as a single tensor

In [None]:
T = tf.concat([tf.expand_dims(M1, axis=1), tf.expand_dims(M2, axis=1)], axis=1)

Example random intermediate layer output

In [None]:
fx = tf.random.uniform([n,L]) 

In [None]:
projections = tf.einsum('ij,jkl->ikl',fx, T)

In [None]:
row_wise_L1 = tf.abs( 
  tf.map_fn(lambda x: tf.abs(x - projections) ,  tf.expand_dims(projections ,[1])))

In [None]:
sim_scores = tf.exp(-tf.reduce_sum(row_wise_L1, axis = 3))

In [None]:
sim_score_out = tf.reduce_sum( sim_scores , axis=[1])

In [None]:
sim_score_out