## General modules

In [None]:
__author__ = 'tkurth'
import sys
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn import preprocessing
from sklearn.cross_validation import train_test_split
from nbfinder import NotebookFinder
sys.meta_path.append(NotebookFinder())
%matplotlib inline
import time

## Theano modules

In [None]:
import theano
import theano.tensor as T
import lasagne as ls

## ROOT stuff

In [None]:
sys.path.append('/global/homes/w/wbhimji/cori-envs/nersc-rootpy/lib/python2.7/site-packages/')
import ROOT
import rootpy
import root_numpy as rnp

## Useful functions

In [None]:
# Define a context manager to suppress stdout and stderr.
class suppress_stdout_stderr(object):
    '''
    A context manager for doing a "deep suppression" of stdout and stderr in 
    Python, i.e. will suppress all print, even if the print originates in a 
    compiled C/Fortran sub-function.
       This will not suppress raised exceptions, since exceptions are printed
    to stderr just before a script exits, and after the context manager has
    exited (at least, I think that is why it lets exceptions through).      

    '''
    def __init__(self):
        # Open a pair of null files
        self.null_fds =  [os.open(os.devnull,os.O_RDWR) for x in range(2)]
        # Save the actual stdout (1) and stderr (2) file descriptors.
        self.save_fds = (os.dup(1), os.dup(2))

    def __enter__(self):
        # Assign the null pointers to stdout and stderr.
        os.dup2(self.null_fds[0],1)
        os.dup2(self.null_fds[1],2)

    def __exit__(self, *_):
        # Re-assign the real stdout/stderr back to (1) and (2)
        os.dup2(self.save_fds[0],1)
        os.dup2(self.save_fds[1],2)
        # Close the null files
        os.close(self.null_fds[0])
        os.close(self.null_fds[1])

## Data iterator

