# Example of a Generative Adversarial Network


This notebook shows hot to build and train an adversarial generative network to train 2D images. 

We are looking at the 2D projection of some 3D calorimetric images 

In [None]:
import keras.backend as K

import h5py 
import numpy as np

#options for GPU running
import tensorflow as tf
session_config = tf.ConfigProto(log_device_placement=True)
session_config.gpu_options.allow_growth=True
session = tf.Session(config=session_config)
K.set_session(session)

K.set_image_dim_ordering('th')

from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adadelta, Adam, RMSprop
from keras.utils.generic_utils import Progbar
from sklearn.model_selection import train_test_split
 

### Define the models

In [None]:
from keras.layers import (Input, Dense, Reshape, Flatten, Lambda, merge,
                          Dropout, BatchNormalization, Activation, Embedding)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import (UpSampling2D, Conv2D, ZeroPadding2D,
                                        AveragePooling2D)

from keras.models import Model, Sequential


In [None]:
nb_epochs = 20
batch_size = 1000
latent_size = 100
nevt = 50000   # number of events
verbose = 'false'
    

### Discriminator model

Create discriminator model using Keras functional model building

In [None]:
input_image = Input(shape=(1, 25, 25))

x = Conv2D(32, (5,5), data_format='channels_first', padding='same')(input_image)
x = LeakyReLU()(x)

x = ZeroPadding2D((2,2))(x)
x = Conv2D(8, (5, 5), data_format='channels_first', padding='valid')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)

x = ZeroPadding2D((2, 2))(x)
x = Conv2D(8, (5,5), data_format='channels_first', padding='valid')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)

x = ZeroPadding2D((1, 1))(x)
x = Conv2D(8, (5, 5), data_format='channels_first', padding='valid')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)


x = AveragePooling2D((2, 2))(x)
h = Flatten()(x)

#dnn = Model(input_image, h)
#print('Discriminator Model summary')
#dnn.summary()

#image = Input(shape=(1, 25, 25))

#dnn_out = dnn(image)


fake = Dense(1, activation='sigmoid')(h)

# to add extra loss functions
#    aux = Dense(1, activation='linear', name='auxiliary')(dnn_out)
#    ecal = Lambda(lambda x: K.sum(x, axis=(2, 3)))(image)
    
discriminator = Model(input=input_image, output=[fake])
discriminator.summary()


### Create generator model

In [None]:
gen = Sequential()

latent_size = 1024
gen.add(Dense(64 * 7, input_dim=latent_size))
gen.add(Reshape((8, 7,8)))

gen.add(Conv2D(64, (6, 8), data_format='channels_first', padding='same', kernel_initializer='he_uniform'))
gen.add(LeakyReLU())
gen.add(BatchNormalization())

gen.add(UpSampling2D(size=(2, 2)))
gen.add(ZeroPadding2D((2, 0)))

gen.add(Conv2D(6, (5, 8), data_format='channels_first', kernel_initializer='he_uniform'))
gen.add(LeakyReLU())
gen.add(BatchNormalization())

gen.add(UpSampling2D(size=(2, 3)))
gen.add(ZeroPadding2D((0,3)))

gen.add(Conv2D(6, (3, 8), data_format='channels_first', kernel_initializer='he_uniform'))
gen.add(LeakyReLU())

gen.add(Conv2D(1, (2, 2), data_format='channels_first', use_bias=False, kernel_initializer='glorot_normal'))
gen.add(Activation('relu'))
    
generator = gen
generator.summary()

### Compile models


In [None]:

print('[INFO] Building discriminator')
    #discriminator.load_weights('veganweights/params_discriminator_epoch_019.hdf5')
discriminator.compile(optimizer=RMSprop(), loss='binary_crossentropy')

discriminator.summary()

# build the generator
print('[INFO] Building generator')
#generator.load_weights('veganweights/params_generator_epoch_019.hdf5')
generator.compile( optimizer=RMSprop(), loss='binary_crossentropy')


### Make combined model

In [None]:
latent = Input(shape=(latent_size, ), name='combined_z')

fake_image = generator( latent)

discriminator.trainable = False
fake_disc_out = discriminator(fake_image)
combined = Model(inputs=[latent], outputs=[fake_disc_out], name='combined_model')

combined.compile( optimizer=RMSprop(), loss='binary_crossentropy')
combined.summary()

### Get input data

In [None]:
### Full data set (~ 20 GB)
##d=h5py.File("/home/moneta/data/Ele_v1_1_2.h5",'r')
d=h5py.File("Ele_GAN_2D.h5",'r')
    #  target is 200k x 2
    #  ECAL is 200k x 25 x 25 x25
e=d.get('target')
xd = d.get('ECAL')
print(xd.shape, e.shape)

