In [1]:
#### Notes

#This file contains the code used for the simulations in Figure 1 and in Appendix E.2.

#(This file contains the code for our algorithm and for GDA, for the CIFAR dataset.  For our algorithm with acceptance rate 1/2, set the "rate" paramter to "rate = 2".  For GDA, set the "rate" parameter to rate = 1.)

In [2]:
# example of calculating the frechet inception distance in Keras for cifar10
import numpy
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy.random import shuffle
from scipy.linalg import sqrtm
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from keras.datasets.mnist import load_data
from skimage.transform import resize
from keras.datasets import cifar10
import time


# %load_ext line_profiler

Using TensorFlow backend.





In [3]:

# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))
# load cifar10 images
(train_images, _), (test_images, _) = cifar10.load_data()
shuffle(train_images)
shuffle(test_images)
















In [4]:

# scale an array of images to a new size
def scale_images(images, new_shape):
	images_list = list()
	for image in images:
		# resize with nearest neighbor interpolation
		new_image = resize(image, new_shape, 0)
		# store
# 		print(new_image.shape)
# 		print(new_image)
		images_list.append(new_image)
	return asarray(images_list)

# calculate frechet inception distance
def calculate_fid(model, images1, images2):
	# calculate activations
	act1 = model.predict(images1)
	act2 = model.predict(images2)
	# calculate mean and covariance statistics
	mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    
	# calculate sum squared difference between means
	ssdiff = numpy.sum((mu1 - mu2)**2.0)
	# calculate sqrt of product between cov
	covmean = sqrtm(sigma1.dot(sigma2))
	# check and correct imaginary numbers from sqrt
	if iscomplexobj(covmean):
		covmean = covmean.real
	# calculate score
	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
	return mu1, mu2, sigma1, sigma2, fid

In [5]:
def scale_and_calculate_FID(model, images1, images2, new_shape=(299,299,3)):
    
    images1 = images1.astype('float32')
    images2 = images2.astype('float32')
    
    mu1 = numpy.zeros((1, 2048))
    sigma1 = numpy.zeros((2048, 2048))
    for image in images1:
        images_list = list()
        image = resize(image, new_shape, 0)
        image = preprocess_input(image)
        images_list.append(image)
        act = model.predict(asarray(images_list))
        
        mu1 += act
        sigma1 += numpy.outer(act, act)
    n1 = float(images1.shape[0])
    mu1 /= n1
    sigma1 -= n1*numpy.outer(mu1, mu1)
    sigma1 /= (n1-1)

    mu2 = numpy.zeros((1,2048))
    sigma2 = numpy.zeros((2048, 2048))
    for image in images2:
        images_list = list()
        image = resize(image, new_shape, 0)
        image = preprocess_input(image)
        images_list.append(image)
        act = model.predict(asarray(images_list))
        
        mu2 += act
        sigma2 += numpy.outer(act, act)
    n2 = float(images2.shape[0])
    mu2 /= n2
    sigma2 -= n2*numpy.outer(mu2, mu2)
    sigma2 /= (n2-1)
    
    # calculate sum squared difference between means
    ssdiff = numpy.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)

    return fid

In [6]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
# %load_ext line_profiler

import random


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from keras.optimizers import Adam#, SGD
import sys
sys.path.append("../")

from utils_CIFAR import *       # utils file has the filler code and helper functions

from tqdm.notebook import tqdm
from functools import partial

In [7]:
NOISE_SIZE = 100
IMAGE_SHAPE = (32,32,3)

In [8]:
#set filter = False to include the entire CIFAR-10 dataset
X, _, _, _ = load_data(filter=False)

In [9]:
adam_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)

In [10]:
take_discriminator_steps_2 = partial(take_discriminator_steps, X_train=X, k=1)
getLoss2 = partial(getLoss, X_train=X)
create_gan2 = partial(create_gan, opt=adam_optimizer)

def create_GAN_player():
    ganPlayer = Players(create_generator(IMAGE_SHAPE, opt=adam_optimizer), 
                        create_discriminator(INPUT_SHAPE=IMAGE_SHAPE, opt=adam_optimizer), 
                        create_gan2, 
                        take_generator_steps, 
                        take_discriminator_steps_2, 
                        change_network, 
                        change_network, 
                        perturb_generator)    
    return ganPlayer