In [None]:
class hep_data_iterator:
    
    #class constructor
    def __init__(self,
                 bg_cfg_file = '../config/BgFileListAug16.txt',
                 sig_cfg_file='../config/SignalFileListAug16.txt',
                 group_name='CollectionTree',
                 branches=['CaloCalTopoClustersAuxDyn.calPhi', 'CaloCalTopoClustersAuxDyn.calEta','CaloCalTopoClustersAuxDyn.calE'],
                 num_events_total=8192,
                 num_events_cached=512,
                 num_classes=2,
                 shuffle=True,
                 preprocess=True,
                 bin_size=0.025,
                 eta_range = [-5,5],
                 phi_range = [-3.14, 3.14],
                 dataset_name='histo'):
        
        #general stuff
        self.group_name=group_name
        self.branches=branches
        
        #set some local variables:
        self.bg_files = [line.rstrip() for line in open(bg_cfg_file)]
        self.sig_files = [line.rstrip() for line in open(sig_cfg_file)]

        #that is somehow hardcoded
        self.events_per_sig_file = 10000
    
        #number of total events:
        self.num_events_total=num_events_total
        self.num_events_cached=num_events_cached
        
        #number of classes
        self.num_classes=num_classes
        
        #if no total number of events is specified, take everything
        if not self.num_events_total:
            self.num_events_total = self.events_per_sig_file * np.min([len(self.sig_files),len(self.bg_files)])
        
        #we assume there are more bg per file than sig, so we bound our number of files by number of files
        #needed for a sig event
        if self.num_events_total % self.num_classes != 0:
            #adjust to class frequencies
            self.num_events_total -= self.num_events_total % self.num_classes
        
        #how many events to we need:
        self.num_each = self.num_events_total / self.num_classes

        #get the number of files needed in total and for caching
        self.num_files = int(np.ceil(self.num_each / float(self.events_per_sig_file)))

        #hack because rootpy does not do well with one file
        if self.num_files == 1:
            self.num_files = 2
        
        #shuffle array in cases where not all files are used:
        self.shuffle=shuffle
        if self.shuffle:
            np.random.shuffle(self.bg_files)
            np.random.shuffle(self.sig_files)
        
        #restrict the number of elements:
        self.bg_files=self.bg_files[:self.num_files]
        self.sig_files=self.sig_files[:self.num_files]
        
        #set file-counter to zero:
        self.filecount=0
        
        #some other required preprocessing
        self.phi_range=phi_range
        self.eta_range=eta_range
        self.bin_size=bin_size
        self.phi_bins = int(np.floor((self.phi_range[1] - self.phi_range[0]) / self.bin_size))
        self.eta_bins = int(np.floor((self.eta_range[1] - self.eta_range[0]) / self.bin_size))
        
        #compute cache size
        self.compute_cache()
        
        #compute absolute max over whole set
        self.compute_data_max()
        
        #prefetch the first batch
        self.prefetch()
        
        #store the shapes:
        self.xshape=self.x[0].shape
        self.yshape=self.y[0].shape
    
    
    #compute cache pars:
    def compute_cache(self):
        #make sure that it is an integer multiple of the number of classes
        if self.num_events_cached % self.num_classes != 0:
            self.num_events_cached -= self.num_events_cached % self.num_classes
            
        #how many per class
        self.num_each_cached = self.num_events_cached / self.num_classes
        
        #how many files:
        self.num_files_cached = int(np.ceil(self.num_each_cached / float(self.events_per_sig_file)))
        
        #hack because rootpy does not do well with one file
        if self.num_files_cached == 1:
            self.num_files_cached = 2
    
    
    #compute max over all data
    def compute_data_max(self):
        '''compute the maximum over all event entries for rescaling data between -1 and 1'''
        #we have to iterate through all available data at least once to compute the max
        #to avoid memory overflows, only iterate in units of numbers of cached events:
        self.max_abs=0.
        for i in np.arange(0,self.num_files,self.num_files_cached):
            #upper boundary
            upper=np.min([self.num_files,i+self.num_files_cached])
            
            #so we don't have annoying stderr messages
            with suppress_stdout_stderr():
                #get arrays
                bgarray = rnp.root2array(self.bg_files[i:upper], \
                                         treename=self.group_name, \
                                         branches=self.branches, \
                                         start=0, \
                                         stop=self.num_each_cached,warn_missing_tree=True)

                sigarray = rnp.root2array(self.sig_files[i:upper],\
                                          treename=self.group_name,\
                                          branches=self.branches,\
                                          start=0, \
                                          stop=self.num_each_cached,warn_missing_tree=True)
            
            #convert to dataframe and compute max
            df = pd.concat([pd.DataFrame.from_records(bgarray),pd.DataFrame.from_records(sigarray)])
            tmpmax=(df['CaloCalTopoClustersAuxDyn.calE'].abs()).apply(lambda x: np.max(x)).max()
            
            #update max
            self.max_abs=np.max([tmpmax,self.max_abs])
    
    
    #fetch the next bunch of data and preprocess
    def prefetch(self):
        '''prefetch the next bunch of events'''
        #shuffle and start from zero if the number of remaining cached files is too small:
        if self.filecount+self.num_files_cached>self.num_files:
            self.filecount=0
        
            #reshuffle data
            if self.shuffle:
                np.random.shuffle(self.bg_files)
                np.random.shuffle(self.sig_files)
        
        #so we don't have annoying stderr messages
        with suppress_stdout_stderr():
            
            #bgarray has n_events groups of 3 parallel numpy arrays 
            #(each numpy within a group is of equal length and each array corresponds to phi, eta and the corresponding energy)
            bgarray = rnp.root2array(self.bg_files[self.filecount:self.filecount+self.num_files_cached], \
                                     treename=self.group_name, \
                                     branches=self.branches, \
                                     start=0, \
                                     stop=self.num_each_cached,warn_missing_tree=True)

            sigarray = rnp.root2array(self.sig_files[self.filecount:self.filecount+self.num_files_cached],\
                                      treename=self.group_name,\
                                      branches=self.branches,\
                                      start=0, \
                                      stop=self.num_each_cached,warn_missing_tree=True)
            
            #update counter
            self.filecount+=self.num_files_cached
            
        
        #now preprocess
        #store in dataframe
        bgdf = pd.DataFrame.from_records(bgarray)
        sigdf = pd.DataFrame.from_records(sigarray)

        #create empty array
        x_bg = np.zeros((self.num_each_cached, 1, self.phi_bins, self.eta_bins ))
        x_sig = np.zeros((self.num_each_cached, 1, self.phi_bins, self.eta_bins ))
        
        #now go through all the events
        for i in range(self.num_each_cached):
            phi, eta, E =  bgdf['CaloCalTopoClustersAuxDyn.calPhi'][i],\
                           bgdf['CaloCalTopoClustersAuxDyn.calEta'][i],\
                           bgdf['CaloCalTopoClustersAuxDyn.calE'][i]

            x_bg[i] = np.histogram2d(phi,eta, bins=(self.phi_bins, self.eta_bins), weights=E, \
                                     range=[self.phi_range,self.eta_range])[0]

            phi, eta, E =  sigdf['CaloCalTopoClustersAuxDyn.calPhi'][i],\
                           sigdf['CaloCalTopoClustersAuxDyn.calEta'][i],\
                           sigdf['CaloCalTopoClustersAuxDyn.calE'][i]
            x_sig[i] = np.histogram2d(phi,eta, bins=(self.phi_bins, self.eta_bins), weights=E, \
                                      range=[self.phi_range,self.eta_range])[0]


        #background first
        self.x = np.vstack((x_bg, x_sig))
        
        # 1 means signal, 0 means background
        self.y = np.zeros((self.num_events_cached,self.num_classes)).astype('int32')
        #make the last half signal label
        for i in range(0,self.num_classes):
            self.y[i*self.num_each_cached:(i+1)*self.num_each_cached,i] = 1.
        
        #shuffle the arrays
        pivot=np.arange(0,self.num_events_cached)
        np.random.shuffle(pivot)
        self.x=self.x[pivot,:,:,:]
        self.y=self.y[pivot,:]
        
        #apply rescaling
        self.x /= self.max_abs
        
    
    #this is the batch iterator:
    def get_batch(self,batchsize):
        '''batch iterator'''
        #recompute cache if batchsize is too large
        if batchsize>self.num_events_cached:
            self.num_events_cached=batchsize*self.num_classes
            compute_cache()
        
        #do iterator loop
        for i in range(0,self.num_events_total,batchsize):
            
            idx = i % self.num_events_cached
            
            #do we need to prefetch?
            if idx+batchsize>self.num_events_cached:
                self.prefetch()
        
            #return the next batch
            yield self.x[idx:idx+batchsize,:,:,:], self.y[idx:idx+batchsize,:]