print ('Number of events in file',xd.shape[0])

print('image size is :',xd.shape[2],' x ',xd.shape[3])

In [None]:
#convert in input Numpy array
X=np.array(xd[:nevt,:,:,:])
y=(np.array(e[:nevt]))

print('*** Input data shapes ***')
print(X.shape)
print(y.shape)
 

##### Define variables before training

In [None]:
# remove unphysical values
X[X < 1e-6] = 0

# split in train and test data
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, test_size=0.2)


y_train= y_train/100
y_test=y_test/100


nb_train, nb_test = X_train.shape[0], X_test.shape[0]

#converte to float32

X_train = X_train.astype(np.float32)  
X_test = X_test.astype(np.float32)
y_train = y_train.astype(np.float32)
y_test = y_test.astype(np.float32)
ecal_train = np.sum(X_train, axis=(2, 3))
ecal_test = np.sum(X_test, axis=(2, 3))

print(X_train.shape)
print(X_test.shape)
print(ecal_train.shape)
print(ecal_test.shape)

print('*************************************************************************************')

#from collections import defaultdict
#train_history = defaultdict(list)
#test_history = defaultdict(list)


### Book Histograms

In [None]:
import ROOT
hl_disc_test = ROOT.TH1D("hldtest","Loss Discriminator test",nb_epochs,0,nb_epochs)
hl_disc_train = ROOT.TH1D("hldtrain","Loss Discriminator training",nb_epochs,0,nb_epochs)
hl_gen_test = ROOT.TH1D("hlgtest","Loss Generatator test",nb_epochs,0,nb_epochs)
hl_gen_train = ROOT.TH1D("hlgtrain","Loss Generatator train",nb_epochs,0,nb_epochs)

htrue = ROOT.TH2D("htrue","true image",25,0,25,25,0,25)
htrueX = ROOT.TH1D("htruex","true image in X",25,0,25)
htrueY = ROOT.TH1D("htruey","true image in Y",25,0,25)
htrueE = ROOT.TH1D("htrueE","true tot energy",100,0,10)

nevents = X_test.shape[0]
for index in range(nevents):
    sum = 0
    for i in range(25):
            for j in range(25) :
                htrue.Fill(i+.5,j+.5,X_test[index,0,i,j])
                htrueX.Fill(i+.5,X_test[index,0,i,j])
                htrueY.Fill(j+.5,X_test[index,0,i,j])
                sum += X_test[index,0,i,j]
    htrueE.Fill(sum)
        



In [None]:
cname = "canvas_image"
c1 = ROOT.TCanvas(cname,cname,1500,500)
c1.Divide(3,1)
c1.cd(1)
htrue.SetMaximum(1.2*htrue.GetMaximum())
htrue.Draw("COLZ")
c1.cd(2)
htrueX.Draw('HIST')
c1.cd(3)
htrueY.Draw('HIST')
c1.Draw()
            
closs = ROOT.TCanvas("closs","closs",1200,1200)
closs.Divide(2,2)



## GAN Training

In [None]:

