In [1]:
#os stuff
import os
import sys
import h5py as h5
import re

#timing
import time

#numpy
import numpy as np

#tensorflow
import tensorflow as tf
import tensorflow.contrib.keras as tfk

#housekeeping
from scripts.nbfinder import NotebookFinder
sys.meta_path.append(NotebookFinder())
import notebooks.networks.binary_classifier_tf as bc

importing Jupyter notebook from notebooks/networks/binary_classifier_tf.ipynb


# Network Parameters

In [2]:
args={'input_shape': [None, 1, 64, 64], 
                      'arch' : 'hsw',
                      'display_interval': 10,
                      'save_interval': 1,
                      'learning_rate': 1.e-5, 
                      'dropout_p': 0.5, 
                      'weight_decay': 0, #0.0001, 
                      'num_fc_units': 512,
                      'num_layers': 3,
                      'momentum': 0.9,
                      'num_epochs': 200,
                      'train_batch_size': 512, #480
                      'validation_batch_size': 320, #480
                      'batch_norm': True,
                      'time': True,
                      'conv_params': dict(num_filters=128, 
                                       filter_size=3, padding='SAME', 
                                       activation=tf.nn.relu, 
                                       initializer=tfk.initializers.he_normal())
                     }

## Build Network and Functions

In [3]:
print("Building model")
variables, network = bc.build_cnn_model(args)
pred_fn, loss_fn, accuracy_fn, auc_fn = bc.build_functions(variables, network)
tf.add_to_collection('pred_fn', pred_fn)
tf.add_to_collection('loss_fn', loss_fn)

Building model


In [4]:
print variables
print network