## Construct data iterator

In [None]:
#data iterator class
hditer=hep_data_iterator()

## Construct network

In [None]:
#some parameters
keep_prob=0.5
num_filters=128
num_units_dense=1024
initial_learning_rate=0.001

#input layer
l_inp_data = ls.layers.InputLayer((None,hditer.xshape[0],hditer.xshape[1],hditer.xshape[2]))
l_inp_label = ls.layers.InputLayer((None,hditer.yshape[0]))

#conv layers
#first layer
l_conv1 = ls.layers.Conv2DLayer(incoming=l_inp_data,
                                num_filters=num_filters,
                                filter_size=3,
                                stride=(1,1),
                                pad=0,
                                W=ls.init.HeUniform(),
                                b=ls.init.Constant(0.),
                                nonlinearity=ls.nonlinearities.LeakyRectify()
                               )
l_pool1 = ls.layers.MaxPool2DLayer(incoming=l_conv1,
                                   pool_size=(2,2),
                                   stride=2,
                                   pad=0                                   
                                  )

l_drop1 = ls.layers.DropoutLayer(incoming=l_pool1,
                       p=keep_prob,
                       rescale=True
                      )

#second layer:
l_conv2 = ls.layers.Conv2DLayer(incoming=l_drop1,
                                num_filters=num_filters,
                                filter_size=3,
                                stride=(1,1),
                                pad=0,
                                W=ls.init.HeUniform(),
                                b=ls.init.Constant(0.),
                                nonlinearity=ls.nonlinearities.LeakyRectify()
                               )
l_pool2 = ls.layers.MaxPool2DLayer(incoming=l_conv2,
                                   pool_size=(2,2),
                                   stride=2,
                                   pad=0                                   
                                  )

l_drop2 = ls.layers.DropoutLayer(incoming=l_pool2,
                       p=keep_prob,
                       rescale=True
                      )

#third layer:
l_conv3 = ls.layers.Conv2DLayer(incoming=l_drop2,
                                num_filters=num_filters,
                                filter_size=3,
                                stride=(1,1),
                                pad=0,
                                W=ls.init.HeUniform(),
                                b=ls.init.Constant(0.),
                                nonlinearity=ls.nonlinearities.LeakyRectify()
                               )
l_pool3 = ls.layers.MaxPool2DLayer(incoming=l_conv3,
                                   pool_size=(2,2),
                                   stride=2,
                                   pad=0                                   
                                  )