In [11]:
def training_gd(create_player_function, create_player_function2):
    #Number of iterations
    T= 50010

    # this will create a Players object, with two players
    player = create_player_function()
    player2 = create_player_function2()
    
    Loss=[]
    pairwise_squared_distances=[]
    
    FID_scores=[]
    
    old_loss = 100
    player.update_y()
    
    #how often to not accept/reject
    rate = 2
    
    for j in tqdm(range(T)):
        
        
        if j%rate != 0:
            print("\nIteration ", j)
            
            #save generator weights and the old loss    
            player2.change_x(player.get_x())        
            player2.change_y(player.get_y())

            
        if j>0:
            loss_old = player.value(getLoss2)  
            print("Old Loss: ", loss_old)        
            Loss.append(loss_old)
        
        
        # perform one gradient update for the generator and k gradient updates for the discriminator (we only use "k=1" discriminator gradient steps for CIFAR)
        player.update_x()
        k = 1
        for s in range(k):
            player.update_y()        

        #Accept/reject Step
        if j%rate != 0:
            loss_new = player.value(getLoss2)            
        
            if loss_new > loss_old:
                print("Reject")
                player.change_x(player2.get_x())        
                player.change_y(player2.get_y())
            else:
                print("Accept")
                

        folder_name ='results_supplementary'
        
        if (j%100 == 0 and j<3001) or j%1000==0:
            loss = player.value(getLoss2)
            print("Ending Loss:",  loss)
            filename = '/results'
            #plot the genererated images ()
            plot_generated_images(j, 
                                  player.get_x(),
                                  folder=folder_name, 
                                  save=True, 
                                  image_shape=IMAGE_SHAPE, 
                                  name=filename+' %d.png')
            
        if j>0 and j%10==0:
            plt.plot(Loss)
            np.save(folder_name + filename + '_loss_values', Loss)
    

    #compute FID scores
    
        if j>0 and j%2500==0:
            FID_sample_size = 10000

            images1_a = train_images[np.random.randint(49999, size=FID_sample_size)]

            FakeImages = generate_fake_FID_image_input(generator=player.get_x(),examples=FID_sample_size, image_shape=IMAGE_SHAPE)
            images2_a = 255*FakeImages
                
            start = time.time()
            fid_a = scale_and_calculate_FID(model, images1_a, images2_a)
            end = time.time()
            print(end - start)
            FID_scores.append(fid_a)
            print('FID_scores')
            print(FID_scores)
            np.save(folder_name + filename + '_FID_scores', FID_scores)
    
    
    return player

In [None]:
# %lprun -f training_gd 
training_gd(create_GAN_player, create_GAN_player)



Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.




HBox(children=(FloatProgress(value=0.0, max=50010.0), HTML(value='')))

Ending Loss: -1.2932832

Iteration  1
Old Loss:  -1.2905293
Reject
Old Loss:  -1.2895181

Iteration  3
Old Loss:  -1.2325437
Reject
Old Loss:  -1.2278309

Iteration  5
Old Loss:  -1.1441116
Reject
Old Loss:  -1.1467028

Iteration  7
Old Loss:  -1.0965085
Accept
Old Loss:  -1.1214663

Iteration  9
Old Loss:  -1.0599692
Reject
Old Loss:  -1.0832694

Iteration  11
Old Loss:  -0.9607671
Reject
Old Loss:  -0.93141186

Iteration  13
Old Loss:  -0.79709804
Reject
Old Loss:  -0.7663938

Iteration  15
Old Loss:  -0.77199113
Reject
Old Loss:  -0.7431572

Iteration  17
Old Loss:  -0.74580073
Accept
Old Loss:  -0.76233447

Iteration  19
Old Loss:  -0.78657377
Accept
Old Loss:  -0.81262165

Iteration  21
Old Loss:  -0.9446621
Accept
Old Loss:  -1.0516528

Iteration  23
Old Loss:  -0.99243045
Reject
Old Loss:  -0.95154834

Iteration  25
Old Loss:  -0.81944823
Reject
Old Loss:  -0.8163549

Iteration  27
Old Loss:  -0.70802903
Reject
Old Loss:  -0.7330277

Iteration  29
Old Loss:  -0.6686849
Accept
Ol