//Copyright (c) Microsoft Corporation. All rights reserved. 
//Licensed under the MIT License.

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import keras 
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist,cifar10
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers
from scipy import stats
import warnings
import PrivacyGAN as pg 
from keras.models import load_model
from datetime import datetime
from sklearn.datasets import fetch_lfw_people


warnings.filterwarnings("ignore")
K.set_image_dim_ordering('th')

### Load dataset

In [None]:
lfw_people = fetch_lfw_people()
X_all = lfw_people['data']
X_all = (X_all.astype(np.float32) - 127.5)/127.5

#Generate training test split
frac = 0.1 
n = int(frac*len(X_all))
l = np.array(range(len(X_all)))
l = np.random.choice(l,len(l),replace = False)
X = X_all[l[:n]]
X_comp = X_all[l[n:]]

print('training set size:',X.shape)
print('test set size:',X_comp.shape)

### Simple GAN

In [None]:
generator = pg.LFW_Generator()
discriminator = pg.LFW_Discriminator()
(generator, discriminator, dLosses, gLosses) = pg.SimpGAN(X, epochs = 500, batchSize = 256,
                                                         discriminator = discriminator,
                                                         generator = generator)

In [None]:
#perform white box attack
Acc = pg.WBattack(X,X_comp, discriminator)

In [None]:
#plot distribution of discriminator scores of training and test set
plt.hist(discriminator.predict(X)[:,0],color = 'r', alpha = 0.5, label = 'train', normed = 1, bins = 50)
plt.hist(discriminator.predict(X_comp)[:,0],color = 'b', alpha = 0.5, label = 'test', normed = 1, bins = 50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('GAN')
plt.legend()

In [None]:
#generate synthetic images
noise = np.random.normal(0, 1, size=[X.shape[0], 100])
generatedImages = generator.predict(noise)
temp = generatedImages[:25].reshape(25, 62, 47)
plt.figure(figsize=(5, 5))
for i in range(temp.shape[0]):
    plt.subplot(5,5, i+1)
    plt.imshow(temp[i], interpolation='nearest', cmap='gray_r')
    plt.axis('off')
plt.tight_layout()

### Private GAN

In [None]:
K.clear_session()
optim = Adam(lr=0.0002, beta_1=0.5)
generators = [pg.LFW_Generator(optim = Adam(lr=0.0002, beta_1=0.5)),
              pg.LFW_Generator(optim = Adam(lr=0.0002, beta_1=0.5))]
discriminators = [pg.LFW_Discriminator(optim = Adam(lr=0.0002, beta_1=0.5))
                  ,pg.LFW_Discriminator(optim = Adam(lr=0.0002, beta_1=0.5))]
pDisc = pg.LFW_DiscriminatorPrivate(OutSize = 2, 
                                      optim = Adam(lr=0.0002, beta_1=0.5))

(generators, discriminators, _, dLosses, dpLosses, gLosses)= pg.privGAN(X, epochs = 500, 
                                                                           disc_epochs=50,
                                                                           batchSize=256,
                                                                           generators = generators, 
                                                                           discriminators = discriminators,
                                                                           pDisc = pDisc,
                                                                           optim = optim,
                                                                           privacy_ratio = 1.0)

In [None]:
#perform white box attack
pg.WBattack_priv(X,X_comp, discriminators)

In [None]:
#generate synthetic images
noise = np.random.normal(0, 1, size=[X.shape[0], 100])
generatedImages = generators[0].predict(noise)
temp = generatedImages[:25].reshape(25, 62, 47)
plt.figure(figsize=(5, 5))
for i in range(temp.shape[0]):
    plt.subplot(5,5, i+1)
    plt.imshow(temp[i], interpolation='nearest', cmap='gray_r')
    plt.axis('off')
plt.tight_layout()

In [None]:
#plot distribution of discriminator scores of training and test set
plt.hist(discriminators[0].predict(X)[:,0],color = 'r', alpha = 0.5, label = 'train', normed = 1, bins = 50)
plt.hist(discriminators[0].predict(X_comp)[:,0],color = 'b', alpha = 0.5, label = 'test', normed = 1, bins = 50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('privGAN')
plt.legend()