In [1]:
__author__ = 'tkurth'
import sys
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn import preprocessing
from sklearn.cross_validation import train_test_split
from nbfinder import NotebookFinder
sys.meta_path.append(NotebookFinder())
%matplotlib inline
import time

## ROOT stuff

In [2]:
sys.path.append('/global/homes/w/wbhimji/cori-envs/nersc-rootpy/lib/python2.7/site-packages/')
import ROOT
import rootpy
import root_numpy as rnp

Welcome to ROOTaaS 6.06/04


## Useful functions

In [3]:
# Define a context manager to suppress stdout and stderr.
class suppress_stdout_stderr(object):
    '''
    A context manager for doing a "deep suppression" of stdout and stderr in 
    Python, i.e. will suppress all print, even if the print originates in a 
    compiled C/Fortran sub-function.
       This will not suppress raised exceptions, since exceptions are printed
    to stderr just before a script exits, and after the context manager has
    exited (at least, I think that is why it lets exceptions through).      

    '''
    def __init__(self):
        # Open a pair of null files
        self.null_fds =  [os.open(os.devnull,os.O_RDWR) for x in range(2)]
        # Save the actual stdout (1) and stderr (2) file descriptors.
        self.save_fds = (os.dup(1), os.dup(2))

    def __enter__(self):
        # Assign the null pointers to stdout and stderr.
        os.dup2(self.null_fds[0],1)
        os.dup2(self.null_fds[1],2)

    def __exit__(self, *_):
        # Re-assign the real stdout/stderr back to (1) and (2)
        os.dup2(self.save_fds[0],1)
        os.dup2(self.save_fds[1],2)
        # Close the null files
        os.close(self.null_fds[0])
        os.close(self.null_fds[1])

## Data iterator