for epoch in range(nb_epochs):
    print('Epoch ',epoch + 1,' of ', nb_epochs)

    nb_batches = int(X_train.shape[0] / batch_size)
    if verbose:
        progress_bar = Progbar(target=nb_batches)

    epoch_gen_loss = []
    epoch_disc_loss = []
    for index in range(nb_batches):
        if verbose:
            progress_bar.update(index)
        else:
            if index % 1 == 0:
                print('processed {}/{} batches'.format(index + 1, nb_batches))

       #generate random noise
        noise = np.random.normal(0, 1, (batch_size, latent_size))

        image_batch = X_train[index * batch_size:(index + 1) * batch_size]
        energy_batch = y_train[index * batch_size:(index + 1) * batch_size]
        
        
        #generate random data for generator
        sampled_energies = np.random.uniform(1, 5,( batch_size,1 ))
        generator_input = np.multiply(sampled_energies, noise)
        
        # perform generator prediction
        generated_images = generator.predict(generator_input, verbose=0)

        # train discriminator on real batch and fake batch 
        # real batch will have label 1 and fake label 0
        real_batch_loss = discriminator.train_on_batch(image_batch, np.ones(batch_size)) 
        fake_batch_loss = discriminator.train_on_batch(generated_images, np.zeros(batch_size)) 
  
        # compute total discriminator loss 
        discrim_loss = (real_batch_loss + fake_batch_loss)/2
        
        #print(discrim_loss)
        
        # save loss for each epoch
        epoch_disc_loss.append(discrim_loss)



        #generate input random to generator
        noise = np.random.normal(0, 1, (batch_size, latent_size))
        sampled_energies = np.random.uniform(1, 5, ( batch_size,1 ))
        generator_input = np.multiply(sampled_energies, noise)


        # train generator and compute its loss
        generator_loss = combined.train_on_batch( generator_input, np.ones(batch_size))
        
        #print(generator_loss)


        epoch_gen_loss.append( generator_loss )
                              
        #### TESTING                      

    print('\nTesting for epoch :' , epoch + 1)

    ### generate test random data                      
    noise = np.random.normal(0, 1, (nb_test, latent_size))
    sampled_energies = np.random.uniform(1, 5, (nb_test, 1))
    generator_input = np.multiply(sampled_energies, noise)
    generated_images = generator.predict(generator_input, verbose=False)


                              
    X = np.concatenate((X_test, generated_images))
    y = np.array([1] * nb_test + [0] * nb_test)
                              
    discriminator_test_loss = discriminator.evaluate(X, y, verbose=False, batch_size=batch_size)

     
    discriminator_train_loss = np.mean(np.array(epoch_disc_loss), axis=0)

    noise = np.random.normal(0, 1, (2 * nb_test, latent_size))
    sampled_energies = np.random.uniform(1, 5, (2 * nb_test, 1))
    generator_input = np.multiply(sampled_energies, noise)
      

    trick = np.ones(2 * nb_test)

    generator_test_loss = combined.evaluate(generator_input, trick, verbose=False, batch_size=batch_size)

    generator_train_loss = np.mean(np.array(epoch_gen_loss), axis=0)


    print('generator  train and test loss', generator_train_loss, generator_test_loss)
    print('discrimin. train and test loss', discriminator_train_loss, discriminator_test_loss)

        
        
    # save loss in an histogram 
                              
    hl_disc_test.SetBinContent(epoch+1,discriminator_test_loss)
    hl_disc_train.SetBinContent(epoch+1,discriminator_train_loss)
    hl_gen_test.SetBinContent(epoch+1,generator_test_loss)
    hl_gen_train.SetBinContent(epoch+1,generator_train_loss)

    # save weights every epoch
    generator.save_weights('generator_weights_epoch_'+ str(epoch)+'.h5', overwrite=True)
    discriminator.save_weights('discriminator_weights_epoch_'+ str(epoch)+'.h5', overwrite=True)
  

        # make histograms of generated images
        
    hname = "hgen_" + str(epoch+1)
    htitle = "Generated image_" + str(epoch+1)
    hgen = ROOT.TH2D(hname,htitle,25,0,25,25,0,25)
    hgenX = ROOT.TH1D("hgenX"+str(epoch+1),htitle,25,0,25)
    hgenY = ROOT.TH1D("hgenY"+str(epoch+1),htitle,25,0,25)
    hgenE = ROOT.TH1D("hgenE"+str(epoch+1),htitle,100,0,10)
    hgenX.SetLineColor(ROOT.kRed)
    hgenY.SetLineColor(ROOT.kRed)
    hgenE.SetLineColor(ROOT.kRed)
        

    nevents = X_test.shape[0]
    for index in range(nevents):
        sum = 0
        for i in range(25):
            for j in range(25) :
                if (epoch == 0) : htrue.Fill(i+.5,j+.5,X_test[index,0,i,j])
                hgen.Fill(i+.5,j+.5,generated_images[index,0,i,j])
                hgenX.Fill(i+.5,generated_images[index,0,i,j])
                hgenY.Fill(j+.5,generated_images[index,0,i,j])
                sum += generated_images[index,0,i,j]
                    
        hgenE.Fill(sum)
            


In [None]:
### Plot Loss

closs.cd(1)
hl_disc_test.Draw()
closs.cd(2)
hl_disc_train.Draw()
closs.cd(3)
hl_gen_test.Draw()
closs.cd(4)
hl_gen_train.Draw()
closs.Draw()
closs.Update()


In [None]:
# plot images 
c1 = ROOT.TCanvas(cname,cname,1500,1000)
c1.Divide(3,2)
c1.cd(1)
htrue.SetMaximum(1.2*htrue.GetMaximum())
htrue.Draw("COLZ")
c1.cd(2)
hgen.SetMaximum(1.2*htrue.GetMaximum() )
hgen.DrawCopy("COLZ")
c1.cd(3)
hgenX.SetMaximum(htrueX.GetMaximum()*1.3)
hgenX.DrawCopy("HIST")
htrueX.DrawCopy("SAME HIST")
c1.cd(4)
hgenY.SetMaximum(htrueY.GetMaximum()*1.3)
hgenY.DrawCopy("HIST")
htrueY.DrawCopy("SAME HIST")
c1.cd(5)
hgenE.SetMaximum(htrueE.GetMaximum()*1.3)
hgenE.DrawCopy("HIST")
htrueE.DrawCopy("SAME")
c1.Update()
c1.Draw()