//Copyright (c) Microsoft Corporation. All rights reserved. 
//Licensed under the MIT License.

In [None]:
import numpy as np 
import pandas as pd 
import sklearn 
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import warnings
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, MaxPooling2D
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Conv2D, Conv2DTranspose, Reshape, Flatten
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.datasets import mnist,cifar10
import matplotlib.pyplot as plt
import sys
import pickle
from tqdm import tqdm
from keras import initializers
import matplotlib.pyplot as plt
import PrivacyGAN as pg 
from keras import backend as K
warnings.filterwarnings("ignore")

### Load dataset

In [None]:
# Load CIFAR-10 data
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

X_all = np.concatenate((X_train,X_test))

(n, d1, d2, d3) = X_all.shape

if d3 !=3:
    X_all = np.moveaxis(X_all, 1, 3)
    
X_all = np.float32(X_all)
X_all = (X_all / 255 - 0.5) * 2
X_all = np.clip(X_all, -1, 1)

#Generate training test split
frac = 0.1 
n = int(frac*len(X_all))
l = np.array(range(len(X_all)))
l = np.random.choice(l,len(l),replace = False)
X = X_all[l[:n]]
X_comp = X_all[l[n:]]

print(X.shape)
print(X_comp.shape)

### Run simple GAN 

In [None]:
#Specify models 
generator = pg.CIFAR_Generator()
discriminator = pg.CIFAR_Discriminator()
generators = [pg.CIFAR_Generator(),pg.CIFAR_Generator()]
discriminators = [pg.CIFAR_Discriminator(),pg.CIFAR_Discriminator()]
pDisc = pg.CIFAR_DiscriminatorPrivate(OutSize = 2)

In [None]:
(generator, discriminator, dLosses, gLosses) = pg.SimpGAN(X, epochs = 500, 
                                                          generator = generator, 
                                                          discriminator = discriminator,
                                                          batchSize=256)

In [None]:
#perform white box attack
Acc = pg.WBattack(X,X_comp, discriminator)

In [None]:
#plot distribution of discriminator scores of training and test set
plt.hist(discriminator.predict(X)[:,0],color = 'r', alpha = 0.5, label = 'train', normed = 1, bins = 50)
plt.hist(discriminator.predict(X_comp)[:,0],color = 'b', alpha = 0.5, label = 'test', normed = 1, bins = 50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('GAN')
plt.legend()

In [None]:
noise = np.random.normal(0, 1, size=[X.shape[0], 100])
generatedImages = generator.predict(noise)
temp = generatedImages[:25].reshape(25, 32, 32, 3)
plt.figure(figsize=(5, 5))
for i in range(temp.shape[0]):
    plt.subplot(5,5, i+1)
    plt.imshow(temp[i], interpolation='nearest', cmap='gray_r')
    plt.axis('off')
plt.tight_layout()

### Run Private GAN 

In [None]:
K.clear_session()
optim = Adam(lr=0.0002, beta_1=0.5)
generator = pg.CIFAR_Generator(optim = Adam(lr=0.0002, beta_1=0.5))
discriminator = pg.CIFAR_Discriminator(optim = Adam(lr=0.0002, beta_1=0.5))
generators = [pg.CIFAR_Generator(optim = Adam(lr=0.0002, beta_1=0.5)),
              pg.CIFAR_Generator(optim = Adam(lr=0.0002, beta_1=0.5))]
discriminators = [pg.CIFAR_Discriminator(optim = Adam(lr=0.0002, beta_1=0.5)),
                  pg.CIFAR_Discriminator(optim = Adam(lr=0.0002, beta_1=0.5))]
pDisc = pg.CIFAR_DiscriminatorPrivate(OutSize = 2,
                                     optim = Adam(lr=0.0002, beta_1=0.5))
(generators, discriminators, _, dLosses, dpLosses, gLosses)= pg.privGAN(X, epochs = 500,
                                                                           disc_epochs=50,
                                                                           generators = generators, 
                                                                           discriminators = discriminators,
                                                                           pDisc = pDisc,
                                                                           optim = optim,
                                                                           privacy_ratio=1.0,
                                                                           batchSize=256)

In [None]:
#perform white box attack
pg.WBattack_priv(X,X_comp, discriminators)

In [None]:
noise = np.random.normal(0, 1, size=[X.shape[0], 100])
generatedImages = generators[0].predict(noise)
temp = generatedImages[:25].reshape(25, 32, 32, 3)
plt.figure(figsize=(5, 5))
for i in range(temp.shape[0]):
    plt.subplot(5,5, i+1)
    plt.imshow(temp[i], interpolation='nearest', cmap='gray_r')
    plt.axis('off')
plt.tight_layout()

In [None]:
plt.hist(discriminators[0].predict(X)[:,0],color = 'r', alpha = 0.5, label = 'train', normed = 1, bins = 50)
plt.hist(discriminators[0].predict(X_comp)[:,0],color = 'b', alpha = 0.5, label = 'test', normed = 1, bins = 50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('privGAN (1.0)')
plt.legend()

In [None]:
pg.WBattack_priv(X,X_comp, discriminators)