//Copyright (c) Microsoft Corporation. All rights reserved. 
//Licensed under the MIT License.

In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from tensorflow.keras import Input
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Reshape, Dense, Dropout, Flatten, LeakyReLU, Conv2D, MaxPooling2D, ZeroPadding2D, Conv2DTranspose, UpSampling2D, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras import initializers
from tensorflow.keras import backend as K
from scipy import stats
import tensorflow as tf
import warnings
import PrivacyGAN as pg 
from datetime import datetime


warnings.filterwarnings("ignore")
print(tf.__version__)

### Load dataset

In [None]:
# Load MNIST data
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_test = (X_test.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1]*X_train.shape[2])
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1]*X_test.shape[2])
X_all = np.concatenate((X_train,X_test))


#Generate training test split
frac = 0.1 
n = int(frac*len(X_all))
l = np.array(range(len(X_all)))
l = np.random.choice(l,len(l),replace = False)
X = X_all[l[:n]]
X_comp = X_all[l[n:]]

print('training set size:',X.shape)
print('test set size:',X_comp.shape)

### Simple GAN

In [None]:
(generator, discriminator, dLosses, gLosses) = pg.SimpGAN(X, epochs = 1, batchSize = 256, verbose = 50)

In [None]:
#perform white box attack
Acc = pg.WBattack(X,X_comp, discriminator)

In [None]:
#plot distribution of discriminator scores of training and test set
plt.hist(discriminator.predict(X)[:,0],color = 'r', alpha = 0.5, label = 'train', density = True, bins = 50)
plt.hist(discriminator.predict(X_comp)[:,0],color = 'b', alpha = 0.5, label = 'test', density = True, bins = 50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('GAN')
plt.legend()

In [None]:
#Generate fake images
pg.DisplayImages(generator, figSize=(5,5), NoImages = 25)

### Private GAN 

In [None]:
K.clear_session()
optim = Adam(lr=0.0002, beta_1=0.5)
generators = [pg.MNIST_Generator(optim = Adam(lr=0.0002, beta_1=0.5)),
              pg.MNIST_Generator(optim = Adam(lr=0.0002, beta_1=0.5))]
discriminators = [pg.MNIST_Discriminator(optim = Adam(lr=0.0002, beta_1=0.5))
                  ,pg.MNIST_Discriminator(optim = Adam(lr=0.0002, beta_1=0.5))]
pDisc = pg.MNIST_DiscriminatorPrivate(OutSize = 2, 
                                      optim = Adam(lr=0.0002, beta_1=0.5))

(generators, discriminators, _, dLosses, dpLosses, gLosses)= pg.privGAN(X, epochs = 1, 
                                                                           disc_epochs=1,
                                                                           batchSize=256,
                                                                           generators = generators, 
                                                                           discriminators = discriminators,
                                                                           pDisc = pDisc,
                                                                           optim = optim,
                                                                           privacy_ratio = 1.0)

In [None]:
#perform white box attack
pg.WBattack_priv(X,X_comp, discriminators)

In [None]:
#Generate fake images
pg.DisplayImages(generators[1], figSize=(5,5), NoImages = 25)

In [None]:
#plot distribution of discriminator scores of training and test set
plt.hist(discriminators[0].predict(X)[:,0],color = 'r', alpha = 0.5, label = 'train', density = True, bins = 50)
plt.hist(discriminators[0].predict(X_comp)[:,0],color = 'b', alpha = 0.5, label = 'test', density = True, bins = 50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('privGAN')
plt.legend()