l_drop3 = ls.layers.DropoutLayer(incoming=l_pool3,
                       p=keep_prob,
                       rescale=True
                      )

#fourth layer:
l_conv4 = ls.layers.Conv2DLayer(incoming=l_drop3,
                                num_filters=num_filters,
                                filter_size=3,
                                stride=(1,1),
                                pad=0,
                                W=ls.init.HeUniform(),
                                b=ls.init.Constant(0.),
                                nonlinearity=ls.nonlinearities.LeakyRectify()
                               )
l_pool4 = ls.layers.MaxPool2DLayer(incoming=l_conv4,
                                   pool_size=(2,2),
                                   stride=2,
                                   pad=0                                   
                                  )

l_drop4 = ls.layers.DropoutLayer(incoming=l_pool4,
                       p=keep_prob,
                       rescale=True
                      )

#flatten
l_flat = ls.layers.FlattenLayer(incoming=l_drop4, 
                                outdim=2)

#crossfire
l_fc1 = ls.layers.DenseLayer(incoming=l_flat, 
                             num_units=num_units_dense, 
                             W=ls.init.GlorotUniform(np.sqrt(2./(1+0.01**2))), 
                             b=ls.init.Constant(0.0),
                             nonlinearity=ls.nonlinearities.LeakyRectify()
                            )

l_fc2 = ls.layers.DenseLayer(incoming=l_fc1, 
                             num_units=num_units_dense, 
                             W=ls.init.GlorotUniform(np.sqrt(2./(1+0.01**2))), 
                             b=ls.init.Constant(0.0),
                             nonlinearity=ls.nonlinearities.LeakyRectify()
                            )

#output layer
l_out = ls.layers.DenseLayer(incoming=l_fc2, 
                             num_units=hditer.num_classes, 
                             W=ls.init.GlorotUniform(np.sqrt(2./(1+0.01**2))), 
                             b=ls.init.Constant(0.0),
                             nonlinearity=ls.nonlinearities.softmax
                            )

#network
network = [l_inp_data, l_inp_label,
           l_conv1, l_pool1, l_drop1,
           l_conv2, l_pool2, l_drop2,
           l_conv3, l_pool3, l_drop3,
           l_conv4, l_pool4, l_drop4,
           l_flat, l_fc1, l_fc2,
           l_out
          ]

#variables
inp = l_inp_data.input_var
lab = T.imatrix('lab')

#output
lab_pred = ls.layers.get_output(l_out, {l_inp_data: inp})
lab_pred_det = ls.layers.get_output(l_out, {l_inp_data: inp}, deterministic=True)

#loss functions:
loss = ls.objectives.categorical_crossentropy(lab_pred,lab).mean()
loss_det = ls.objectives.categorical_crossentropy(lab_pred_det,lab).mean()

#accuracy
acc_det = T.mean(T.eq(T.argmax(lab_pred_det, axis=1), lab),dtype=theano.config.floatX)

#parameters
params = ls.layers.get_all_params(network, trainable=True)

#updates
updates = ls.updates.adam(loss, params, learning_rate=initial_learning_rate)

#compile network function
fnn = theano.function([inp], lab_pred)
#training function to minimize
fnn_train = theano.function([inp,lab], loss, updates=updates)
#validation function with accuracy
fnn_validate = theano.function([inp,lab], [loss_det,acc_det])

## Training

In [None]:
num_epochs=10
batchsize=128

for epoch in range(num_epochs):
    # In each epoch, we do a full pass over the training data:
    train_err = 0
    train_batches = 0
    start_time = time.time()
    for batch in hditer.get_batch(batchsize):
        inputs, targets = batch
        train_err += fnn_train(inputs, targets)
        train_batches += 1
        
        print train_err

    # And a full pass over the validation data:
    #val_err = 0
    #val_acc = 0
    #val_batches = 0
    #for batch in iterate_minibatches(X_val, y_val, 500, shuffle=False):
    #    inputs, targets = batch
    #    err, acc = val_fn(inputs, targets)
    #    val_err += err
    #    val_acc += acc
    #    val_batches += 1

    # Then we print the results for this epoch:
    print("Epoch {} of {} took {:.3f}s".format(
        epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss:\t\t{:.6f}".format(train_err / train_batches))
    #print("  validation loss:\t\t{:.6f}".format(val_err / val_batches))
    #print("  validation accuracy:\t\t{:.2f} %".format(
    #    val_acc / val_batches * 100))