In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import pickle
import math

trunc_normal= tf.truncated_normal_initializer(stddev=1)
normal = tf.random_normal_initializer(stddev=1)

from NN_architectures import *

In [2]:
# some constants

LEARNING_RATE = 0.0001
BETA1 = 0.5
BATCH_SIZE = 128
EPOCHS = 15
#SAVE_SAMPLE_PERIOD = 100

task='TRAIN'
#task='TEST'
PATH='resCNN_test'

In [3]:
global sizes
sizes = {
    
        'convblock_layer_0': [(2, 4, 2, False, 0.5, lrelu, normal),
                             (8, 4, 1, True, 0.5, lrelu, normal),
                             (16, 4, 1, False, 0.5, lrelu, normal)],
        'convblock_shortcut_layer_0':[(16, 1, 2, False, 0.5, normal)],
        
        'maxpool_layer_0':[(4, 2, 1)],
    
        'convblock_layer_1': [(16, 8, 2, False, 0.5, lrelu, normal),
                             (16, 8, 1, True, 0.5, lrelu, normal),
                             (64, 8, 1, False, 0.5, lrelu, normal)],
        'convblock_shortcut_layer_1':[(64, 8, 2, False, 0.5, normal)],
        
        'max_pool_layer_1':[(4, 2, 1)],
        'dense_layers':[(1024, True, 0.8, tf.nn.relu, normal)],
        'n_classes':10
}

In [4]:
def mnist():
    
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

    X_train = mnist.train.images
    Y_train = mnist.train.labels
    
    X_train = X_train.reshape(len(X_train),28,28,1)
    Y_train = Y_train.reshape(len(Y_train),10)
    
    n_W = X_train.shape[1]
    n_H = X_train.shape[2]
    n_C = X_train.shape[-1]
    
    X_test = mnist.test.images
    Y_test = mnist.test.labels
    
    X_test = X_test.reshape(len(X_test),28,28,1)
    Y_test =  Y_test.reshape(len(Y_test),10)
    
    tf.reset_default_graph()
    cnn = resCNN(n_W, n_H, n_C, sizes,
              lr=LEARNING_RATE, beta1=BETA1,
              batch_size=BATCH_SIZE, epochs=EPOCHS,
              path=PATH)
    
    vars_to_train= tf.trainable_variables()
    
    if task == 'TRAIN':
        init_op = tf.global_variables_initializer()
        
    if task == 'TEST':
        vars_all = tf.global_variables()
        vars_to_init = list(set(vars_all)-set(vars_to_train))
        init_op = tf.variables_initializer(vars_to_init)
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        
        sess.run(init_op)

        if task=='TRAIN':
            print('\n Training...')
            
            if os.path.exists(PATH+'/CNN_model.ckpt.index'):
                saver.restore(sess,PATH+'/CNN_model.ckpt')
                print('Model restored.')
            
            cnn.set_session(sess)
            cnn.fit(X_train, Y_train)
            
            save_path = saver.save(sess, PATH+'/CNN_model.ckpt')
            print("Model saved in path: %s" % save_path)
        
        if task=='TEST':
            print('\n Evaluate model on test set...')
            saver.restore(sess,PATH+'/CNN_model.ckpt')
            print('Model restored.')
            
            cnn.set_session(sess)

           
        true_hist = X_train.reshape(train_size, X_train.shape[1]*X_train.shape[2])
        true_hist = np.sum(true_hist,axis=1)
        
        reco_hist = Y_train.reshape(train_size, Y_train.shape[1]*Y_train.shape[2])
        reco_hist = np.sum(reco_hist,axis=1)
        
        nn_reco_hist = np.array([cnn.generate_sample([X_train[i]]) for i in range(train_size)]).reshape(train_size,52*64)
        nn_reco_hist = np.sum(nn_reco_hist,axis=1)               
        

        done = False
        while not done:
        #for i in range(0,1):
            
            i = np.random.choice(len(X_test))
            x = X_test[i].reshape(X_test.shape[1],X_test.shape[2],X_test.shape[3])
            reco = Y_test[i].reshape(Y_test.shape[1],Y_test.shape[2])
            im = cnn.generate_sample([x])
            print('True energy deposit: '+str(np.sum(x))+'\n'+
                  'HCAL reconstructed energy deposit: '+str(np.sum(reco))+'\n'+
                  'NN Simulated energy deposit: '+str(np.sum(im)))
            
            plt.subplot(1,3,1)
            
            plt.imshow(x.reshape(
                                X_test.shape[1],
                                X_test.shape[2]),
                                cmap='gray'
                      )
            plt.title('Original\n True energy deposit: '+str(np.sum(x)))
            plt.axis('off')
            
            plt.subplot(1,3,2)
            plt.imshow(reco.reshape(52,64), cmap='gray')
            plt.title('Reco\n Reco energy deposit: '+str(np.sum(reco)))
            plt.axis('off')
            
            plt.subplot(1,3,3)
            plt.imshow(im.reshape(52,64), cmap='gray')
            plt.title('NN Generated\n NN energy deposit: '+str(np.sum(im)))
            
            plt.show()
            plt.axis('off')
                    
            fig = plt.gcf()
            fig.set_size_inches(20, 10)
            plt.savefig(PATH+'/sample_%d.png' % i,dpi=100)
            

            ans = input("Generate another?")
            if ans and ans[0] in ('n' or 'N'):
                done = True
            
        return true_hist, reco_hist, nn_reco_hist, mean_ET, sigma_ET
                
    

In [5]:
if __name__=='__main__':
    task='TRAIN'
    #task='TEST'
    if not os.path.exists(PATH):
        os.mkdir(PATH)
        
    elif os.path.exists(PATH):
    
        ans = input('Do you want to overwrite the current model saved at '+PATH+'?\n')
        if ans and ans[0] in ('n' or 'N'):
            PATH = input('Specify the name of the model, a new directory will be created.\n')
            os.mkdir(PATH)
        else:
            print('Overwriting existing model in '+PATH)
            
    mnist()

Do you want to overwrite the current model saved at resCNN_test?
y
Overwriting existing model in resCNN_test
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Residual Network architecture detected
64 1024


NameError: name 'M_in' is not defined

In [None]:
sizes['maxpool_layer_0'][0]