{'bn2_g': <tf.Variable 'bn2_g:0' shape=(128, 32, 32) dtype=float32_ref>, 'images_': <tf.Tensor 'Placeholder:0' shape=(?, 1, 64, 64) dtype=float32>, 'bn1_m': <tf.Variable 'bn1_m:0' shape=(128, 64, 64) dtype=float32_ref>, 'bn2_m': <tf.Variable 'bn2_m:0' shape=(128, 32, 32) dtype=float32_ref>, 'bn1_g': <tf.Variable 'bn1_g:0' shape=(128, 64, 64) dtype=float32_ref>, 'bn2_s': <tf.Variable 'bn2_s:0' shape=(128, 32, 32) dtype=float32_ref>, 'bn2_b': <tf.Variable 'bn2_b:0' shape=(128, 32, 32) dtype=float32_ref>, 'bn1_s': <tf.Variable 'bn1_s:0' shape=(128, 64, 64) dtype=float32_ref>, 'keep_prob_': <tf.Tensor 'Placeholder_1:0' shape=<unknown> dtype=float32>, 'fc2_w': <tf.Variable 'fc2_w:0' shape=(512, 2) dtype=float32_ref>, 'fc1_b': <tf.Variable 'fc1_b:0' shape=(512,) dtype=float32_ref>, 'fc2_b': <tf.Variable 'fc2_b:0' shape=(2,) dtype=float32_ref>, 'fc1_w': <tf.Variable 'fc1_w:0' shape=(8192, 512) dtype=float32_ref>, 'labels_': <tf.Tensor 'Placeholder_2:0' shape=(?, 1) dtype=int32>, 'bn1_b': <tf.

## Setup Iterators

### My Files

In [5]:
if False:
    print("Setting up iterators")
    #paths
    inputpath = '/global/cscratch1/sd/tkurth/atlas_dl/data_delphes_final_64x64'
    logpath = '/project/projectdirs/mpccc/tkurth/MANTISSA-HEP/atlas_dl/temp/tensorflow_logs/hep_classifier_log'
    modelpath = '/project/projectdirs/mpccc/tkurth/MANTISSA-HEP/atlas_dl/temp/tensorflow_models/hep_classifier_models'
    #training files
    trainfiles = [inputpath+'/'+x for x in os.listdir(inputpath) if x.startswith('hep_train') and x.endswith('.hdf5')]
    trainset=bc.DataSet(trainfiles[0:20])
    #validation files
    validationfiles = [inputpath+'/'+x for x in os.listdir(inputpath) if x.startswith('hep_valid') and x.endswith('.hdf5')]
    validationset = bc.DataSet(validationfiles[0:20])

### Evans Files

In [6]:
if True:
    print("Setting up iterators")
    #paths
    inputpath = '/global/cscratch1/sd/wbhimji/delphes_combined_64imageNoPU'
    logpath = '/project/projectdirs/mpccc/tkurth/MANTISSA-HEP/atlas_dl/temp/tensorflow_logs/hep_classifier_log'
    modelpath = '/project/projectdirs/mpccc/tkurth/MANTISSA-HEP/atlas_dl/temp/tensorflow_models/hep_classifier_models'
    #training files
    trainfiles = [inputpath+'/'+x for x in os.listdir(inputpath) if x.startswith('train_') and x.endswith('.h5')]
    trainset=bc.DataSetEvan(trainfiles)
    #validation files
    validationfiles = [inputpath+'/'+x for x in os.listdir(inputpath) if x.startswith('val_') and x.endswith('.h5')]
    validationset = bc.DataSetEvan(validationfiles)

Setting up iterators


# Train Model

In [7]:
arch=args['arch']

#common stuff
os.environ["KMP_BLOCKTIME"] = "1"
os.environ["KMP_SETTINGS"] = "1"
os.environ["KMP_AFFINITY"]= "granularity=fine,verbose,compact,1,0"

#arch-specific stuff
if arch=='hsw':
    num_inter_threads = 2
    num_intra_threads = 16
elif arch=='knl':
    num_inter_threads = 2
    num_intra_threads = 66
elif arch=='k80':
    num_inter_threads = -1
    num_intra_threads = -1
else:
    raise ValueError('Please specify a valid architecture with arch (allowed values: hsw, knl.)')

#set the rest
os.environ['OMP_NUM_THREADS'] = str(num_intra_threads)
sess_config=tf.ConfigProto(inter_op_parallelism_threads=num_inter_threads,
                           intra_op_parallelism_threads=num_intra_threads,
                           allow_soft_placement=True, 
                           log_device_placement=True)

print("Using ",num_inter_threads,"-way task parallelism with ",num_intra_threads,"-way data parallelism.")

('Using ', 2, '-way task parallelism with ', 16, '-way data parallelism.')


In [8]:
restart=False
#determining which model to load:
metafilelist = [modelpath+'/'+x for x in os.listdir(modelpath) if x.endswith('.meta')]
if not metafilelist:
    restart=True
#metafilelist.sort()
#metafile = metafilelist[-1]
#checkpoint = metafile.replace(".meta","")
#print metafile
#restart from scratch or restore?

/project/projectdirs/mpccc/tkurth/MANTISSA-HEP/atlas_dl/temp/tensorflow_models/hep_classifier_models_debug/hep_classifier_tfmodel_epoch_8.ckpt.meta


In [9]:
#initialize session
print("Start training")
with tf.Session(config=sess_config) as sess:

    #train on training loss
    train_step = tf.train.AdamOptimizer(args['learning_rate']).minimize(loss_fn)

    #create summaries
    var_summary = []
    for item in variables:
        var_summary.append(tf.summary.histogram(item,variables[item]))
        #if item.startswith('conv'):
        #    #add additional image feature maps
        #    for i in range(variables_dict.shape[])
        #    tf.summary.image()
    summary_loss = tf.summary.scalar("loss",loss_fn)
    summary_accuracy = tf.summary.scalar("accuracy",accuracy_fn)
    train_summary = tf.summary.merge([summary_loss]+var_summary)
    validation_summary = tf.summary.merge([summary_loss])
    train_writer = tf.summary.FileWriter(logpath, sess.graph)
    
    # Add an op to initialize the variables.
    init_global_op = tf.global_variables_initializer()
    init_local_op = tf.local_variables_initializer()
    
    #saver class:
    model_saver = tf.train.Saver()        
    
    #initialize variables
    sess.run([init_global_op, init_local_op])
    
    #counter stuff
    trainset.reset()
    validationset.reset()
    
    #restore weights belonging to graph
    epochs_completed = 0
    if not restart:
        last_model = tf.train.latest_checkpoint(modelpath)
        print("Restoring model %s.",last_model)
        model_saver.restore(sess,last_model)
        epochs_completed = int(re.match(r'^.*?\_epoch\_(.*?)\.ckpt.*?$',last_model).groups()[0])
        trainset._epochs_completed = epochs_completed
    
    #losses
    train_loss=0.
    train_batches=0
    total_batches=0
    train_time=0
    
    #do training
    while epochs_completed < args['num_epochs']:
        
        #increment total batch counter
        total_batches+=1
        
        #get next batch
        images,labels,normweights,_ = trainset.next_batch(args['train_batch_size'])  
        #set weights to zero
        normweights[:] = 1.
        
        #update weights
        start_time = time.time()
        _, summary, tmp_loss, pred = sess.run([train_step, train_summary, loss_fn, pred_fn],
                                           feed_dict={variables['images_']: images, 
                                              variables['labels_']: labels, 
                                              variables['weights_']: normweights, 
                                              variables['keep_prob_']: args['dropout_p']})
        end_time = time.time()
        train_time += end_time-start_time
        
        #add to summary
        train_writer.add_summary(summary, total_batches)
        
        #increment train loss and batch number
        train_loss += tmp_loss
        train_batches += 1
        
        #determine if we give a short update:
        if train_batches%args['display_interval']==0:
            print("REPORT epoch %d.%d, average training loss %g (%.3f sec/batch)"%(epochs_completed, train_batches,
                                                                                train_loss/float(train_batches),
                                                                                train_time/float(train_batches)))
        
        #check if epoch is done
        if trainset._epochs_completed>epochs_completed:
            epochs_completed=trainset._epochs_completed
            print("COMPLETED epoch %d, average training loss %g (%.3f sec/batch)"%(epochs_completed, 
                                                                                 train_loss/float(train_batches),
                                                                                 train_time/float(train_batches)))
            train_loss=0.
            train_batches=0
            train_time=0
            
            #compute validation loss:
            #reset variables
            validation_loss=0.
            validation_batches=0
            sess.run(init_local_op)
            
            #all_labels=[]
            #all_weights=[]
            #all_pred=[]
            
            #iterate over batches
            while True:
                #get next batch
                images,labels,normweights,weights = validationset.next_batch(args['validation_batch_size'])
                #set weights to 1:
                normweights[:] = 1.
                weights[:] = 1.
                
                #compute loss
                summary, tmp_loss=sess.run([validation_summary,loss_fn],
                                            feed_dict={variables['images_']: images, 
                                                        variables['labels_']: labels, 
                                                        variables['weights_']: normweights, 
                                                        variables['keep_prob_']: 1.0})
                
                #add loss
                validation_loss += tmp_loss
                validation_batches += 1
                
                #update accuracy
                sess.run(accuracy_fn[1],feed_dict={variables['images_']: images, 
                                                    variables['labels_']: labels, 
                                                    variables['weights_']: weights, 
                                                    variables['keep_prob_']: 1.0})
                
                #update auc
                sess.run(auc_fn[1],feed_dict={variables['images_']: images, 
                                              variables['labels_']: labels, 
                                              variables['weights_']: weights, 
                                              variables['keep_prob_']: 1.0})
                
                #debugging
                #pred = sess.run(pred_fn,
                #                feed_dict={variables['images_']: images, 
                #                            variables['labels_']: labels, 
                #                            variables['weights_']: weights, 
                #                            variables['keep_prob_']: 1.0})
                #all_labels.append(labels)
                #all_weights.append(weights)
                #all_pred.append(pred[:,1])
                
                #check if full pass done
                if validationset._epochs_completed>0:
                    validationset.reset()
                    break
                    
            
            #sklearn ROC
            #all_labels = np.concatenate(all_labels,axis=0).flatten()
            #all_pred = np.concatenate(all_pred,axis=0).flatten()
            #all_weights = np.concatenate(all_weights,axis=0).flatten()
            #fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_pred, pos_label=1, sample_weight=all_weights)
            #print("epoch %d, sklearn AUC %g"%(epochs_completed,metrics.auc(fpr,tpr,reorder=True)))
            
            print("COMPLETED epoch %d, average validation loss %g"%(epochs_completed, validation_loss/float(validation_batches)))
            validation_accuracy = sess.run(accuracy_fn[0])
            print("COMPLETED epoch %d, average validation accu %g"%(epochs_completed, validation_accuracy))
            validation_auc = sess.run(auc_fn[0])
            print("COMPLETED epoch %d, average validation auc %g"%(epochs_completed, validation_auc))
            
            # Save the variables to disk.
            if epochs_completed%args['save_interval']==0:
                model_save_path = model_saver.save(sess, modelpath+'/hep_classifier_tfmodel_epoch_'+str(epochs_completed)+'.ckpt')
                print 'Model saved in file: %s'%model_save_path

Start training
Restoring model.
INFO:tensorflow:Restoring parameters from /project/projectdirs/mpccc/tkurth/MANTISSA-HEP/atlas_dl/temp/tensorflow_models/hep_classifier_models/hep_classifier_tfmodel_epoch_8.ckpt


InvalidArgumentError: Default MaxPoolingOp only supports NHWC.
	 [[Node: maxpool1 = MaxPool[T=DT_FLOAT, data_format="NCHW", ksize=[1, 1, 2, 2], padding="SAME", strides=[1, 1, 2, 2], _device="/job:localhost/replica:0/task:0/cpu:0"](Relu)]]

Caused by op u'maxpool1', defined at:
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/runpy.py", line 174, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/global/u2/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/IPython/kernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 474, in start
    ioloop.IOLoop.instance().start()
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/tornado/ioloop.py", line 887, in start
    handler_func(fd_obj, events)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 276, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 228, in dispatch_shell
    handler(stream, idents, msg)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 390, in execute_request
    user_expressions, allow_stdin)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/ipykernel/zmqshell.py", line 501, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2717, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2821, in run_ast_nodes
    if self.run_code(code, result):
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2881, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-2be599c03cba>", line 2, in <module>
    variables, network = bc.build_cnn_model(args)
  File "<string>", line 80, in build_cnn_model
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/tensorflow/python/ops/nn_ops.py", line 1821, in max_pool
    name=name)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 1638, in _max_pool
    data_format=data_format, name=name)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 768, in apply_op
    op_def=op_def)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2336, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/global/homes/t/tkurth/.conda/envs/thorstendl/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1228, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Default MaxPoolingOp only supports NHWC.
	 [[Node: maxpool1 = MaxPool[T=DT_FLOAT, data_format="NCHW", ksize=[1, 1, 2, 2], padding="SAME", strides=[1, 1, 2, 2], _device="/job:localhost/replica:0/task:0/cpu:0"](Relu)]]
