In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import pickle
import math



uniform = tf.random_uniform_initializer()
glorot_uniform = tf.glorot_uniform_initializer()

normal = tf.random_normal_initializer(stddev=0.02)
glorot_normal = tf.glorot_normal_initializer()
trunc_normal= tf.truncated_normal_initializer(stddev=0.02)

from NN_architectures import *

  return f(*args, **kwds)


In [2]:
# some constants

LEARNING_RATE = 0.01
BETA1 = 0.5
BATCH_SIZE = 128
EPOCHS = 15
SAVE_SAMPLE_PERIOD = 5

task='TRAIN'
#task='TEST'
PATH='DCVAE_test'

In [3]:
global e_sizes, d_sizes

e_sizes = {
        'conv_layers': [(2, 5, 2, False, 1, lrelu, trunc_normal),
                         (64, 5, 2, True, 1, lrelu, trunc_normal)],
        
        'dense_layers': [(1024, True, 1, lrelu, trunc_normal), 
                         (512, True, 1, tf.nn.relu, trunc_normal)],
        'z': 100
}
    
d_sizes = {
        'projection': 128,
        'bn_after_project': False,
        'conv_layers': [(128, 5, 2, True, 1, tf.nn.relu, glorot_normal),
                        (1, 5, 2, False, 1, tf.nn.relu, glorot_normal)],
    
        'dense_layers': [(512, True, 1, tf.nn.relu, glorot_normal),
                         (1024, True, 1 , tf.nn.relu, glorot_normal)],
    
        'output_activation': tf.sigmoid,
}

In [4]:
def mnist():
    
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

    X_train = mnist.train.images
    #X_train = X_train/255
    
    X_train = X_train.reshape(len(X_train),28,28,1)
    X_train = (X_train>0.5).astype(np.float32)
    
    X_test = mnist.test.images
    #X_test = X_test/255

    X_test = X_test.reshape(len(X_test),28,28,1)
    X_test = (X_test>0.5).astype(np.float32)
    
    n_W = X_train.shape[1]
    n_H = X_train.shape[2]
    n_C = X_train.shape[-1]
    
    tf.reset_default_graph()
    dcvae = DCVAE(n_W, n_H, n_C, e_sizes, d_sizes,
              lr=LEARNING_RATE, beta1=BETA1,
              batch_size=BATCH_SIZE, epochs=EPOCHS,
              save_sample= SAVE_SAMPLE_PERIOD, path=PATH)
    
    vars_to_train= tf.trainable_variables()
    
    if task == 'TRAIN':
        init_op = tf.global_variables_initializer()
        
    if task == 'TEST':
        vars_all = tf.global_variables()
        vars_to_init = list(set(vars_all)-set(vars_to_train))
        init_op = tf.variables_initializer(vars_to_init)
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        
        sess.run(init_op)

        if task=='TRAIN':
            print('\n Training...')
            
            if os.path.exists(PATH+'/CNN_model.ckpt.index'):
                saver.restore(sess,PATH+'/CNN_model.ckpt')
                print('Model restored.')
            
            dcvae.set_session(sess)
            dcvae.fit(X_train)
            
            save_path = saver.save(sess, PATH+'/CNN_model.ckpt')
            print("Model saved in path: %s" % save_path)
        
        if task=='TEST':
            print('\n Evaluate model on test set...')
            saver.restore(sess,PATH+'/CNN_model.ckpt')
            print('Model restored.')
            
            dae.set_session(sess)

#       true_hist = X_train.reshape(train_size, X_train.shape[1]*X_train.shape[2])
#       true_hist = np.sum(true_hist,axis=1)
        
#       reco_hist = Y_train.reshape(train_size, Y_train.shape[1]*Y_train.shape[2])
#       reco_hist = np.sum(reco_hist,axis=1)
        
#       nn_reco_hist = np.array([cnn.generate_sample([X_train[i]]) for i in range(train_size)]).reshape(train_size,52*64)
#       nn_reco_hist = np.sum(nn_reco_hist,axis=1)               
        
        done = False
        while not done:
        #for i in range(0,1):
            
            i = np.random.choice(len(X_test))
            x = X_test[i].reshape(X_test.shape[1],X_test.shape[2],X_test.shape[3])
            reco = Y_test[i].reshape(Y_test.shape[1],Y_test.shape[2])
            im = dae.generate_sample([x])
            print('True energy deposit: '+str(np.sum(x))+'\n'+
                  'HCAL reconstructed energy deposit: '+str(np.sum(reco))+'\n'+
                  'NN Simulated energy deposit: '+str(np.sum(im)))
            
            plt.subplot(1,3,1)
            
            plt.imshow(x.reshape(
                                X_test.shape[1],
                                X_test.shape[2]),
                                cmap='gray'
                      )
            plt.title('Original\n True energy deposit: '+str(np.sum(x)))
            plt.axis('off')
            
            plt.subplot(1,3,2)
            plt.imshow(reco.reshape(52,64), cmap='gray')
            plt.title('Reco\n Reco energy deposit: '+str(np.sum(reco)))
            plt.axis('off')
            
            plt.subplot(1,3,3)
            plt.imshow(im.reshape(52,64), cmap='gray')
            plt.title('NN Generated\n NN energy deposit: '+str(np.sum(im)))
            
            plt.show()
            plt.axis('off')
                    
            fig = plt.gcf()
            fig.set_size_inches(20, 10)
            plt.savefig(PATH+'/sample_%d.png' % i,dpi=100)
            

            ans = input("Generate another?")
            if ans and ans[0] in ('n' or 'N'):
                done = True

In [5]:
if __name__=='__main__':
    task='TRAIN'
    #task='TEST'
    if not os.path.exists(PATH):
        os.mkdir(PATH)
        
    elif os.path.exists(PATH):
    
        ans = input('Do you want to overwrite the current model saved at '+PATH+'?\n')
        if ans and ans[0] in ('n' or 'N'):
            PATH = input('Specify the name of the model, a new directory will be created.\n')
            os.mkdir(PATH)
        else:
            print('Overwriting existing model in '+PATH)
            
    mnist()

Do you want to overwrite the current model saved at DCVAE_test?
y
Overwriting existing model in DCVAE_test
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

 Training...

 ****** 

Training deep convolutional VAE with a total of 55000 samples distributed in 429 batches, each of size 128

The learning rate set is 0.01, and every 5 epoch a generated sample will be saved to DCVAE_test

 ****** 



KeyboardInterrupt: 