In [None]:
import tensorflow as tf
import numpy as np
from numpy import size

In [None]:
class GradientPenalty:
  """Computes the gradient penalty as defined in "Improved Training of Wassertein GANS"
    (https://arxiv.org/abs/1704.00028)
  """
  #Initialize class attributes
  def __init__(self, lambdaGP, gamma=1):
    self.lambdaGP = lambdaGP
    self.gamma = gamma

  #Call class using parameters discriminator model, the real data, and fake data
  def __call__(self, disc_model, real_data, fake_data):
    batch_size = real_data.shape[0]
    fake_data = fake_data[:batch_size]

    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)

    inter = real_data + alpha * (fake_data - real_data)

    with tf.GradientTape() as t:
      t.watch(inter)
      pred = disc_model(inter)
    grad1 = t.gradient(pred, [inter])
    print(grad1)
    grad = grad1[0]
    # grad = t.gradient(pred, [inter])[0]

    slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]))
    gp = tf.reduce_mean(((slopes - self.gamma)/self.gamma)**2)*self.lambdaGP
    return gp



Gradient Penalty Unit Test

In [None]:
fake_data = np.array([1, 2, 3, 4])
real_data = np.array([4, 6, 7, 8])
def disc(arg):
  return arg*arg

gp = GradientPenalty(10)


In [None]:
penalty = gp.__call__(disc, real_data, fake_data)
print (penalty)

[<tf.Tensor: shape=(4, 1, 1, 4), dtype=float32, numpy=
array([[[[ 6.7470255, 10.329368 , 12.329368 , 14.329368 ]]],


       [[[ 5.0937886,  8.1250515, 10.1250515, 12.1250515]]],


       [[[ 3.0437002,  5.3916006,  7.3916006,  9.391601 ]]],


       [[[ 5.81446  ,  9.085946 , 11.085946 , 13.085946 ]]]],
      dtype=float32)>]
tf.Tensor(3242.9624, shape=(), dtype=float32)


Original Pytorch Implementation

In [None]:
import torch
from torch.autograd import Variable, grad

class GradientPenalty2:
    """Computes the gradient penalty as defined in "Improved Training of Wasserstein GANs"
    (https://arxiv.org/abs/1704.00028)
    Args:
        batchSize (int): batch-size used in the training. Must be updated w.r.t the current batchsize
        lambdaGP (float): coefficient of the gradient penalty as defined in the article
        gamma (float): regularization term of the gradient penalty, augment to minimize "ghosts"
    """

    def __init__(self, lambdaGP, gamma=1, vertex_num=2500, device=torch.device('cpu')):
        self.lambdaGP = lambdaGP
        self.gamma = gamma
        self.vertex_num = vertex_num
        self.device = device

    def __call__(self, netD, real_data, fake_data):
        batch_size = real_data.size(0)

        fake_data = fake_data[:batch_size]
        
        alpha = torch.rand(batch_size, 1, 1, requires_grad=True).to(self.device)
        # randomly mix real and fake data
        interpolates = real_data + alpha 
        interpolates *= (fake_data - real_data)
        # compute output of D for interpolated input
        disc_interpolates = netD(interpolates)
        # compute gradients w.r.t the interpolated outputs
        
        gradients = grad(outputs=disc_interpolates, inputs=interpolates,
                         grad_outputs=torch.ones(disc_interpolates.size()).to(self.device),
                         create_graph=True, retain_graph=True, only_inputs=True)[0].contiguous().view(batch_size,-1)
                         
        gradient_penalty = (((gradients.norm(2, dim=1) - self.gamma) / self.gamma) ** 2).mean() * self.lambdaGP

        return gradient_penalty

In [None]:
gp2 = GradientPenalty2(10)
real_data2 = torch.Tensor(1, 2, 3, 4)
fake_data2 = torch.Tensor(5, 6, 7, 8)
penalty = gp2.__call__(disc, real_data2, fake_data2)
print(penalty)

RuntimeError: ignored