In [38]:
class hep_data_iterator:
    
    #class constructor
    def __init__(self,
                 bg_cfg_file = '../config/BgFileListAug16.txt',
                 sig_cfg_file='../config/SignalFileListAug16.txt',
                 group_name='CollectionTree',
                 branches=['CaloCalTopoClustersAuxDyn.calPhi', 'CaloCalTopoClustersAuxDyn.calEta','CaloCalTopoClustersAuxDyn.calE'],
                 num_events_total=50000,
                 num_events_cached=512,
                 num_classes=2,
                 shuffle=True,
                 preprocess=True,
                 bin_size=0.025,
                 eta_range = [-5,5],
                 phi_range = [-3.14, 3.14],
                 dataset_name='histo'):
        
        #general stuff
        self.group_name=group_name
        self.branches=branches
        
        #set some local variables:
        self.bg_files = [line.rstrip() for line in open(bg_cfg_file)]
        self.sig_files = [line.rstrip() for line in open(sig_cfg_file)]

        #that is somehow hardcoded
        self.events_per_sig_file = 10000
    
        #number of total events:
        self.num_events_total=num_events_total
        self.num_events_cached=num_events_cached
        
        #number of classes
        self.num_classes=num_classes
        
        #if no total number of events is specified, take everything
        if not self.num_events_total:
            self.num_events_total = self.events_per_sig_file * np.min([len(self.sig_files),len(self.bg_files)])
        
        #we assume there are more bg per file than sig, so we bound our number of files by number of files
        #needed for a sig event
        if self.num_events_total % self.num_classes != 0:
            #adjust to class frequencies
            self.num_events_total -= self.num_events_total % self.num_classes
        
        #how many events to we need:
        self.num_each = self.num_events_total / self.num_classes

        #get the number of files needed in total and for caching
        self.num_files = int(np.ceil(self.num_each / float(self.events_per_sig_file)))

        #hack because rootpy does not do well with one file
        if self.num_files == 1:
            self.num_files = 2
        
        #shuffle array in cases where not all files are used:
        self.shuffle=shuffle
        if self.shuffle:
            np.random.shuffle(self.bg_files)
            np.random.shuffle(self.sig_files)
        
        #restrict the number of elements:
        self.bg_files=self.bg_files[:self.num_files]
        self.sig_files=self.sig_files[:self.num_files]
        
        #set file-counter to zero:
        self.filecount=0
        
        #some other required preprocessing
        self.phi_range=phi_range
        self.eta_range=eta_range
        self.bin_size=bin_size
        self.phi_bins = int(np.floor((self.phi_range[1] - self.phi_range[0]) / self.bin_size))
        self.eta_bins = int(np.floor((self.eta_range[1] - self.eta_range[0]) / self.bin_size))
        
        #compute cache size and prefetch the first batch
        self.compute_cache()
        self.prefetch()
    
    
    #compute cache pars:
    def compute_cache(self):
        #make sure that it is an integer multiple of the number of classes
        if self.num_events_cached % self.num_classes != 0:
            self.num_events_cached -= self.num_events_cached % self.num_classes
            
        #how many per class
        self.num_each_cached = self.num_events_cached / self.num_classes
        
        #how many files:
        self.num_files_cached = int(np.ceil(self.num_each_cached / float(self.events_per_sig_file)))
        
        #hack because rootpy does not do well with one file
        if self.num_files_cached == 1:
            self.num_files_cached = 2
        
    
    #fetch the next bunch of data and preprocess
    def prefetch(self):
        
        #shuffle and start from zero if the number of remaining cached files is too small:
        if self.filecount+self.num_files_cached>self.num_files:
            self.filecount=0
        
            #reshuffle data
            if self.shuffle:
                np.random.shuffle(self.bg_files)
                np.random.shuffle(self.sig_files)
        
        #so we don't have annoying stderr messages
        with suppress_stdout_stderr():
            
            #bgarray has n_events groups of 3 parallel numpy arrays 
            #(each numpy within a group is of equal length and each array corresponds to phi, eta and the corresponding energy)
            bgarray = rnp.root2array(self.bg_files[self.filecount:self.filecount+self.num_files_cached], \
                                     treename=self.group_name, \
                                     branches=self.branches, \
                                     start=0, \
                                     stop=self.num_each_cached,warn_missing_tree=True)

            sigarray = rnp.root2array(self.sig_files[self.filecount:self.filecount+self.num_files_cached],\
                                      treename=self.group_name,\
                                      branches=self.branches,\
                                      start=0, \
                                      stop=self.num_each_cached,warn_missing_tree=True)
            
            #update counter
            self.filecount+=self.num_files_cached
            
        
        #now preprocess
        #store in dataframe
        bgdf = pd.DataFrame.from_records(bgarray)
        sigdf = pd.DataFrame.from_records(sigarray)

        #create empty array
        x_bg = np.zeros((self.num_each_cached, 1, self.phi_bins, self.eta_bins ))
        x_sig = np.zeros((self.num_each_cached, 1, self.phi_bins, self.eta_bins ))
        
        #now go through all the events
        for i in range(self.num_each_cached):
            phi, eta, E =  bgdf['CaloCalTopoClustersAuxDyn.calPhi'][i],\
                           bgdf['CaloCalTopoClustersAuxDyn.calEta'][i],\
                           bgdf['CaloCalTopoClustersAuxDyn.calE'][i]

            x_bg[i] = np.histogram2d(phi,eta, bins=(self.phi_bins, self.eta_bins), weights=E, \
                                     range=[self.phi_range,self.eta_range])[0]

            phi, eta, E =  sigdf['CaloCalTopoClustersAuxDyn.calPhi'][i],\
                           sigdf['CaloCalTopoClustersAuxDyn.calEta'][i],\
                           sigdf['CaloCalTopoClustersAuxDyn.calE'][i]
            x_sig[i] = np.histogram2d(phi,eta, bins=(self.phi_bins, self.eta_bins), weights=E, \
                                      range=[self.phi_range,self.eta_range])[0]


        #background first
        self.x = np.vstack((x_bg, x_sig))
        
        # 1 means signal, 0 means background
        self.y = np.zeros((self.num_events_cached,)).astype('int32')
        #make the last half signal label
        self.y[self.num_each_cached:] = 1
    
    
    #this is the batch iterator:
    def get_batch(self,batchsize):
        #recompute cache if batchsize is too large
        if batchsize>self.num_events_cached:
            self.num_events_cached=batchsize*self.num_classes
            compute_cache()
        
        #do iterator loop
        for i in range(0,self.num_events_total,batchsize):
            
            idx = i % self.num_events_cached
            
            #do we need to prefetch?
            if idx+batchsize>self.num_events_cached:
                self.prefetch()
        
            #return the next batch
            yield self.x[idx:idx+batchsize], self.y[idx:idx+batchsize]

## Construct network

In [39]:
hditer=hep_data_iterator()